#! /usr/bin/gawk -f
# Last edited on 2000-05-23 01:55:23 by stolfi

BEGIN {
  abort = -1;
  usage = ( ARGV[0] " \\\n" \
    "  [ -v wordcounts=LANG.frq ] \\\n" \
    "  [ -v axiom=AXIOMSYMB ] \\\n" \
    "  [ -v separator=Z ]\\\n" \
    "  [ -v ignorecounts=BOOL ] \\\n" \
    "  [ -v eps=NUMBER ] \\\n" \
    "  [ -v maxderivs=NUM ] \\\n" \
    "  [ -v countprec=NUM ] \\\n" \
    "  [ -v terse=BOOL ] \\\n" \
    "  < INFILE.grx \\\n" \
    "  > OUTFILE.grx" \
  );
  
  # Parses the words in LANG.frq according to the probabilistic
  # grammar INFILE.grx, then recomputes the probabilities of the
  # latter and writes a new grammar OUTFILE.grx.
  # 
  # The data file LANG.frq must be in the format produced by
  # 'compute-freqs". Only the count fields are used, not the
  # probabilities.
  # 
  # The grammar files INFILE.grx and OUTFILE.grx must contain one or
  # more rules of the form
  # 
  #  SYMBOL:
  #     OLDCOUNT1 OTHER1... PROD1
  #     OLDCOUNT2 OTHER2... PROD2
  #     ...       ...       ...
  # 
  # where SYMBOL is a non-terminal symbol, each OLDCOUNTi is an
  # integer or fractional count, the OTHERi fields are zero or more
  # numeric fields, and each PRODi is an alternative for SYMBOL.
  # 
  # On input, the counts OLDCOUNTi of each SYMBOL are scaled to add to
  # 1 and interpreted as `a priori' rule probabilities. Each rule
  # PRODi must be a sequence of symbols separated by the "separator"
  # character (default "."). Symbols that are not defined by any rule
  # are assumed to be terminal. The fields OTHERi are ignored.
  # 
  # After inhaling the grammar INFILE.grx, the program will read each
  # entry (COUNT, WORD) from LANG.frq, and enumerate all possible
  # derivarions of WORD (up to a certain maximum). Each derivation
  # will be assigned a probability PDERIV based on the `a priori' rule
  # probabilities. Then each rule used K times in a given derivation
  # will have its `new' count bumped by K*COUNT*PDERIV.
  # 
  # After processing all of LANG.frq in this fashion, the grammar 
  # with is written to OUTFILE.grx, with the new counts replacing
  # the old ones.  The output rules will be sorted in order of
  # decreasing counts, and will have two OTHER fields:
  # the normalized probability of the rule, and the cumulative
  # probability of this and all preceding rules.  Comments and
  # blank lines are preserved.
  # 
  # A non-terminal symbol may contain any of the characters 
  # [A-Za-z0-9_()*+] and must begin [A-Z(].
  # 
  # The start symbol AXIOMSYMB defaults to the first non-terminal symbol
  # if not given.
  # 
  # The default data file LANG.frq may be specified in the grammar file 
  # itself, by a comment of the form 
  # 
  #    # Data-File: LANG.frq
  # 
  # The command line option has precedence over this comment. 
  # 
  # If "ignorecounts" is set to 1, the input counts are ignored and 
  # the rules of each symbol will have unifor, `a priori' probability.
  # 
  # The "eps" parameter is a fudge factor to be added to all input
  # counts before computing the probabilities.
  # 
  # If "countprec" is positive, the COUNT fields of the output grammar
  # will be printed as fractions with that many decimal fraction digits.
  # Otherwise the COUNT field will be rounded to the nearest integer.
  # 
  # Words from LANG.frq that cannot be parsed, or have too many
  # parses, are reported to the error output.
  
  # Arguments:

  if (eps == "") { eps = 0; }
  
  if (ignorecounts == "") { ignorecounts = 0; }
  
  if (terse == "") { terse = 0; }
  
  if (maxderivs == "") { maxderivs = 1; }
  
  if (separator == "")
    { separator = "."; }
  else 
    { if ( \
        (length(separator) != 1) || 
        match(separator, /[A-Za-z0-9_()*+\[\]\\ ]/) \
      ) { arg_error("bad symbol separator"); }
    }
  spat = ("[" separator "]");
  
  if (countprec == "") { countprec = 0; }
  
  # Other global variables:

  nwords = 0;          # Number of lines read from the "wordcounts" file.
  
  nsymb = 0;           # Number of non-terminal symbols.
  split("", symbol);   # "symbol[i]"  = the "i"th non-terminal symbol.
  split("", comment);  # "comment[s]" = the concatenated comments of symbol "s".
                       # "comment[s,k]" = the concatenated comments of rule "[s,k]".
  split("", nprod);    # "nprod[s]"   = the number of rules for symbol "s".
  split("", oldct);    # "oldct[s,k]" = the old count of rule "k" of symbol "s".
  split("", oldpr);    # "oldpr[s,k]" = the old prob of rule "k" of symbol "s".
  split("", prod);     # "prod[s,k]"  = rule "k" of symbol "s".
  split("", newct);    # "newct[s,k]" = the new count of rule "k" of symbol "s".
  split("", newpr);    # "newpr[s,k]" = the new prob of rule "k" of symbol "s".

  # Variables used while inhaling grammar:
  nextcmt = "";  # Comments for next symbol or rule.
  cursymb = "";  # Current non-terminal symbol
}

# Inhale input grammar:

(abort >= 0) { exit abort; }

/^ *$/ {
  nextcmt = ( nextcmt "\n" );
  next;
}

/^[#][ ]*Data-File[ ]*[:]/ {
  fn = $0;
  gsub(/^[#][ ]*Data-File[ ]*[:][ ]*/, "", fn);
  gsub(/[ ]*$/, "", fn);
  if (wordcounts == "") { wordcounts = fn; }
}

/^[#]/{ 
  nextcmt = ( nextcmt $0 "\n" );
  next;
}

/^[A-Z(][A-Za-z0-9_()*+]*[ ]*[:][ ]*$/ {
  cursymb = $0; 
  gsub(/[ ]*[:][ ]*$/, "", cursymb);
  if (cursymb in nprod) { grammar_error(("repeated symbol \"" cursymb "\"")); }
  symbol[nsymb] = cursymb;
  nprod[cursymb] = 0;
  comment[cursymb] = nextcmt; nextcmt = "";
  if (axiom == "") { axiom = cursymb; }
  nsymb++;
  next;
}

/^ *[0-9.]/ {
  if (cursymb == "") { grammar_error("rule without head symbol"); }
  if (NF < 2) { grammar_error("bad rule format"); }
  if (! match($1, /[0-9]*([0-9]|([0-9][.]|[.][0-9]))[0-9]*/))
    { grammar_error("bad rule count field"); }
  k = nprod[cursymb];
  if (ignorecounts) { ct = 1; } else { ct = $1 + eps; }
  comment[cursymb,k] = nextcmt; nextcmt = "";
  oldct[cursymb,k] = ct;
  def = $(NF);
  # Ensure a trailing separator, eliminate double and leading separators:
  def = (def separator); 
  gsub((spat spat), separator, def);
  while(substr(def,1,1) == separator) { def = substr(def,2); }
  prod[cursymb,k] = def;
  nprod[cursymb]++;
  next;
}

// { 
  grammar_error("bad line format"); 
}

END {
  if (abort >= 0) { exit abort; }
  if (nsymb == 0) { grammar_error("empty grammar"); }
  
  # Comments at end of grammar:
  final_comment = nextcmt;
  gsub(/[\n]*$/, "", final_comment);

  if (wordcounts == "") { arg_error("must define \"wordcounts\""); }
  if (axiom == "") { prog_error("axiom was not specified"); }
  if (! (axiom in nprod)) { prog_error("axiom must be a non-terminal symbol"); }

  normalize_probs(oldct,oldpr);
  clear_counts(newct);
  parse_and_tally_file(oldpr,newct);
  normalize_probs(newct,newpr);
  write_new_grammar(newct,newpr);
}

function normalize_probs(ct,pr,   i,s,m,k,tot)
{
  for (i = 0; i < nsymb; i++) 
    { s = symbol[i];
      tot = 0.000000;
      m = nprod[s];
      for (k = 0; k < m; k++) { tot += ct[s,k]; }
      if (tot == 0)
        { printf "warning: zero total count for \"%s\"\n", s > "/dev/stderr"; }
      for (k = 0; k < m; k++) 
        { pr[s,k] = (tot == 0 ? 1.0/m : ct[s,k]/tot); }
    }
}

function clear_counts(ct,   i,s,m,k)
{
  for (i = 0; i < nsymb; i++) 
    { s = symbol[i];
      m = nprod[s];
      for (k = 0; k < m; k++) { ct[s,k] = 0; }
    }
}

function parse_and_tally_file(oldpr,newct,  lin,ct,wd,totct)
{
  if (! terse) { printf "reading %s ...\n", wordcounts > "/dev/stderr"; }
  nwords = 0; totct = 0;
  while ((getline lin < wordcounts) > 0 )
    { nwords++;
      if (! match(lin, /^[#]/))
        { nfld = split(lin, fld, " ");
          if (nfld != 3) word_error(("bad word entry = \"" lin "\""));
          ct = fld[1]; wd = fld[3];
          if (! match(ct, /[0-9]*([0-9]|([0-9][.]|[.][0-9]))[0-9]*/))
            { word_error("bad word count field"); }
          if (wd == separator)
            { wd = ""; }
          else if (match(wd, spat))
            { word_error("word contains separator character"); }
          totct += ct;
          parse_and_tally_word(wd,ct,oldpr,newct);
        }
    }
  if (ERRNO != "0") { word_error((wordcounts ": " ERRNO)); }
  close (wordcounts);
  if (nwords == 0) { arg_error(("file \"" wordcounts "\" empty or missing")); }
  if (! terse) { printf "read %d counts of %6d words\n", totct, nwords > "/dev/stderr"; }
}

function parse_and_tally_word(wd,ct,oldpr,newct)
{
  # Global variables for parser:
  
  split("", rymb);
  split("", ralt);
  split("", rprev);
  rfree = 0;  # "{rsymb,ralt,rprev}[0..rfree-1]" are the confirmed derivations.
  
  # A derivation is represented by the list of rules in left-first order.
  # An element of the list is stored in "rsymb[j]" (the non-terminal 
  # symbol) and "ralt[j]" (the index of the rule relative to that symbol).
  # The index of the *previous* element of the list is "rprev[j]".
  
  nderivs = 0;         # Number of derivations for this word.
  split("", derivix);  # "derivix[t]" is the index of the last rule of deriv number "t".
  split("", derivpr);  # "derivpr[t]" is the corresponding probability.
  
  enum_derivs(wd,(axiom "."),-1,1.000000);
  # printf "nderivs = %d\n", nderivs > "/dev/stderr";
  tally_derivs(wd,ct,newct);
}

function enum_derivs(wd,def,rbase,pbase,  s,k,r)
  {
    # Enumerates all leftmost derivations for "wd" from the mixed
    # terminal/non-terminal string "def", stacks them in "derivix[]" and
    # puts their corresponding probabilities in "derivpr[]". Each generated
    # derivation gets prefixed with the fixed derivation starting with "rbase",
    # whose probability is "pbase".

    if (nderivs > maxderivs) { return; }
    
    # printf "et(\"%s\",\"%s\",%d,%9.7f)\n", wd,def,rbase,pbase > "/dev/stderr";
        
    while (def != "") 
      { # Get first component of "def":
        if (! match(def, spat)) { prog_error("lost separator in rule"); }
        s = substr(def,1,RSTART-1);
        def = substr(def,RSTART+1);
        if (s == "") { prog_error("empty component in rule"); }
        if (s in nprod)
          { # Non-terminal symbol - try all alternatives:
            for (k = 0; k < nprod[s]; k++)
              { # Try adding this production to the derivation:
                r = rbase+1; if (r < rfree) { r = rfree; }
                rsymb[r] = s;
                ralt[r] = k;
                rprev[r] = rbase;
                enum_derivs(wd, (prod[s,k] def), r, oldpr[s,k]*pbase);
              }
            return;
          }
        else if (s != substr(wd,1,length(s)))
          { # No match, sorry:
            return;
          }
        else
          { wd = substr(wd,length(s)+1); }
      }
    if (wd == "") 
      { # Found another derivation for the word, hooray!
        derivix[nderivs] = rbase;
        derivpr[nderivs] = pbase;
        nderivs++;
        rfree = rbase+1;
        rtemp = rfree;
      }
  }
  
function dump_deriv(r,p, s,k)
  { printf "  --------------------\n" > "/dev/stderr";
    printf "  prob = %9.7f\n", p > "/dev/stderr";
    while (r >= 0)
      { s = rsymb[r];
        k = ralt[r];
        printf "    %d: %s -> %s\n", r,s,make_vis(prod[s,k]) > "/dev/stderr";
        r = rprev[r];
      }
  }

function tally_derivs(wd,ct,newct,  t,tot,p,r,s,k)
  {
    # For each rule "s,k" used in each deriv "t",
    # adds to "newct[s,k]" the count "ct" times "derivpr[t]",
    # after normalizing "derivpr[]" to unit sum.

    if (nderivs == 0) 
      { printf "not in language, added: %7d %s\n", ct, make_vis(wd) >> "/dev/stderr";
        s = axiom; k = nprod[axiom];
        prod[s,k] = ( wd separator );
        newct[s,k] = ct;
        nprod[s]++;
        return;
      }

    if (nderivs > maxderivs) 
      { printf "%s has %d or more derivations\n", make_vis(wd), nderivs >> "/dev/stderr"; 
        if (! terse) 
          { for (t = 0; t < nderivs; t++)
              { dump_deriv(derivix[t], derivpr[t]); }
          }
      }

    # Normalize the derivation probabilities:
    tot = 0.00000000;
    for(t = 0; t < nderivs; t++) { tot += derivpr[t]; }
    if (tot == 0) 
      { printf "warning: \"%s\" has probability zero\n", wd >> "/dev/stderr"; }
    for(t = 0; t < nderivs; t++)
      { # dump_deriv(derivix[t], derivpr[t]/tot);
        p = (tot == 0 ? 1.0/nderivs : derivpr[t]/tot);
        r = derivix[t];
        while (r >= 0)
          { s = rsymb[r];
            k = ralt[r];
            r = rprev[r];
            newct[s,k] += p*ct;
          }
      }
  }

function write_new_grammar(ct,pr,   i,s,k,m,def,np,c,p,pcum)
  {
    for(i = 0; i < nsymb; i++)
      { s = symbol[i];
        np = nprod[s];
        printf "%s", comment[s];
        printf "%s:\n", s;
        pcum = 0;
        for (k = 0; k < np; k++)
          { printf "%s", comment[s,k];
            def = (prod[s,k] separator);
            m = length(def);
            while((m > 1) && (substr(def,m,1) == separator)) 
              { def = substr(def,1,m-1); m--; }
            c = ct[s,k]; p = pr[s,k]; pcum += p;
            if (countprec > 0)
              { printf "%12.*f %7.5f %7.5f %s\n", countprec, c, p, pcum, def; }
            else
              { printf "%7d %7.5f %7.5f %s\n", int(c+0.5), p, pcum, def; }
          }
      }
    printf "%s", final_comment;
    printf "\n";
    fflush("/dev/stdout");
  }
  
function make_vis(wd)
{
  return(wd == "" ? separator : wd);
}

function arg_error(msg)
  {
    printf "%s\n", msg > "/dev/stderr";
    printf "usage: %s\n", usage > "/dev/stderr";
    abort = 1;
    exit abort;
  }

function grammar_error(msg)
  {
    printf "input grammar, line %d: %s\n", FNR, msg > "/dev/stderr";
    abort = 1;
    exit abort;
  }

function word_error(msg)
  {
    printf "file %s, line %d: %s\n", wordcounts, nwords, msg > "/dev/stderr";
    abort = 1;
    exit abort;
  }

function prog_error(msg)
  {
    printf "*** program error: %s\n", msg > "/dev/stderr";
    abort = 1;
    exit abort;
  }