#! /usr/bin/gawk -f
# Last edited on 2008-02-04 18:34:58 by stolfi

# Usage: "$0 -v rows=RWORDFILE  -v cols=CWORDFILE [ -v digits=NNN ]" 

# Reads a sequence of word pairs from stdin.
# Prints to stdout a table with the number of occurrences of each pair.
# Also prints tables of row-relative and column-relative probabilities.
# Each line of stdin should contain two words separated by a blank.
# The variables "rows" and "cols" should be filenames. Each file
# should contain the list of words to be used as row and
# column labels, respectively. 
# The script only counts pairs where the left member is in "rows"
# and the right member is in "cols".

# Global variables:
#   nr         number of row words
#   rw         array indexed [1..nb] containing the row words in order.
#   
#   nc         number of column words
#   cw         array indexed [1..nk] containing the col words in order.
#   
#   n          total wprd pair count
#   rn[x]      num occurrences of row word x on left side of pairs.
#   cn[x]      num occurrences of col word x on right side of pairs.
#   rcn[x,y]   num occurences of pair "x y" where x is row and y is col.

function max(x, y)
{ 
  return (x > y ? x : y)
}

function readwords(filetype, file, fw, fn,    n, w)
{
  # Reads words from the given file and stores them in fw[1..n].
  # Also sets fn[w] = 0 for all w in fw.
  # Returns the count n.
  if (file == "") 
    { printf "count-diword-freqs: \"%s\" file not specified.\n", \
        filetype \
        > "/dev/stderr";
      exit(1);
      awk_is_stupid = 1
    }
  n = 0
  while ((getline w < file) > 0)
    { n++ 
      fw[n] = w; fn[w] = 0
    }
  return n
}

function maxlength(fw,   i, w, sz)
{
  sz = 0
  for (i in fw) sz = max(sz, length(fw[i]))
  return sz
}

BEGIN { 
  split("", rw)
  split("", rn)
  nr = readwords("rows", rows, rw, rn)
  rsz = maxlength(rw)
  print ("max row word length = " rsz) > "/dev/stderr"
  
  split("", cw)
  split("", cn)
  nc = readwords("cols", cols, cw, cn)
  csz = maxlength(cw)
  print ("max col word length = " csz) > "/dev/stderr"
  
  if ((digits == "") || (digits < 1)) { digits = 3; }
  
  split("", rcn)
  for(x in rn)
    for (y in cn) 
      rcn[x,y] = 0
}

function computecoldistrs(t, nrows, rw, rt, ncols, cw, ct, rct, \
                           rp, cp, rcp, \
                           r, c, rword, cword, sum, tot, temp )
{
  # Computes probability distribution along each column,
  # given raw counts and totals.
  # Inputs:
  #   t     = total totalorum.
  #   nrows = number of rows.
  #   rw[r]   row words, indexed [1..nrows].
  #   rt[x]   row totals, indexed by word.
  #   ncols = number of columns.
  #   cw[c]   column words, indedex [1..ncols].
  #   ct[y]   column totals, indexed by word.
  #   rct[x,y]  word pair count, indexed by [row word,column word].
  # Outputs (scaled [0..99]):
  #   rcp[x,y] = rct[x,y]/ct[y].
  #   cp[y]    = sum{rct[x,y] : x in rw}/ct[y].
  #   rp[x]    = sum{rct[x,y] : y in cw}/sum{ct[y] : y in cw}.
  #   returned result = sum{rct[x,y] : x in rw and y in cw}/sum{ct[y] : y in cw}.
  # If rct[x,y] is the count of consecutive word pairs (x,y), then
  #   rcp[x,y] = probability of y being preceded by x.
  #   cp[y]    = probability of y being preceded by a word of rw.
  #   rp[x]    = probability of a word of cw being preceded by x.
  #   returned result = probability of a word in cw bein preceded by a word in rw.
  split("", rp)
  split("", cp)
  split("", rcp)
  for(r=1; r<= nrows; r++)
    { rword = rw[r]; rp[rword] = 0 }
  tot = 0
  for (c=1; c<= ncols; c++)
    { cword = cw[c]
      sum = 0
      for(r=1; r<= nrows; r++)
        { rword = rw[r]
          temp = rct[rword,cword]
          rcp[rword, cword] = int(99.9999*(temp+0.0)/max(1,ct[cword]))
          sum += temp
          rp[rword] += temp
        }
      cp[cword] = int(99.9999*(sum+0.0)/max(1,ct[cword]))
      tot += ct[cword]
    }
  sum = 0
  for(r=1; r<= nrows; r++)
    { rword = rw[r]
      temp = rp[rword]
      rp[rword] = int(99.9999*(temp+0.0)/max(1,tot))
      sum += temp
    }
  return int(99.9999*(sum+0.0)/max(1,tot))
}

function computerowdistrs(t, nrows, rw, rt, ncols, cw, ct, rct, \
                           rp, cp, rcp, \
                           r, c, rword, cword, sum, tot, temp )
{
  # Computes probability distribution along each row,
  # given raw counts and totals.
  # Inputs:
  #   t     = total totalorum.
  #   nrows = number of rows.
  #   rw[r]   row words, indexed [1..nrows].
  #   rt[x]   row totals, indexed by word.
  #   ncols = number of columns.
  #   cw[c]   column words, indedex [1..ncols].
  #   ct[y]   column totals, indexed by word.
  #   rct[x,y]  word pair count, indexed by [row word,column word].
  # Outputs (scaled [0..99]):
  #   rcp[x,y] = rct[x,y]/rt[x].
  #   rp[x]    = sum{rct[x,y] : y in cw}/rt[x].
  #   cp[y]    = sum{rct[x,y] : x in rw}/sum{rt[x] : x in rw}.
  #   returned result = sum{rct[x,y] : x in rw and y in cw}/sum{rt[x] : x in rw}.
  # If rct[x,y] is the count of consecutive word pairs (x,y), then
  #   rcp[x,y] = probability of x being followed by y.
  #   rp[x]    = probability of x being followed by a word of cw.
  #   cp[y]    = probability of a word of rw being followed by y.
  #   returned result = probability of a word in rw being followed by a word in cw.
  split("", rp)
  split("", cp)
  split("", rcp)
  for(c=1; c<= ncols; c++)
    { cword = cw[c]; cp[cword] = 0 }
  tot = 0
  for (r=1; r<= nrows; r++)
    { rword = rw[r]
      sum = 0
      for(c=1; c<= ncols; c++)
        { cword = cw[c]
          temp = rct[rword,cword]
          rcp[rword, cword] = int(99.9999*(temp+0.0)/max(1,rt[rword]))
          sum += temp
          cp[cword] += temp
        }
      rp[rword] = int(99.9999*(sum+0.0)/max(1,rt[rword]))
      tot += rt[rword]
    }
  sum = 0
  for(c=1; c<= ncols; c++)
    { cword = cw[c]
      temp = cp[cword]
      cp[cword] = int(99.9999*(temp+0.0)/max(1,tot))
      sum += temp
    }
  return int(99.9999*(sum+0.0)/max(1,tot))
}

function printcolheaders(rsz, ncols, cw, csz, esz,    i, j, len, ctfmt)
{
  # rsz = maximum base word length.
  # ncols = number of columns.
  # cw = list of column words, indexed [1..ncols].
  # csz = maximum key word length.
  # esz = nominal count width.
  
  ctfmt = ("%" (esz + 1) "s")
  cw[0] = "TOT"
  
  for(i=csz;i>0;i--)
    { 
      for(j=1;j<=rsz;j++) printf " ";
      printf "  ";
      for(j=0;j<=ncols;j++)
        { len = length(cw[j])
          printf ctfmt, (i > len ? " " : substr(cw[j], len-i+1, 1))
        }
      printf "\n"
    }
  delete bw[0]
}

function printheaderdashes(rsz, ncols, cw, csz, esz,    i, j, len)
{
  # rsz = maximum base word length.
  # ncols = number of columns.
  # cw = list of column words.
  # csz = maximum word length.
  # esz = nominal count width.
  
  for(j=1;j<=rsz;j++) printf "-";

  printf " "
  for(i=0;i<esz+2;i++) printf "-"

  for(j=1;j<=ncols;j++)
    { printf " "
      for(i=0;i<esz;i++) printf "-"
    }
  printf "\n"
}

function printtitle(title, rsz,   i)
{
  # Prints title.
  printf ("\n%s:\n\n", title)
}

function ptable(title, t, nrows, rw, rsz, rt,  ncols, cw, csz, ct,  rct, esz, fmt,  \
                rword, cword, x, y, i,j, \
                efmt, tfmt, rfmt, temp)
{
  # Prints a table of word pair data. Inputs:
  #   title = table title.
  #   t     = total totalorum.
  #   nrows = number of rows.
  #   rw    = row words, indexed [1..nrows].
  #   rsz   = max row word length.
  #   rt    = row totals, indexed by word.
  #   ncols = number of columns.
  #   cw    = column words, indedex [1..ncols].
  #   csz   = max col word length.
  #   ct    = column totals, indexed by word.
  #   rct   = word pair data, indexed by [row word,column word].
  #   esz   = table entry width (excluding separating whitespace).
  #   fmt   = table entry format (should print at most esz chars).
  #
  
  printtitle(title)
  
  printheaderdashes(rsz, ncols, cw, csz, esz);
  
  printcolheaders(rsz, ncols, cw, csz, esz)
  
  printheaderdashes(rsz, ncols, cw, csz, esz);
  
  efmt = ("%" (esz + 1) "s")
  tfmt = ("%" (esz + 3) "s")
  rfmt = ("%-" rsz "s")
  
  # Print the table's body:
  
  for(i=1;i<=nrows;i++)
    { 
      rword = rw[i];
      printf rfmt, rw[i]
      printf tfmt, sprintf(fmt, rt[rword])
      
      for(j=1;j<=ncols;j++) 
        { 
          temp = rct[rword, cw[j]]
          if (temp == 0) 
            { printf efmt, "."; }
          else
            { printf efmt, sprintf(fmt, temp); }
        }
      printf "\n";
    }

  printheaderdashes(rsz, ncols, cw, csz, esz);
  
  # Print a row of column totals:
  
  printf rfmt, "TOT"
  printf tfmt, sprintf(fmt, t)
  for(j=1;j<=ncols;j++) 
    { 
      temp = ct[cw[j]]
      if (temp == 0) 
        { printf efmt, "."; }
      else
        { printf efmt, sprintf(fmt, temp); }
    }
  printf "\n";
}

function printfreqsprobs(t, nrows, rw, rsz, rt,  ncols, cw, csz, ct, rct,  digs, \
                p, rp, cp, rcp, cfmt)
{
  # Prints raw counts and probabilities for word pairs. Inputs:
  #  t     = total totalorum.
  #  nrows = number of rows.
  #  rw    = row words, indexed [1..nrows].
  #  rsz   = max row word length.
  #  rt    = row totals, indexed by word.
  #  ncols = number of columns.
  #  cw    = column words, indedex [1..ncols].
  #  csz   = max col word length.
  #  ct    = column totals, indexed by word.
  #  rct   = word pair counts, indexed by [row word,column word].
  #  digs  = number of digits in table entry
  #
  split("", rp)
  split("", cp)
  split("", rcp)
  cfmt = ("%" digs "d");
  ptable("Raw pair counts", t, nrows, rw, rsz, rt,  ncols, cw, csz, ct, rct, digs, cfmt);
  p = computerowdistrs(t, nrows, rw, rt,  ncols, cw, ct, rct,  rp, cp, rcp)
  printf "\n"
  ptable("Row probabilities", p, nrows, rw, rsz, rp,  ncols, cw, csz, cp, rcp, 2, "%2d");
  p = computecoldistrs(t, nrows, rw, rt,  ncols, cw, ct, rct,  rp, cp, rcp)
  printf "\n"
  ptable("Col probabilities", p, nrows, rw, rsz, rp,  ncols, cw, csz, cp, rcp, 2, "%2d");
}

function printfreqs(t, nrows, rw, rsz, rt,  ncols, cw, csz, ct, rct,  digs, \
                    cfmt)
{
  # Prints raw counts only for word pairs. Inputs:
  #  t    total totalorum.
  #  nrows = number of rows.
  #  rw    = row words, indexed [1..nrows].
  #  rsz   = max row word length.
  #  rt    = row totals, indexed by word.
  #  ncols = number of columns.
  #  cw    = column words, indedex [1..ncols].
  #  csz   = max col word length.
  #  ct    = column totals, indexed by word.
  #  rct   = word pair counts, indexed by [row word,column word].
  #  digs  = number of digits in table entry
  #
  cfmt = ("%" digs "d");
  ptable("Raw pair counts", t, nrows, rw, rsz, rt,  ncols, cw, csz, ct, rct, digs, cfmt);
}

/./ {
  x = $1; xok = (x in rn)
  y = $2; yok = (y in cn)
  
  if (xok && yok) rcn[x,y]++
  if (xok) rn[x]++ 
  if (yok) cn[y]++
  n++
}

/^$/ { next; }

END {
  if (awk_is_stupid == 1) { exit(1) }
  printf "\n";
  printfreqsprobs(n,   nr, rw, rsz, rn,  nc, cw, csz, cn,   rcn,  digits)
  printf "\n";
}