#! /usr/bin/gawk -f

# Usage: $0 [ -v chars=XXXX ] < NGRAMFILE > PTABLES

# Prints to stdout a table of transition counts.
# Also computes several derived tables (probabilities, entropies, etc.).
# Each line of input should consist of N>1 characters, of which 
# the first N-1 are interpreted as a "state" and the last one as a "transition".

# Each table has one row for each state and one column for each transition.

# Externally defined parameters (with "-vVAR=VALUE"): 
#   chars            string specifying the ordering of transition chars

# Global variables:
#   n          total digraph count
#   xn[x]      num occurrences of state x
#   yn[x]      num occurrences of transition y
#   xyn[x,y]   num occurences of pair (x,y)
#   maxslen    max state length

BEGIN { 
  n = 0; 
  split("", xn); 
  split("", yn); 
  setmarks(chars, mark);
  maxslen = 1;
}

function checkstate(x,  z, m)
{
  if (!(x in xn))
    { for (z in yn)
        { xyn[x,z] = 0; }
      xn[x] = 0
      m = length(x); 
      if (m > maxslen) { maxslen = m; }
    }
}

function checktrans(y,  z)
{
  if (!(y in yn))
    { for (z in xn)
        { xyn[z,y] = 0; }
      yn[y] = 0
    }
  if (!(y in mark))
    {
      printf "Warning: undeclared character '%s'\n", y;
      mark[y] = 1;
    }
}

function setmarks(chars, mark,   i,y)
{
  # Sets mark[y] = true for every letter in chars
  split("", mark)
  for(i=1;i<=length(chars);i++)
    { y = substr(chars,i,1); 
      if (y in mark) { printf "duplicate char '%s'\n", y; exit 1 }
      mark[y] = 1;
    }
}

function fixchars(chars, yt, rows,    temp, i, x)
{
  # Sets rows[y] = 1 for every character y  that has an entry in yt.
  # Returns a string consisting only of the defined characters, in the 
  # order given by "chars". Characters that do not occur in "chars" are 
  # appended at the end.
  
  temp = ""
  for (i=1;i<length(chars);i++) 
    { y = substr(chars,i,1); 
      if(y in yt)
        { temp = (temp y); rows[y] = 1 }
    }
  for(y in yt) 
    { if (! (y in rows)) 
        { rows[y] = 1; temp = (temp y); }
    }
  return temp;
}

function ptable(n, h, xn, xh, yt, xyt, fmt, esz, chars,    \
    x, y, rows, nchars, i,j, temp, p)
{
  # Prints a transition table
  # Inputs:
  #   n      = total transition count
  #   h      = total transition entropy 
  #   xn     = transition counts per state
  #   xh     = transition entropy per state
  #   yt     = column totals 
  #   xyt    = elements
  #   fmt    = format spec for elements
  #   esz    = formatted element size in bytes
  #   chars  = list of characters to use as colum indices
  
  # Reduce/complete "chars" to include all transitions that occured one or more times:
  split("", rows);
  chars = fixchars(chars, yt, rows);
  nchars = length(chars)
  
  printf "  %*.*s", maxslen, maxslen, "";
  printf " %5s %5s %5s %5s", "count", "freq", "ntrpy", "pntpy";
  for(j=1;j<=nchars;j++)
    { y = substr(chars,j,1); printf " %*s", esz, y; }
  printf "\n";

  printf "  %*.*s", maxslen, maxslen, "------------------------------";
  printf " %5.5s %5.5s %5.5s %5.5s", "---------", "---------", "---------", "---------";
  for(j=1;j<=nchars;j++) printf " %*.*s", esz, esz, "---------";
  printf "\n";

  for(x in xn)
    { 
      printf "  %*s", maxslen, x;
      p = xn[x]/n;
      printf " %5d %5.3f %5.3f %5.3f", xn[x], p, xh[x], p*xh[x];
      for(j=1;j<=nchars;j++) 
        { y = substr(chars,j,1);
          if (xyt[x,y] == 0) 
            { printf " %*s", esz, "."; }
          else
            { printf " %*s", esz, sprintf(fmt, xyt[x,y]); }
        }
      printf "\n";
    }

  printf "  %*.*s", maxslen, maxslen, "";
  printf " %5.5s %5.5s %5.5s %5.5s", "---------", "---------", "---------", "---------";
  for(j=1;j<=nchars;j++) printf " %*.*s", esz, esz, "---------";
  printf "\n";

  printf "%*.*s", maxslen+2, maxslen+2, "TOT";
  printf " %5d %5.3f %5.3f %5.3f", n, 1.0, h, h;
  for(j=1;j<=nchars;j++) 
    { y = substr(chars,j,1);
      printf " %*s", esz, sprintf(fmt, yt[y]); 
    }
  printf "\n";
}

function entropy(p)
{
  if (p == 0) 
    { return 0.0 }
  else
    { return (- p * log(p)/log(2.0)) }
}

function pscale(p,m)
{ return int((m*p)+ 0.5) }

function sscale(v,m)
{ return int((m*v) + (v > 0 ? 0.5 : -0.5)) }

/./ {
  w = $0;
  m = length(w);
  x = substr(w,1,m-1);   checkstate(x);
  y = substr(w,m,1);     checktrans(y);
  xyn[x,y]++;
  xn[x]++;
  yn[y]++;
  n++;
  next;
}

// { next; }

END {
  
  printf "State entropy: "
  h = 0.000
  for (x in xn)
    { h += entropy(xn[x]/n); }
  printf "%.3f\n\n", h;

  printf "Transition entropy: "
  h = 0.000
  for (y in yn)
    { h += entropy(yn[y]/n); }
  printf "%.3f\n\n", h;

  for (x in xn) xh[x] = 0.000;
  h = 0.000
  for (x in xn)
    { 
      for (y in yn)
        { xh[x] += entropy(xyn[x,y]/xn[x]); }
      h += xh[x] * (xn[x]/n);
    }

  printf "Transition counts:\n";
  printf "\n";
  ptable(n, h, xn, xh, yn, xyn, "%d", 5, chars)
  printf "\n";
  
  printf "Transition probabilities (× 99):\n";
  printf "\n";
  for (y in yn) yp[y] = pscale(yn[y]/n, 99);
  for (x in xn)
    for (y in yn) 
      { xyp[x,y] = pscale(xyn[x,y]/xn[x], 99); }
  ptable(n, h, xn, xh, yp, xyp, "%d", 2, chars)
  printf "\n";

  exit 0;
}