#! /usr/bin/gawk -f
# Last edited on 1999-07-28 02:02:31 by stolfi

BEGIN {
  abort = -1;
  
  usage = ( \
    "compute-dist-matrix \\\n" \
    "    -v nItems=NITEMS \\\n" \
    "    -v countDir=DIR \\\n" \
    "    -v units='U1 U2...UN'" \
  );

  # Computes a distance matrix for a set of text units
  # (bifolios, folios, pages, paragraphs, etc.) given 
  # a set of item count vectors, one for each unit.
  #
  # The units are assumed to be named U1, U2, ... UN. The item count
  # vector for a unit named U is assumed to be stored in file
  # DIR/U.frq. It should contain one or more records of the form COUNT
  # FREQ ITEM, where COUNT is the number of occurrences of ITEM in the
  # units. The FREQ field is ignored, as well as blank lines,
  # #-comments, and entries with COUNT=0.
  #
  # The ITEMs listed in those data files are assumed to be subsets
  # of a global "vocabulary" with NITEMS elements.  The count aand
  # frequency vectors for every unit are implicitly completed with 
  # zero entries, so that they all have NITEMS elements.
  #
  # The matrix entry corresponding to a pair of units U,V is the
  # expected value of the square of the Euclidean distance between the
  # item frequency vectors for those two units, as estimated from the
  # given item counts.
  #
  # The program writes to disk the matrix, with row and column headings
  # at top and left, as expected by the optimization tools.
  
  nItems += 0;
  if (nItems == 0) { error("must define \"nItems\""); }
  if (countDir == "") { error("must define \"countDir\""); }
  if (units == "") { error("must define \"units\""); }
  
  printf "nItems = %s\n", nItems > "/dev/stderr";
  
  nUnits = split(units, uName);  # "uName[1..nUnits]" are the names of the units.
  printf "nUnits = %d\n", nUnits > "/dev/stderr";
  
  split("", dm);
  split("", P);
  split("", Q);
  
  nameWd = 1;
  for (uP=1; uP<= nUnits; uP++) 
    { lng = length(uName[uP]); if (lng>nameWd) { nameWd = lng;} }

  for (uP=1; uP<= nUnits; uP++)
    { 
      
      printf "=== %s ====================\n", uName[uP] > "/dev/stderr";
      read_counts(countDir, uName[uP], P);
      SP = hist_sum(P, nItems);
      SPP = hist_sum_sq(P, nItems);
      printf "SP = %.5f SPP = %.5f\n", SP, SPP > "/dev/stderr";
      
      dm[uP,uP] = 0;
      
      for (uQ=uP+1; uQ<=nUnits; uQ++)
        { 
          printf "== %s ==\n", uName[uQ] > "/dev/stderr";
          read_counts(countDir, uName[uQ], Q);
          SQ = hist_sum(Q, nItems);
          SQQ = hist_sum_sq(Q, nItems);
          printf "SQ = %.5f SQQ = %.5f\n", SQ, SQQ > "/dev/stderr";

          dm[uP,uQ] = hist_dist_sqr(P, SP, SPP, Q, SQ, SQQ, nItems);
          printf "dm[%d,%d] = %.8f\n", uP,uQ, Ed2 > "/dev/stderr";
          dm[uQ,uP] = dm[uP,uQ];
        }
      printf "\n" > "/dev/stderr";
    }
  
  write_matrix(dm, uName, nUnits);
}

function read_counts(dir, uName, P,   nr,fname,lin,fld,nfld,Pk,k) 
{
  # Reads item counts "P[it]" from file "dir/uName.frq" 
  
  nr = 0; 
  fname = (dir "/" uName ".frq");
  split("", P);
  while ((getline lin < fname) > 0)
    { if (! match(lin, /^([#]|[ ]*$)/))
        { nfld = split(lin, fld, " ");
          if (nfld != 3) error(("bad counts entry = \"" fname ":" lin "\""));
          Pk = fld[1]; k = fld[3];
          if (Pk > 0)
            { if (k in P) error(("repeated item = \"" fname ":" lin "\""));
              printf "%7d %s\n", Pk, k > "/dev/stderr";
              P[k] = Pk;
              nr++;
            }
        }
    }
  if (ERRNO != "0") { error((fname ": " ERRNO)); }
  close (fname);  
}

function hist_sum(P,nItems,  k,S)
{
  # Returns the sum of "P[k]+1" for all items "k"
  S = 0;
  for (k in P) { S += P[k]; }
  return S + nItems;
}

function hist_sum_sq(P,nItems,  k,S,Pk)
{
  # Returns the sum of "(P[k]+1)*(P[k]+2)" for all items
  S = 0;
  for (k in P) { Pk = P[k]; S += Pk*(Pk+3); }
  return S + 2*nItems;
}

function hist_dist_sqr(P,SP,SPP,Q,SQ,SQQ,nItems,    k,n,Pk,Qk,SPQ,Ed2)
{
  # Computes the expected distance squared between two frequency 
  # vectors "fu[k]" and "fv[k]"
  # over "nItems" items, given the counts "P[k]" and "Q[k]"
  # for a finite sample of each distribution.
  #
  # Also requires the biased sums "SP" and "SPP" for "P",
  # "SQ" and "SQQ" for "Q".
  
  n = 0;
  SPQ = 0;
  for(k in P)
    { Pk = P[k];
      if (k in Q)
        { Qk = Q[k]; SPQ += (Pk+1)*(Qk+1); }
      else
        { SPQ += (Pk+1); }
      n++;
    }
  for(k in Q)
    { if (!(k in P))
        { Qk = Q[k]; SPQ += (Qk+1); n++; }
    }
  if (n > nItems) { error(("too many items " n " " nItems)); }
  SPQ += (nItems - n);
  Ed2 = SPP/(SP*(SP+1)) - 2*SPQ/(SP*SQ) + SQQ/(SQ*(SQ+1));
  printf "SPQ = %.5f Ed2 = %.8f\n", SPQ, Ed2 > "/dev/stderr";
  return Ed2;
}
  
function write_matrix(dm,uName,nUnits,  uP,uQ)
{
  printf "%-*s ", nameWd, "";
  for (uQ=1; uQ<=nUnits; uQ++)
    { printf " %8s", uName[uQ]; }
  printf "\n"
  
  printf "%-*s ", nameWd, "";
  for (uQ=1; uQ<=nUnits; uQ++)
    { printf " %8s", "--------"; }
  printf "\n"
  
  for (uP=1; uP<=nUnits; uP++)
    { printf "%-*s ", nameWd, uName[uP];
      for (uQ=1; uQ<=nUnits; uQ++)
        { printf " %8.6f", dm[uP,uQ]; }
      printf "\n"
    }
}
    
function error(msg)
{ 
  printf "line %d: %s\n", NR, msg >> "/dev/stderr";
  abort = 1;
  exit 1
}