#! /usr/bin/gawk -f
# Last edited on 2000-05-31 21:40:24 by stolfi

BEGIN {
  abort = -1;
  usage = ( ARGV[0] " \\\n" \
    " -f factor-table.gawk \\\n" \
    " -v rows=RWORDFILE \\\n" \
    " -v cols=CWORDFILE \\\n" \
    " [ -v counted=BOOL ] \\\n" \
    " [ -v digits=NNN ]" );

# Reads a sequence of word pairs from stdin.
# Prints to stdout a table with the number of occurrences of each pair.
# Also prints tables of row-relative and column-relative probabilities.
#
# Each line of stdin should contain two words separated by a blank.
# If "counted" is true then each pair shoudl be preceded by
# an integer count, as produced by "uniq -c".
#
# The variables "rows" and "cols" should be filenames. Each file
# should contain the list of words to be used as row and
# column labels, respectively.  Blank lines in the "rows" file
# generate rows of dashes.
#
# The script only counts pairs where the left member is in "rows"
# and the right member is in "cols".

# Global variables:
#   nr         number of row words
#   rw         array indexed [1..nb] containing the row words in order.
#   
#   nc         number of column words
#   cw         array indexed [1..nk] containing the col words in order.
#   
#   n          total wprd pair count
#   rn[x]      num occurrences of row word x on left side of pairs.
#   cn[x]      num occurrences of col word x on right side of pairs.
#   rcn[x,y]   num occurences of pair "x y" where x is row and y is col.
#
#   log10      log(10)

  log10 = log(10);

  split("", rw);
  split("", rn);
  nr = readwords("rows", rows, rw, rn);
  rsz = maxlength(rw);
  print ("max row word length = " rsz) > "/dev/stderr";
  
  split("", cw);
  split("", cn);
  nc = readwords("cols", cols, cw, cn);
  csz = maxlength(cw);
  print ("max col word length = " csz) > "/dev/stderr";
  
  if ((digits == "") || (digits < 1)) { digits = 3; }
  if (counted == "") { counted = 0; }
  
  split("", rcn);
  for(x in rn)
    { for (y in cn) 
        { rcn[x,y] = 0; }
    }
}

(abort >= 0) { exit abort; }

/^$/ { next; }

/./ {
  if (counted) 
    { NFX = 3; k = $1; x = $2; y = $3; }
  else
    { NFX = 2; k = 1; x = $1; y = $2; }

  if (NF != NFX) { file_error(("wrong number of fields = " NF)); }
  
  xok = (x in rn);
  yok = (y in cn);
  
  if (xok && yok) { rcn[x,y] += k; }
  if (xok) { rn[x] += k; }
  if (yok) { cn[y] += k; }
  n += k;
  next;
}

END {
  if (abort >= 0) { exit(abort); }
  printf "\n";
  printfreqsprobs(n,   nr, rw, rsz, rn,  nc, cw, csz, cn,   rcn,  digits);
  printf "\n";
}

function file_error(msg)
{
  printf "count-diword-freqs: line %d: %s\n", FNR, msg > "/dev/stderr";
  abort = 1;
  exit(1);
}

function arg_error(msg)
{
  printf "count-diword-freqs: %s\n", msg > "/dev/stderr";
  abort = 1;
  exit(1);
}

function max(x, y)
{ 
  return (x > y ? x : y);
}

function readwords(filetype, file, fw, fn,    n, w)
{
  # Reads words from the given file and stored them in fw[1..n].
  # Also sets fn[w] = 0 for all w in fw.
  # Returns the count n.
  if (file == "") 
    { arg_error(("\"" filetype "\" file not specified.\n")); }
  n = 0
  while ((getline w < file) > 0)
    { n++ 
      fw[n] = w; fn[w] = 0
    }
  if (ERRNO != "0") { arg_error((file ": " ERRNO)); }
  close(file);
  return n
}

function maxlength(fw,   i, w, sz)
{
  sz = 0
  for (i in fw) sz = max(sz, length(fw[i]))
  return sz
}

function printfreqsprobs(t, nrows, rw, rsz, rt,  ncols, cw, csz, ct, rct,  digs, \
                p, rp, cp, rcp, cfmt)
{
  # Prints raw counts and probabilities for word pairs. Inputs:
  #  t     = total totalorum.
  #  nrows = number of rows.
  #  rw    = row words, indexed [1..nrows].
  #  rsz   = max row word length.
  #  rt    = row totals, indexed by word.
  #  ncols = number of columns.
  #  cw    = column words, indedex [1..ncols].
  #  csz   = max col word length.
  #  ct    = column totals, indexed by word.
  #  rct   = word pair counts, indexed by [row word,column word].
  #  digs  = number of digits in table entry
  #

  cfmt = ("%" digs "d");
  ptable("Raw pair counts", t, nrows, rw, rsz, rt,  ncols, cw, csz, ct, rct, digs, cfmt, 0);

  split("", rp);
  split("", cp);
  split("", rcp);

  p = computerowdistrs(t, nrows, rw, rt,  ncols, cw, ct, rct,  rp, cp, rcp);
  printf "\n";
  ptable("Row probabilities", p, nrows, rw, rsz, rp,  ncols, cw, csz, cp, rcp, 2, "%2d", 0);

  p = computecoldistrs(t, nrows, rw, rt,  ncols, cw, ct, rct,  rp, cp, rcp);
  printf "\n";
  ptable("Col probabilities", p, nrows, rw, rsz, rp,  ncols, cw, csz, cp, rcp, 2, "%2d", 0);

  p = computeanomalies(t, nrows, rw, rt,  ncols, cw, ct, rct,  rp, cp, rcp);
  printf "\n";
  ptable("Anomalies", p, nrows, rw, rsz, rp,  ncols, cw, csz, cp, rcp, 3, "%+2d", 0);
}

function printfreqs(t, nrows, rw, rsz, rt,  ncols, cw, csz, ct, rct,  digs, \
                    cfmt)
{
  # Prints raw counts only for word pairs. Inputs:
  #  t    total totalorum.
  #  nrows = number of rows.
  #  rw    = row words, indexed [1..nrows].
  #  rsz   = max row word length.
  #  rt    = row totals, indexed by word.
  #  ncols = number of columns.
  #  cw    = column words, indedex [1..ncols].
  #  csz   = max col word length.
  #  ct    = column totals, indexed by word.
  #  rct   = word pair counts, indexed by [row word,column word].
  #  digs  = number of digits in table entry
  #
  cfmt = ("%" digs "d");
  ptable("Raw pair counts", t, nrows, rw, rsz, rt,  ncols, cw, csz, ct, rct, digs, cfmt, 0);
}


function computecoldistrs(t, nrows, rw, rt, ncols, cw, ct, rct, \
                           rp, cp, rcp, \
                           r, c, rword, cword, sum, tot, temp )
{
  # Computes probability distribution along each column,
  # given raw counts and totals.
  # Inputs:
  #   t     = total totalorum.
  #   nrows = number of rows.
  #   rw[r]   row words, indexed [1..nrows].
  #   rt[x]   row totals, indexed by word.
  #   ncols = number of columns.
  #   cw[c]   column words, indedex [1..ncols].
  #   ct[y]   column totals, indexed by word.
  #   rct[x,y]  word pair count, indexed by [row word,column word].
  # Outputs (scaled [0..99]):
  #   rcp[x,y] = rct[x,y]/ct[y].
  #   cp[y]    = sum{rct[x,y] : x in rw}/ct[y].
  #   rp[x]    = sum{rct[x,y] : y in cw}/sum{ct[y] : y in cw}.
  #   returned result = sum{rct[x,y] : x in rw and y in cw}/sum{ct[y] : y in cw}.
  # If rct[x,y] is the count of consecutive word pairs (x,y), then
  #   rcp[x,y] = probability of y being preceded by x.
  #   cp[y]    = probability of y being preceded by a word of rw.
  #   rp[x]    = probability of a word of cw being preceded by x.
  #   returned result = probability of a word in cw bein preceded by a word in rw.
  split("", rp);
  split("", cp);
  split("", rcp);
  for(r=1; r<= nrows; r++)
    { rword = rw[r]; rp[rword] = 0; }
  tot = 0;
  for (c=1; c<= ncols; c++)
    { cword = cw[c];
      sum = 0;
      for(r=1; r<= nrows; r++)
        { rword = rw[r];
          if (rword != "") 
            { temp = rct[rword,cword];
              rcp[rword, cword] = int(99.9999*(temp+0.0)/max(1,ct[cword]));
              sum += temp;
              rp[rword] += temp;
            }
        }
      cp[cword] = int(99.9999*(sum+0.0)/max(1,ct[cword]));
      tot += ct[cword];
    }
  sum = 0;
  for(r=1; r<= nrows; r++)
    { rword = rw[r];
      temp = rp[rword];
      rp[rword] = int(99.9999*(temp+0.0)/max(1,tot));
      sum += temp;
    }
  return int(99.9999*(sum+0.0)/max(1,tot));
}

function computerowdistrs(t, nrows, rw, rt, ncols, cw, ct, rct, \
                           rp, cp, rcp, \
                           r, c, rword, cword, sum, tot, temp )
{
  # Computes probability distribution along each row,
  # given raw counts and totals.
  # Inputs:
  #   t     = total totalorum.
  #   nrows = number of rows.
  #   rw[r]   row words, indexed [1..nrows].
  #   rt[x]   row totals, indexed by word.
  #   ncols = number of columns.
  #   cw[c]   column words, indedex [1..ncols].
  #   ct[y]   column totals, indexed by word.
  #   rct[x,y]  word pair count, indexed by [row word,column word].
  # Outputs (scaled [0..99]):
  #   rcp[x,y] = rct[x,y]/rt[x].
  #   rp[x]    = sum{rct[x,y] : y in cw}/rt[x].
  #   cp[y]    = sum{rct[x,y] : x in rw}/sum{rt[x] : x in rw}.
  #   returned result = sum{rct[x,y] : x in rw and y in cw}/sum{rt[x] : x in rw}.
  # If rct[x,y] is the count of consecutive word pairs (x,y), then
  #   rcp[x,y] = probability of x being followed by y.
  #   rp[x]    = probability of x being followed by a word of cw.
  #   cp[y]    = probability of a word of rw being followed by y.
  #   returned result = probability of a word in rw being followed by a word in cw.
  split("", rp);
  split("", cp);
  split("", rcp);
  for(c=1; c<= ncols; c++);
    { cword = cw[c]; cp[cword] = 0; }
  tot = 0;
  for (r=1; r<= nrows; r++)
    { rword = rw[r];
      if (rword != "") 
        { sum = 0;
          for(c=1; c<= ncols; c++)
            { cword = cw[c];
              temp = rct[rword,cword];
              rcp[rword, cword] = int(99.9999*(temp+0.0)/max(1,rt[rword]));
              sum += temp;
              cp[cword] += temp;
            }
          rp[rword] = int(99.9999*(sum+0.0)/max(1,rt[rword]));
          tot += rt[rword];
        }
    }
  sum = 0;
  for(c=1; c<= ncols; c++)
    { cword = cw[c];
      temp = cp[cword];
      cp[cword] = int(99.9999*(temp+0.0)/max(1,tot));
      sum += temp;
    }
  return int(99.9999*(sum+0.0)/max(1,tot));
}

function computeanomalies(t, nrows, rw, rt, ncols, cw, ct, rct, \
                           rp, cp, rcp, \
                           r, c, rword, cword, temp, obs, pred, eps )
{
  # Computes `log anomaly factor' for each element (part of the element's
  # probability that cannot be explained as a consequence of 
  # row and colum totals).
  # Inputs:
  #   t     = total totalorum.
  #   nrows = number of rows.
  #   rw[r]   row words, indexed [1..nrows].
  #   rt[x]   row totals, indexed by word.
  #   ncols = number of columns.
  #   cw[c]   column words, indedex [1..ncols].
  #   ct[y]   column totals, indexed by word.
  #   rct[x,y]  word pair count, indexed by [row word,column word].
  # Outputs (scaled × 10, clipped [-99..+99]):
  #   cp[y]    = log_10 of basic factor for column y.
  #   rp[x]    = log_10 of basic factor for row x.
  #   rcp[x,y] = log_10(rct[x,y]) - rp[x] - cp[y].
  #   returned result = log_10(t).
  
  eps = 1.0;     # noise level.
  split("", rp);
  split("", cp);
  split("", rcp);
  for(r=1; r<= nrows; r++) { rword = rw[r]; rp[rword] = 0; }
  for(c=1; c<= ncols; c++) { cword = cw[c]; cp[cword] = 0; }
  fctb_factor(rct,rp,cp,eps);
  
  for (c=1; c<= ncols; c++)
    { cword = cw[c];
      for(r=1; r<= nrows; r++)
        { rword = rw[r];
          if (rword != "") 
            { obs = rct[rword,cword];
              pred = rp[rword]*cp[cword];
              temp = clippedlog10(sqrt(obs*obs + eps*eps)/pred, 10.0);
              rcp[rword, cword] = int(9.99999*(temp+0.0));
            }
        }
    }
  for(r=1; r<= nrows; r++) 
    { rword = rw[r];
      temp = clippedlog10(rp[rword], 10.0);
      rp[rword] = int(9.99999*(temp+0.0));
    }
  for (c=1; c<= ncols; c++)
    { cword = cw[c];
      temp = clippedlog10(ct[cword], 10.0);
      cp[cword] = int(9.99999*(temp+0.0));
    }
  temp = clippedlog10(t, 10.0);
  return int(9.99999*(temp+0.0));
}

function clippedlog10(x,m,   y)
{
  if (x == 0) { return(-m); }
  y = log(x)/log10;
  if (y >= +m) { y = +m; }
  if (y <= -m) { y = -m; }
  return(y);
}

function ptable(title, t, nrows, rw, rsz, rt,  ncols, cw, csz, ct,  rct, esz, fmt, null,  \
                rword, cword, x, y, i,j, \
                efmt, tfmt, rfmt, temp)
{
  # Prints a table of word pair data. Inputs:
  #   title = table title.
  #   t     = total totalorum.
  #   nrows = number of rows.
  #   rw    = row words, indexed [1..nrows].
  #   rsz   = max row word length.
  #   rt    = row totals, indexed by word.
  #   ncols = number of columns.
  #   cw    = column words, indedex [1..ncols].
  #   csz   = max col word length.
  #   ct    = column totals, indexed by word.
  #   rct   = word pair data, indexed by [row word,column word].
  #   esz   = table entry width (excluding separating whitespace).
  #   fmt   = table entry format (should print at most esz chars).
  #   null  = null value (prints as ".").
  
  printtitle(title);
  
  if (rsz < 3) { rsz = 3; }
  if (csz < 3) { csz = 3; }
  
  printheaderdashes(rsz, ncols, cw, csz, esz);
  
  printcolheaders(rsz, ncols, cw, csz, esz);
  
  printheaderdashes(rsz, ncols, cw, csz, esz);
  
  efmt = ("%" (esz + 1) "s");
  tfmt = ("%" (esz + 3) "s");
  rfmt = ("%-" rsz "s");
  
  # Print the table's body:
  
  for(i=1;i<=nrows;i++)
    { 
      rword = rw[i];
      if ( rword != "" )
        { printf rfmt, rword;
          printf tfmt, sprintf(fmt, rt[rword]);

          for(j=1;j<=ncols;j++) 
            { 
              temp = rct[rword, cw[j]];
              if (temp == null) 
                { printf efmt, "."; }
              else
                { printf efmt, sprintf(fmt, temp); }
            }
          printf "\n";
        }
      else
        { printheaderdashes(rsz, ncols, cw, csz, esz); }
    }

  printheaderdashes(rsz, ncols, cw, csz, esz);
  
  # Print a row of column totals:
  
  printf rfmt, "TOT";
  printf tfmt, sprintf(fmt, t);
  for(j=1;j<=ncols;j++) 
    { 
      temp = ct[cw[j]];
      if (temp == 0) 
        { printf efmt, "."; }
      else
        { printf efmt, sprintf(fmt, temp); }
    }
  printf "\n";
}

function printcolheaders(rsz, ncols, cw, csz, esz,    i, j, len, ctfmt)
{
  # rsz = maximum base word length.
  # ncols = number of columns.
  # cw = list of column words, indexed [1..ncols].
  # csz = maximum key word length.
  # esz = nominal count width.
  
  ctfmt = ("%" (esz + 1) "s");
  cw[0] = "TOT"; 
  if (csz < 3) { csz = 3; }
  if (rsz < 3) { rsz = 3; }
  
  for(i=csz;i>0;i--)
    { 
      for(j=1;j<=rsz;j++) printf " ";
      printf "  ";
      for(j=0;j<=ncols;j++)
        { len = length(cw[j]);
          printf ctfmt, (i > len ? " " : substr(cw[j], len-i+1, 1));
        }
      printf "\n";
    }
  delete bw[0];
}

function printheaderdashes(rsz, ncols, cw, csz, esz,    i, j, len)
{
  # rsz = maximum base word length.
  # ncols = number of columns.
  # cw = list of column words.
  # csz = maximum word length.
  # esz = nominal count width.

  if (rsz < 3) { rsz = 3; }
  
  for(j=1;j<=rsz;j++) printf "-";

  printf " ";
  for(i=0;i<esz+2;i++) { printf "-"; }

  for(j=1;j<=ncols;j++)
    { printf " ";
      for(i=0;i<esz;i++) printf "-";
    }
  printf "\n";
}

function printtitle(title, rsz,   i)
{
  # Prints title.
  printf ("\n%s:\n\n", title);
}