#! /usr/bin/gawk -f
# Last edited on 1998-07-27 05:50:56 by stolfi

BEGIN { 
  usage = ( \
      "tabulate-triple-counts \\\n" \
      "  -v rows=RWORDFILE \\\n" \
      "  -v cols=CWORDFILE \\\n" \
      "  [ -v rowFreqs=BOOL ] \\\n" \
      "  [ -v colFreqs=BOOL ] \\\n" \
      "  [ -v freqs=BOOL ] \\\n" \
      "  [ -v counts=BOOL ] \\\n" \
      "  [ -v digits=NNN ] \\\n" \
      "  < TRCOUNTS " \
    );

  # The TRCOUNTS records should have the format COUNTS KEY X Y
  # where COUNTS is an integer and KEY,X,Y are non-empty strings.
  # The file must be sorted by KEY.

  # For each value of KEY, prints to stdout one or more tables,
  # giving counts and/or probabilities for indexed  pair X,Y. 
  #
  # If "counts" is true, prints absolute counts per pair.
  # If "freqs" is true, prints frequencies relative to table totals.
  # If "rowFreqs" is true, prints frequencies relative to row totals.
  # If "colFreqs" is true, prints frequencies relative to col totals.
  # 
  # The variables "rows" and "cols" should be filenames. Each file
  # should contain the list of strings to be used as row and column
  # labels, respectively.  The script only prints separate statistics
  # for pairs where the left member is in "rows" and the right member
  # is in "cols".  Other strings are mapped to "ETC". 

  # Global variables:
  #   nr         number of row strings
  #   rw         array indexed [1..nb] containing the row strings in order.
  #   
  #   nc         number of column strings
  #   cw         array indexed [1..nc] containing the col strings in order.

  #   n          total string pair count
  #   rn[x]      num occurrences of row string x on left side of pairs.
  #   cn[y]      num occurrences of col string y on right side of pairs.
  #   rcn[x,y]   num occurences of pair "x y" where x is row and y is col.

  abort = -1;
  
  if ((counts == "") && (freqs == "") && (colFreqs == "") && (rowFreqs == "")) 
    { error(("must specify \"[row|col]Freqs\" or  \"freqs\" or \"counts\"")); }
  
  split("", rw);
  nr = readstrings("rows", rows, rw);
  rsz = maxlength(rw);
  print ("max row string length = " rsz) > "/dev/stderr";
  
  split("", cw);
  nc = readstrings("cols", cols, cw);
  csz = maxlength(cw);
  print ("max col string length = " csz) > "/dev/stderr";
  
  if ((digits == "") || (digits < 1)) { digits = 3; }
  
  curkey = "";
}

function clearcounts(   r,c,x,y)
{
  n = 0;
  split("", rn); for(r=1;r<=nr;r++) { rn[rw[r]] = 0; }
  split("", cn); for(c=1;c<=nc;c++) { cn[cw[c]] = 0; }
  split("", rcn); for(x in rn) { for(y in cn) { rcn[x,y] = 0; } } 
}

/./ {
  if (abort >= 0) { exit abort; }
  if (NF != 4) { error(("line " NR ": bad record format")); }
  k = $1;
  
  key = $2;
  if (key != curkey) { printslice(curkey); clearcounts(); curkey = key; }
    
  x = $3; if (!(x in rn)) { x = "ETC"; }
  y = $4; if (!(y in cn)) { y = "ETC"; }
  
  rcn[x,y] += k;
  rn[x] += k;
  cn[y] += k;
  n += k;
}

/^$/ {
  if (abort >= 0) { exit abort; }
  next;
}

END {
  if (abort >= 0) { exit abort; }
  printslice(curkey); 
}

function printslice(key)
{
  if (key != "")
    { printf "Pairs with key = %s\n", key;
      printfreqsprobs(n,   nr, rw, rsz, rn,  nc, cw, csz, cn,   rcn,  digits);
      printf "\n";
    }
}

function max(x, y)
{ 
  return (x > y ? x : y);
}

function readstrings(filetype, file, fw,     n,w)
{
  # Reads strings from the given file and stores them in fw[1..n].
  # Provides a default string "ETC".
  # Also sets fn[w] = 0 for all w in fw.
  # Returns the count n.
  if (file == "") { error(("file \"" filetype "\" not specified")); }
  n = 0;
  while ((getline w < file) > 0)
    { n++;
      fw[n] = w;
    }
  n++; fw[n] = "ETC";
  close(file);
  return n;
}

function maxlength(fw,   i,w,sz)
{
  sz = 0;
  for (i in fw) sz = max(sz, length(fw[i]));
  return sz;
}

function computecoldistrs(t, nrows, rw, rt, ncols, cw, ct, rct, \
                           rp, cp, rcp, rce, \
                           r, c, rwd, cwd, sum, temp )
{
  # Computes probability distribution along each column,
  # given raw counts and totals.
  # Inputs:
  #   t     = total totalorum.
  #   nrows = number of rows.
  #   rw[r]   row strings, indexed [1..nrows].
  #   rt[x]   row totals, indexed by string.
  #   ncols = number of columns.
  #   cw[c]   column strings, indedex [1..ncols].
  #   ct[y]   column totals, indexed by string.
  #   rct[x,y]  string pair count, indexed by [row string,column string].
  # Outputs (biased and scaled [0..99]):
  #   rcp[x,y] = rct[x,y]/ct[y].
  #   rce[x,y] = 1/ct[y].
  #   rp[x]    = sum{rct[x,y] : y in cw}/t.
  #   cp[y]    = sum{rct[x,y] : x in rw}/ct[y].
  #   returned result = sum{rct[x,y] : x in rw and y in cw}/t.
  # If rct[x,y] is the count of pairs (x,y), and rt[x], ct[y] are its totals, then
  #   rcp[x,y] = probability of y being preceded by x,
  #   rp[x]    = probability of x as a first member of pair,
  #   cp[y]    = 1,
  #   returned result = 1,
  # all biased and scaled to [0..99].
  for (c=1; c<= ncols; c++)
    { cwd = cw[c];
      sum = 0;
      for(r=1; r<= nrows; r++)
        { rwd = rw[r];
          temp = rct[rwd,cwd] + 1;
          rcp[rwd, cwd] = int(99.9999*temp/(ct[cwd]+nrows));
          rce[rwd, cwd] = int(99.9999*1/(ct[cwd]+nrows));
          sum += temp;
        }
      if (sum != (ct[cwd]+nrows)) { error(("col \"" cwd "\": bad total")); }
      cp[cwd] = int(99.9999);
    }
  sum = 0;
  for(r=1; r<= nrows; r++)
    { rwd = rw[r];
      temp = rt[rwd] + 1;
      rp[rwd] = int(99.9999*temp/(t+nrows));
      sum += temp;
    }
  if (sum != (t+nrows)) { error(("bad table total")); }
  return int(99.9999);
}

function computerowdistrs(t, nrows, rw, rt, ncols, cw, ct, rct, \
                           rp, cp, rcp, rce, \
                           r, c, rwd, cwd, sum, temp )
{
  # Computes probability distribution along each row,
  # given raw counts and totals.
  # Inputs:
  #   t     = total totalorum.
  #   nrows = number of rows.
  #   rw[r]   row strings, indexed [1..nrows].
  #   rt[x]   row totals, indexed by string.
  #   ncols = number of columns.
  #   cw[c]   column strings, indedex [1..ncols].
  #   ct[y]   column totals, indexed by string.
  #   rct[x,y]  string pair count, indexed by [row string,column string].
  # Outputs (biased and scaled [0..99]):
  #   rcp[x,y] = rct[x,y]/rt[x].
  #   rce[x,y] = 1/rt[x].
  #   rp[x]    = sum{rct[x,y] : y in cw}/rt[x].
  #   cp[y]    = sum{rct[x,y] : x in rw}/t.
  #   returned result = sum{rct[x,y] : x in rw and y in cw}/t.
  # If rct[x,y] is the count of pairs (x,y), and rt[x], ct[y] are its totals, then
  #   rcp[x,y] = probability of x being followed by y,
  #   rp[x]    = 1,
  #   cp[y]    = probability of y as a second pair element.
  #   returned result = 1,
  # all biased and scaled to [0..99].
  for (r=1; r<= nrows; r++)
    { rwd = rw[r];
      sum = 0;
      for(c=1; c<= ncols; c++)
        { cwd = cw[c];
          temp = rct[rwd,cwd] + 1;
          rcp[rwd, cwd] = int(99.9999*temp/(rt[rwd]+ncols));
          rce[rwd, cwd] = int(99.9999/(rt[rwd]+ncols));
          sum += temp;
        }
      if (sum != (rt[rwd]+ncols)) { error(("row \"" rwd "\": bad total")); }
      rp[rwd] = int(99.9999);
    }
  sum = 0;
  for(c=1; c<= ncols; c++)
    { cwd = cw[c];
      temp = ct[cwd] + 1;
      cp[cwd] = int(99.9999*temp/(t+ncols));
      sum += temp;
    }
  if (sum != (t+ncols)) { error(("bad table total")); }
  return int(99.9999);
}

function computepairdistrs(t, nrows, rw, rt, ncols, cw, ct, rct, \
                           rp, cp, rcp, rce, \
                           r, c, rwd, cwd, sum, temp, nen )
{
  # Computes probability distribution over entire table,
  # given raw counts and totals.
  # Inputs:
  #   t     = total totalorum.
  #   nrows = number of rows.
  #   rw[r]   row strings, indexed [1..nrows].
  #   rt[x]   row totals, indexed by string.
  #   ncols = number of columns.
  #   cw[c]   column strings, indedex [1..ncols].
  #   ct[y]   column totals, indexed by string.
  #   rct[x,y]  string pair count, indexed by [row string,column string].
  # Outputs (biased and scaled [0..99]):
  #   rcp[x,y] = rct[x,y]/t.
  #   rce[x,y] = 1/t.
  #   rp[x]    = sum{rct[x,y] : y in cw}/t.
  #   cp[y]    = sum{rct[x,y] : x in rw}/t.
  #   returned result = 1.
  # If rct[x,y] is the count of consecutive string pairs (x,y), then
  #   rcp[x,y] = probability of a pair being x y,
  #   rp[x]    = probability of x as second pair element,
  #   cp[y]    = probability of y as first pair element,
  #   returned result = 1,
  # all biased and scaled to [0..99].
  nen = ncols*nrows;
  sum = 0;
  for (r=1; r<= nrows; r++)
    { rwd = rw[r];
      for(c=1; c<= ncols; c++)
        { cwd = cw[c];
          temp = rct[rwd,cwd] + 1;
          rcp[rwd, cwd] = int(999.999*temp/(t+nen));
          rce[rwd, cwd] = int(999.999*1/(t+nen));
          sum += temp;
        }
    }
  if (sum != (t+nen)) { error(("bad total totalorum")); }
  sum = 0;
  for(c=1; c<= ncols; c++)
    { cwd = cw[c];
      temp = ct[cwd] + 1;
      cp[cwd] = int(999.999*temp/(t+ncols));
      sum += temp;
    }
  if (sum != (t+ncols)) { error(("bad col totals")); }
  sum = 0;
  for(r=1; r<= nrows; r++)
    { rwd = rw[r];
      temp = rt[rwd] + 1;
      rp[rwd] = int(999.999*temp/(t+nrows));
      sum += temp;
    }
  if (sum != (t+nrows)) { error(("bad row totals")); }
  return int(999.999);
}

function printcolheaders(rsz, ncols, cw, csz, esz,    i,j,len,ctfmt)
{
  # rsz = maximum base string length.
  # ncols = number of columns.
  # cw = list of column strings, indexed [1..ncols].
  # csz = maximum key string length.
  # esz = nominal count width.
  
  ctfmt = (" %" esz "s");
  cw[ncols+1] = "TOT";
  
  for(i=csz;i>0;i--)
    { 
      for(j=1;j<=rsz;j++) printf " ";
      for(j=1;j<=ncols+1;j++)
        { len = length(cw[j]);
          if (j>ncols) printf "  ";
          printf ctfmt, (i > len ? " " : substr(cw[j], len-i+1, 1));
        }
      printf "\n"
    }
  delete cw[ncols+1];
}

function printheaderdashes(rsz, ncols, cw, csz, esz,    i, j, len)
{
  # rsz = maximum base string length.
  # ncols = number of columns.
  # cw = list of column strings.
  # csz = maximum string length.
  # esz = nominal count width.
  
  for(j=1;j<=rsz;j++) printf "-";

  for(j=1;j<=ncols+1;j++)
    { printf " ";
      if (j>ncols) { printf "--"; }
      for(i=0;i<esz;i++) printf "-";
    }
  printf "\n";
}

function printtitle(title, rsz,   i)
{
  # Prints title.
  printf ("%s:\n", title);
}

function ptable(title, t, nrows, rw, rsz, rt,  ncols, cw, csz, ct,  rct, rce, esz, fmt,  \
                rwd, cwd, x, y, i,j, \
                efmt, tfmt, rfmt, temp)
{
  # Prints a table of string pair data. Inputs:
  #   title = table title.
  #   t     = total totalorum.
  #   nrows = number of rows.
  #   rw    = row strings, indexed [1..nrows].
  #   rsz   = max row string length.
  #   rt    = row totals, indexed by string.
  #   ncols = number of columns.
  #   cw    = column strings, indedex [1..ncols].
  #   csz   = max col string length.
  #   ct    = column totals, indexed by string.
  #   rct   = string pair data, indexed by [row string,column string].
  #   rce   = string pair error, indexed by [row string,column string].
  #   esz   = table entry width (excluding separating whitespace).
  #   fmt   = table entry format (should print at most esz chars).
  #
  
  printtitle(title);
  
  printheaderdashes(rsz, ncols, cw, csz, esz);
  
  printcolheaders(rsz, ncols, cw, csz, esz);
  
  printheaderdashes(rsz, ncols, cw, csz, esz);
  
  efmt = (" %" esz "s");
  tfmt = (" %" (esz + 2) "s");
  rfmt = ("%-" rsz "s");
  
  # Print the table's body:
  
  for(i=1;i<=nrows;i++)
    { 
      rwd = rw[i];
      printf rfmt, rw[i];
      for(j=1;j<=ncols;j++) 
        { 
          temp = rct[rwd, cw[j]];
          if (temp <= rce[rwd, cw[j]]) 
            { printf efmt, "."; }
          else
            { printf efmt, sprintf(fmt, temp); }
        }
      printf tfmt, sprintf(fmt, rt[rwd]);
      printf "\n";
    }

  printheaderdashes(rsz, ncols, cw, csz, esz);
  
  # Print a row of column totals:
  
  printf rfmt, "TOT";
  for(j=1;j<=ncols;j++) 
    { 
      temp = ct[cw[j]];
      if (temp == 0) 
        { printf efmt, "."; }
      else
        { printf efmt, sprintf(fmt, temp); }
    }
  printf tfmt, sprintf(fmt, t);
  printf "\n";
}

function printfreqsprobs(t, nrows, rw, rsz, rt,  ncols, cw, csz, ct, rct,  digs, \
                p, rp, cp, rcp, rce, cfmt)
{
  # Prints raw counts and probabilities for string pairs. Inputs:
  #  t     = total totalorum.
  #  nrows = number of rows.
  #  rw    = row strings, indexed [1..nrows].
  #  rsz   = max row string length.
  #  rt    = row totals, indexed by string.
  #  ncols = number of columns.
  #  cw    = column strings, indedex [1..ncols].
  #  csz   = max col string length.
  #  ct    = column totals, indexed by string.
  #  rct   = string pair counts, indexed by [row string,column string].
  #  digs  = number of digits in table entry
  #
  if (counts)
    { cfmt = ("%" digs "d");
      split("", rce);
      ptable("Raw pair counts", \
        t, nrows, rw, rsz, rt,  ncols, cw, csz, ct, rct, rce, digs, cfmt);
    }
  if (rowFreqs)
    { 
      split("", rp);
      split("", cp);
      split("", rcp);
      split("", rce);
      p = computerowdistrs(t, nrows, rw, rt,  ncols, cw, ct, rct,  rp, cp, rcp, rce);
      printf "\n";
      ptable("Row probabilities (×99)", \
        p, nrows, rw, rsz, rp,  ncols, cw, csz, cp, rcp, rce, 2, "%2d");
    }
  if (colFreqs)
    { 
      split("", rp);
      split("", cp);
      split("", rcp);
      split("", rce);
      p = computecoldistrs(t, nrows, rw, rt,  ncols, cw, ct, rct,  rp, cp, rcp, rce);
      printf "\n";
      ptable("Col probabilities (×99)", \
        p, nrows, rw, rsz, rp,  ncols, cw, csz, cp, rcp, rce, 2, "%2d");
    }
  if (freqs)
    { 
      split("", rp);
      split("", cp);
      split("", rcp);
      split("", rce);
      p = computepairdistrs(t, nrows, rw, rt,  ncols, cw, ct, rct,  rp, cp, rcp, rce);
      printf "\n";
      ptable("Pair probabilities (×999)", \
        p, nrows, rw, rsz, rp,  ncols, cw, csz, cp, rcp, rce, 3, "%3d");
    }
}

function printfreqs(t, nrows, rw, rsz, rt,  ncols, cw, csz, ct, rct, rce,  digs, \
                    cfmt)
{
  # Prints raw counts only for string pairs. Inputs:
  #  t    total totalorum.
  #  nrows = number of rows.
  #  rw    = row strings, indexed [1..nrows].
  #  rsz   = max row string length.
  #  rt    = row totals, indexed by string.
  #  ncols = number of columns.
  #  cw    = column strings, indedex [1..ncols].
  #  csz   = max col string length.
  #  ct    = column totals, indexed by string.
  #  rct   = string pair counts, indexed by [row string,column string].
  #  rce   = min pair counts, indexed by [row string,column string].
  #  digs  = number of digits in table entry
  #
  cfmt = ("%" digs "d");
  ptable("Raw pair counts", t, nrows, rw, rsz, rt,  ncols, cw, csz, ct, rct, rce, digs, cfmt);
}

function error(msg)
{
  printf "%s\n", msg > "/dev/stderr";
  abort = 1; exit 1;
}