#! /usr/bin/gawk -f
# Last edited on 2004-10-29 09:44:19 by stolfi

BEGIN {
  abort = -1;
  usage = ( ARGV[0] " [-v includeBlack=BOOL] -v cr=NUM -v cg=NUM -v cb=NUM -v < IMG.hist" );
  
  # Reads a color image histogram, as produced by "ppmhist".
  # Finds the plane that minimizes the mean square deviation 
  # of all pixels. Uses {(cr, cg, cb)} as the starting
  # guess for the plane's coefficients. If {includeBlack} is TRUE, 
  # includes zero pixels in the analysis, otherwise omits them.
  
  if (cr == "") { cr = 0; }
  if (cg == "") { cg = 0; }
  if (cb == "") { cb = 0; }
  if (includeBlack == "") { includeBlack = 0; }
  
  SR = 0; SG = 0; SB = 0; # Sum of all pixels
  SN = 0; # Total number of pixels

  nc = 0; # Number of distinct colors
  # Indexed [0..nc-1]:
  split("", R); split("", G); split("", B); # 
  split("", N); # {N[i]} is the number of pixels with color {i}.
  
}

(abort >= 0) { exit abort; }

/^ *[0-9]/{
  if (NF != 5) { data_error(("bad NF = " NF)); }
  r = $1; g = $2; b = $3; n = $5;
  if ((! includeBlack) && (r+0 == 0) && (g+0 == 0) && (b+0 == 0)) { next; }
  R[nc] = r; G[nc] = g; B[nc] = b; N[nc] = n;
  SR += n*r; SG += n*g; SB += n*b; SN += n;
  nc++;
}

END {
  if (abort >= 0) { exit abort; }
  # Compute barycenter of colors:
  CR = SR/SN; CG = SG/SN; CB = SB/SN;
  # Compute moment matrix:
  split("", M);
  for (i = 0; i < nc; i++)
    { r = R[i] - CR; g = G[i] - CG; b = B[i] - CB; wt = N[i]/SN;
      M[0,0] += wt*r*r; M[0,1] += wt*r*g; M[0,2] += wt*r*b;
      M[1,1] += wt*g*g; M[1,2] += wt*g*b;
      M[2,2] += wt*b*b;
    }
  M[1,0] = M[0,1]; 
  M[2,0] = M[0,2]; M[2,1] = M[1,2];

  # Build unit normal vector {u}:
  split("", u);
  u[0] = cr+0; u[1] = cg+0; u[2] = cb+0;

  # Make sure initial guess is not all zeros:
  if ((u[0] == 0) && (u[1] == 0) && (u[2] == 0)) { u[1] = 1.0; }
  normalize(u);
  
  split("", v);
  split("", w);

  sdev = optimize(u,v,w,M);
  
  # Barycenter:
  split ("", o);
  o[0] = CR; o[1] = CG; o[2] = CB;
  printf "o = %05.1f %05.1f %05.1f / 255\n", o[0], o[1], o[2];
  # {u,v,w} coordinates of barycenter:
  uC = u[0]*o[0] + u[1]*o[1] + u[2]*o[2];
  vC = v[0]*o[0] + v[1]*o[1] + v[2]*o[2];
  wC = w[0]*o[0] + w[1]*o[1] + w[2]*o[2];
  # Unit normal vector and two orthogonal vectors on the plane:
  printf "u = %0+6.4f %0+6.4f %0+6.4f  uC = %0+6.1f\n", u[0], u[1], u[2], uC;
  printf "v = %0+6.4f %0+6.4f %0+6.4f  vC = %0+6.1f\n", v[0], v[1], v[2], vC;
  printf "w = %0+6.4f %0+6.4f %0+6.4f  wC = %0+6.1f\n", w[0], w[1], w[2], wC;
  # Root mean square deviation
  printf "sdev = %06.4f\n", sdev;
  # Neutral color:
  split("", bg);
  uwh = 255*(u[0]+u[1]+u[2]) - uC;
  ubk = -uC
  y = ( uwh == 0 ? 0.0 : uC/uwh );
  if ((uwh*ubk < 0) && (uwh != 0))
    { twh = abs(uwh); tbk = abs(ubk);
      for (r = 0; r < 3; r++) 
        { bg[r] = (tbk*255 + twh*0)/(tbk + twh); }
      printf "bg = %05.1f %05.1f %05.1f / 255\n", bg[0], bg[1], bg[2];
    }
  # Projections of white and black
  split("", pwh); split("", pbk); 
  for (r = 0; r < 3; r++) 
    { pwh[r] = 255 - uwh*u[r]; pbk[r] = - ubk*u[r]; }
  printf "pwh = %05.1f %05.1f %05.1f / 255\n", pwh[0], pwh[1], pwh[2];
  printf "pbk = %05.1f %05.1f %05.1f / 255\n", pbk[0], pbk[1], pbk[2];
  # Pure hue colors
  split("", ha); ha[0] = 255; ha[1] = 000; ha[2] = 000;
  split("", hb); hb[0] = 255; hb[1] = 255; hb[2] = 000;
  split("", h);
  for (k = 0; k < 6; k++)
    { ta = ha[0]*u[0] + ha[1]*u[1] + ha[2]*u[2] - uC;
      tb = hb[0]*u[0] + hb[1]*u[1] + hb[2]*u[2] - uC;
      if ((ta*tb <= 0.0) && (tb != 0.0))
        { ta = abs(ta); tb = abs(tb);
          for (r = 0; r < 3; r++) { h[r] = (tb*ha[r] + ta*hb[r])/(ta+tb); }
          printf "hue = %05.1f %05.1f %05.1f / 255\n", h[0], h[1], h[2];
        }
      # Cycle {ha,hb} to next pure hue pair
      ht = ha[2];
      for (r = 0; r < 3; r++) { hs = ha[r]; ha[r] = hb[r]; hb[r] = ht; ht = hs; }
    }
  # Plane equation:
  printf "set cr = \"%0+6.1f\"\n", u[0]*1000;
  printf "set cg = \"%0+6.1f\"\n", u[1]*1000;
  printf "set cb = \"%0+6.1f\"\n", u[2]*1000;
  printf "set ct = \"%0+6.1f\"\n", -uC*1000;
}

function optimize(u,v,w,M,  \
  un,vn,r,r1,r2,rmax,s,aur,aurmax, \
  phi,tau,ctau,stau,taumag,taured, \
  bet,cbet,sbet, \
  iter,maxiter,nfail,maxfail, \
  D2,D2min \
)
{ 
  # Updates the unit vector {u} to minimize the mean-square error
  # quadric {u M u'} where { u' = transpose(u)}. Also computes two
  # unit vectors {v,w} orthogonal to {u} and to each other. Returns
  # the root-mean-square deviation.
  
  split("", un);
  split("", vn);

  pi = 3.1415926; # Close enough...
  phi = (sqrt(5)-1)/2; # Golden ratio
  taumag = 1.2;
  taured = exp(-log(taumag)*phi);
  maxiter = 300;  # Max total iterations
  maxfail = 30;  # Max consecutive failed iterations

  # Find dominant coordinate {rmax} of {u}
  rmax = 0;
  for (r = 1; r < 3; r++)
    { if (abs(u[r]) > abs(u[rmax])) { rmax = r; } }
  if (u[rmax] == 0) { u[0] = 1; rmax = 0; }

  # Find an orthogonal unit vector {v} on plane
  r1 = (rmax + 1) % 3;
  r2 = (rmax + 2) % 3;
  v[rmax] = -u[r1]; v[r1] = u[rmax]; v[r2] = 0;
  normalize(v);
  
  # Rotation angle
  bet = 2*pi*phi; # In radians
  cbet = cos(bet); sbet = sin(bet);

  # Main loop:
  nfail = 0;        # Num consecutive failed iterations.
  tau = 0.001*pi/4; # Tilt angle (initial).
  D2min = 999999999.0; # Hope it is large enough.
  printf "[" > "/dev/stderr";
  for (iter = 0; 1; iter++)
    { # Find the third orthogonal direction {w}:
      w[0] = u[1]*v[2] - u[2]*v[1];
      w[1] = u[2]*v[0] - u[0]*v[2];
      w[2] = u[0]*v[1] - u[1]*v[0];
      normalize(w);
      
      # Stop after enough iterations:
      if ((iter >= maxiter) || (nfail > maxfail))
        { printf "]\n" > "/dev/stderr"; 
          return sqrt(D2min);
        } 

      # Rotate {v} by {bet} around {u} towards {w}:
      vn[0] = cbet*v[0] + sbet*w[0];
      vn[1] = cbet*v[1] + sbet*w[1];
      vn[2] = cbet*v[2] + sbet*w[2];
      normalize(vn);
      # Rotate {u} by {tau} towards {vn}:
      ctau = cos(tau); stau = sin(tau);
      un[0] = ctau*u[0] + stau*vn[0];
      un[1] = ctau*u[1] + stau*vn[1];
      un[2] = ctau*u[2] + stau*vn[2];
      normalize(un);
      # Compute total squared deviation in direction {un}:
      D2 = 0;
      for (r = 0; r < 3; r++)
        for (s = 0; s < 3; s++)
          { D2 += un[r]*M[r,s]*un[s]; }
      # Did it improve?
      if (D2 < D2min)
        { # Improved update {Dmin}
          printf "%6.4f", D2 > "/dev/stderr";
          D2min = D2;
          # Bend {vn} too by {tau} away from {u}, update {v = vn, u = un};
          for (r = 0; r < 3; r++)
            { v[r] = ctau*vn[r] + stau*u[r]; u[r] = un[r]; }
          normalize(v);
          # Increase {tau}, up to {pi/4}:
          tau = taumag*tau; if (tau > pi/4) { tau = pi/4; }
          # Reset failed test count:
          nfail = 0;
          printf "!" > "/dev/stderr";
        }
      else
        { # Did not improve - update only {v = vn}
          for (r = 0; r < 3; r++) { v[r] = vn[r]; }
          # Decrease {tau}
          tau = taured*tau;
          # Count one more failed test
          nfail++;
          printf ":" > "/dev/stderr";
        }
    }
}

function abs(x)
{
  return (x < 0 ? -x : x);
}

function normalize(x,  m)
{
  m = sqrt(x[0]*x[0] + x[1]*x[1] + x[2]*x[2]);
  x[0] /= m; x[1] /= m; x[2] /= m;
}