#! /usr/bin/gawk -f
# Last edited on 2001-01-08 14:06:29 by stolfi

BEGIN{
  abort = -1;
  usage = ( "compute-elem-distances \\\n" \
    "  { -v elemList='a,o,...' | -v elemTable=FILE } \\\n" \
    "  [ -v exponent=NUM ] \\\n" \
    "  < INFILE.wct > OUTFILE.tex" \
  );
  
  # Reads from standard input a table of digraph counts 
  # in the format
  # 
  #   COUNT SYMB1 SYMB2
  # 
  # Computes a dissimilarity measure d(u,v) for all pairs
  # of distinct elems (u,v), and writes out a list of pairs
  # 
  #   SYMB1 SYMB2 DIST
  #
  # where DIST is d(u,v) raised to the specified exponent.
  # Only the symbols listed in the `elemList' or in the 
  # file `elemTable' will be tabulated, ignoring entries
  # which are "+", "-", "/", "~".

  if (exponent == "") { exponent = 1.0; }

  if ((elemList == "") == (elemTable == ""))
    { arg_error("must define exactly one of \"elemList\" and \"elemTable\""); }
  split("", elem);
  split("", eindex);
  split("", eclass);
  if (elemList != "") 
    { nelems = parse_explicit_elems(elemList,elem,eindex,eclass); }
  else
    { nelems = load_elems_from_file(elemTable,elem,eindex,eclass); }
    
  # Eliminate bogus elements:
  
  k = 0;
  for (i = 1; i <= nelems; i++) 
    { if (elem[i] !~ /[-/~+]/) 
        { k++; elem[k] = elem[i]; eindex[elem[k]] = k; }
    }
  nelems = k;

  split("", pairCt);
}

(abort >= 0) { exit abort; }

/^ *([#]|$)/ { next; }

/./ { 
  if (NF != 3) { data_error("bad line format"); }
  ct = $1; x = $2; y = $3;
  if ((x !~ /[-/~+]/) && (y !~ /[-/~+]/)) { pairCt[x,y] += ct; }
  next;
}

END {
  if (abort >= 0) { exit abort; }
  compute_distance_matrix();
  for (i = 1; i <= nelems; i++)
    { for (j = 1; j <= nelems; j++)
        { u = elem[i]; v = elem[j];
          printf "%-7s %-7s %8.5f\n", u, v, d[u,v]; 
        }
    }
}

function compute_distance_matrix(    \
    pprev,eprev,hprev,dprev, \
    pnext,enext,hnext,dnext, \
    tot,u,v,x,y,i,j,k,t \
)
{
  # Global input: pairCt[u,v], indexed by elems u,v.
  # Global output: d[u,v], ditto.
  # 
  # Let G be the set of basic elems, and let pprev(x,u) and
  # pnext(u,y) be the probabilities of x right before and after an u,
  # respectively.  Let eprev(u) be the nominal error 
  # in pprev(x,u).  We fudge pprev(x,u) so that it lies in the range
  # [eprev(u)..1].  Then we define
  # 
  #   hprev(x,u) = log(pprev(x,u)/eprev(u))/log(1/eprev(u))
  # 
  # Note that hprev(x,u) ranges in [0..1]. We then define
  # 
  #   dprev(u,v) = sqrt(sum{ (hprev(x,u) - hprev(x,v))^2 : x in G }/|G|)
  #    
  #  and similarly for dnext(u,v). The elem distance is then
  # 
  #   d(u,v) = sqrt((dnext(u,v)^2 + dprev(u,v)^2)/2)
  # 

  split("", pprev);  # pprev[x,u] is prob(x before u), fudged.
  split("", eprev);  # eprev[u] is the minimum value of pprev[u,y]
  split("", hprev);  # hprev[x,u] is log(pprev[x,u]/eprev[u])/log(1/eprev[u]).
  
  for (i = 1; i <= nelems; i++)
    { u = elem[i]; 
      tot = 0;
      for (j = 1; j <= nelems; j++) 
        { x = elem[j]; tot += pairCt[x,u]; }
      eprev[u] = 1/(tot + nelems);
      for (j = 1; j <= nelems; j++) 
        { x = elem[j]; 
          pprev[x,u] = (pairCt[x,u] + 1)/(tot + nelems);
          hprev[x,u] = log(pprev[x,u]/eprev[u])/log(1.0/eprev[u]);
        }
    }
  
  split("", pnext);  # pnext[u,y] is prob(x after u), fudged.
  split("", enext);  # enext[u] is the minimum value of pnext[u,y].
  split("", hnext);  # hnext[u,y] is log(pnext[u,y]/enext[u])/log(1/enext[u]).

  for (i = 1; i <= nelems; i++)
    { u = elem[i]; 
      tot = 0;
      for (j = 1; j <= nelems; j++) 
        { y = elem[j]; tot += pairCt[u,y]; }
      enext[u] = 1/(tot + nelems);
      for (j = 1; j <= nelems; j++)
        { y = elem[j]; 
          pnext[u,y] = (pairCt[u,y] + 1)/(tot + nelems);
          hnext[u,y] = log(pnext[u,y]/enext[u])/log(1.0/enext[u]);
        }
    }
  
  split("", d);

  for (i = 1; i <= nelems; i++)
    { u = elem[i]; 
      for (j = 1; j <= nelems; j++)
        { v = elem[j];
          dprev = 0; dnext = 0;
          for (k = 1; k <= nelems; k++)
            { x = elem[k];
              t = (hprev[x,u] - hprev[x,v]);
              dprev += t*t;
              t = (hnext[u,x] - hnext[v,x]);
              dnext += t*t;
            }
          dprev /= nelems; dnext /= nelems;
          printf "%-7s %-7s hprev = %8.5f hnext = %8.5f dprev = %8.5f dnext = %8.5f\n", 
            u, v, hprev[u,v], hnext[u,v], dprev, dnext > "/dev/stderr";
          t = (dnext + dprev)/2;
          d[u,v] = (t <= 0 ? 0 : exp(log(t)*(0.5*exponent)));
        }
    }
} 

function arg_error(msg)
{ 
  printf "%s\n", msg > "/dev/stderr";
  printf "usage: %s\n", usage > "/dev/stderr";
  abort = 1; exit 1;
}

function data_error(msg)
{ 
  printf "line %d: %s\n", FNR, msg > "/dev/stderr";
  abort = 1; exit 1;
}