// Finds the correspondence matrix between two images, given 
// 8 or more pairs of correspondng points.
// Last edited on 2011-12-21 14:03:23 by stolfi

#include <stdlib.h>
#include <math.h>
#include <stdio.h>
#include <assert.h>

#include <r2.h>
#include <r2x2.h>
#include <r3.h>
#include <r3x3.h>
#include <sve_minn.h>

// IMPLEMENTATIONS

void image_stitch_read_data(r2_vec_t *p1, r2_vec_t *p2);
/* Reads from {rd} a list of point pairs {p1[i],p2[i]},
  where {p1[i]} is in image 1 and {p2[i]} is in image 2.
  Each pair should be in the format "( h1 v1 ) = ( h2 v2 )" where 
  {h1,h2} are column indices and {v1,v2} are row indices,
  both counted from 0.  The vectors {p1,p2} are allocated
  by the procedure. */

void image_stitch_map_point(r2_t *p, r3x3_t *M, r2_t *q);
  /* Maps the Cartesian point {p} through the projective map with 
    homogeneous matrix {M}, stores result in {*q}. Returns {(INF,INF)} 
    if the resulting point is at infinity or beyond. */

void image_stitch_bar(r2_vec_t *p, r2_t *bar);
  /* Computes the barycenter {bar} of the point set {*p}. */

double image_stitch_translation_scale(r2_vec_t *p1, r2_vec_t *p2);
  /* Computes a suitable scale factor {tscale} for the translation elements 
    {M[0,1}} and {M[0,2]} of {M} so that an uncertainty of {eps} in the other elements 
    is roughly equivalent to an uncertainty of {eps*tscale} in the translation
    elements. */

r3x3_t image_stitch_initial_matrix_guess(r2_vec_t *p1, r2_vec_t *p2);
  /* Computes an affine matrix {M} such that {p1} mapped TWICE by {M} 
    matches {p2}, approximately. */

r3x3_t image_stitch_optimize_matrix(r3x3_t *M0, r2_vec_t *p1, r2_vec_t *p2);
  /* Computes the projective matrix {M} such that {p1} mapped TWICE by {M} 
    is as close as possible to {p2}, by linear optimization. */

double image_stitch_mean_err_sqr(r2_vec_t *p1, r3x3_t *M1, r2_vec_t *p2, r3x3_t *M2);
  /* Cmputes the mean squared error between the points of {p1} mapped by 
    {M1} and the corresponding points of {p2} maped by {M2}. */

void image_stitch_output_matrices(r3x3_t *M, r3x3_t *N);
  /* Wrotes (to stderr) the matrix {M} and its inverse {N}. */

void image_stitch_check_matrices(r2_vec_t *p1, r3x3_t *M1, r2_vec_t *p2, r3x3_t *M2);
  /* Compares (to stderr) the points of {p1} mapped by 
    {M1} to the corresponding points of {p2} maped by {M2}. */

int main(int argc, char **argv)
  {
    /* Read the point pairs: */
    r2_vec_t p1;    /* Reference points in image 1. */
    r2_vec_t p2;    /* Reference points in image 2. */
    image_stitch_read_data(&p1, &p2);
    int np = p1.ne;
    assert(np == p2.ne);
    if (np < 3) { fprintf(stderr, "too few point pairs (%d) for affine mapping\n", np); exit(1); }
    
    r3x3_t M0 = image_stitch_initial_matrix_guess(&p1, &p2);
                                        
    /* Nonlinear optimization: */
    r3x3_t M = image_stitch_optimize_matrix(&M0, &p1, &p2);
    r3x3_t N;  r3x3_inv(&M, &N);

    /* Output: */
    fprintf(stderr, "final result:\n");
    image_stitch_output_matrices(&M, &N);
    
    /* Checking: */
    fprintf(stderr, "mapped points:\n");
    image_stitch_check_matrices(&p1, &M, &p2, &N);

    return 0;
  }

void image_stitch_read_data(r2_vec_t *p1, r2_vec_t *p2)
  {
    /* Read the point pairs: */
    (*p1) = r2_vec_new(20);
    (*p2) = r2_vec_new(20);
    int np = 0; /* Number of data pairs. */

    while (TRUE)
      { double x1, y1, x2, y2;
        int nscan = fscanf(stdin, " ( %lf %lf ) = ( %lf %lf )", &x1, &y1, &x2, &y2);
        if (nscan <= 0) { break; }
        if (nscan != 4) { fprintf(stderr, "line %d: bad input format (nscan = %d)\n", np+1, nscan); exit(1); }
        fprintf(stderr, " ( %6.1f %6.1f ) = ( %6.1f %6.1f )\n", x1,y1,x2,y2);
        r2_vec_expand(p1, np); p1->e[np] = (r2_t){{x1, y1}}; 
        r2_vec_expand(p2, np); p2->e[np] = (r2_t){{x2, y2}}; 
        np++;
      }
    r2_vec_trim(p1, np); 
    r2_vec_trim(p2, np); 
  }
  

void image_stitch_map_point(r2_t *p, r3x3_t *M, r2_t *q)
  {
    r3_t hp = (r3_t){{ 1, p->c[0], p->c[1] }};
    r3_t hq;
    r3x3_map_row(&hp, M, &hq);
    double w = hq.c[0];
    double m = fmax(fabs(hq.c[1]), fabs(hq.c[2]));
    if (w <= m*1e-200) 
      { (*q) =  (r2_t){{ INF, INF }}; }
    else
      { (*q) =  (r2_t){{ hq.c[1]/w, hq.c[2]/w }}; } 
  }
    

double image_stitch_mean_err_sqr(r2_vec_t *p1, r3x3_t *M1, r2_vec_t *p2, r3x3_t *M2)
  {
    int np = p1->ne;
    assert(np == p2->ne);
    
    int k;
    double sum2 = 0.0;
    for (k = 0; k < np; k++)
      {
        r2_t *p1k = &(p1->e[k]);
        r2_t q1k; image_stitch_map_point(p1k, M1, &q1k);
        r2_t *p2k = &(p2->e[k]);
        r2_t q2k; image_stitch_map_point(p2k, M2, &q2k);
        double d2 = r2_dist_sqr(&q1k, &q2k);
        sum2 += d2;
      }
    return sum2/np;
  }

r3x3_t image_stitch_initial_matrix_guess(r2_vec_t *p1, r2_vec_t *p2)
  {
    int np = p1->ne;
    assert(np == p2->ne);
    int k;
    
    fprintf(stderr, "--- computing the affine matrix ---\n");
    
    r2_t bar1; image_stitch_bar(p1, &bar1);
    r2_gen_print(stderr, &bar1, "%10.4f", "  bar1 = ( ", " ", " )\n");
    r2_t bar2; image_stitch_bar(p2, &bar2);
    r2_gen_print(stderr, &bar2, "%10.4f", "  bar2 = ( ", " ", " )\n");
    
    /* Compute mean linear transformation {S} from {p1} to {p2}: */
    r2x2_t A; r2x2_zero(&A); /* Moment matrix. */
    r2x2_t B; r2x2_zero(&B); /* Projection matrix. */
    for (k = 0; k < np; k++)
      { 
        /* Reduce points relative to barycenter: */
        r2_t q1k, q2k;
        r2_sub(&(p1->e[k]), &bar1, &q1k);
        r2_sub(&(p2->e[k]), &bar2, &q2k);
        /* Accumulate moments and projections: */
        int i,j;
        for(i = 0; i < 2; i ++)
          { for (j = 0; j < 2; j++)
              { A.c[i][j] += q1k.c[i]*q1k.c[j];
                B.c[i][j] += q1k.c[i]*q2k.c[j];
              }
          }
      }
    r2x2_t Z; r2x2_inv(&A, &Z);
    r2x2_t S; r2x2_mul(&Z, &B, &S);
    fprintf(stderr, "  linear matrix:\n");
    r2x2_gen_print(stderr, &S, "%13.6e", "", "\n", "\n", "    [ ", " ", " ]");
    
    /* Obtain square root {R} of {S}: */
    r2x2_t R; r2x2_sqrt(&S, &R);
    fprintf(stderr, "  square root of linear matrix:\n");
    r2x2_gen_print(stderr, &R, "%13.6e", "", "\n", "\n", "    [ ", " ", " ]");
    
    /* Compute the translation vetor {u} so that {bar1} is mapped to {bar2}: */
    r2x2_t N = R;
    N.c[0][0] += 1;
    N.c[1][1] += 1;
    r2x2_inv(&N, &N);
    r2_t u; 
    r2x2_map_row(&bar1, &R, &u);
    r2x2_map_row(&u, &R, &u);
    r2_sub(&bar2, &u, &u);
    r2x2_map_row(&u, &N, &u);
    r2_gen_print(stderr, &u, "%10.4f", "  u =   ( ", " ", " )\n");
    
    /* Pack {R} and {u} as a 3x3 homogeneous affine matrix {M}: */
    r3x3_t M;
    M.c[0][0] = 1.0;
    M.c[0][1] = u.c[0];
    M.c[0][2] = u.c[1];

    M.c[1][0] = 0.0;
    M.c[1][1] = R.c[0][0];
    M.c[1][2] = R.c[0][1];

    M.c[2][0] = 0.0;
    M.c[2][1] = R.c[1][0];
    M.c[2][2] = R.c[1][1];
    
    return M;
  }
        
void image_stitch_bar(r2_vec_t *p, r2_t *bar)
  { int k;
    r2_t sum = (r2_t){{ 0, 0 }};
    for (k = 0; k < p->ne; k++) { r2_add(&sum, &(p->e[k]), &sum); }
    r2_scale(1.0/((double)p->ne), &sum, bar);
  }

r3x3_t image_stitch_optimize_matrix(r3x3_t *M0, r2_vec_t *p1, r2_vec_t *p2)
  {

    int np = p1->ne;
    assert(np == p2->ne);

    /* Choose scale factors for the matrix parameters: */
    double tscale = image_stitch_translation_scale(p1, p2); /* Translation. */
    double pscale = 1.0; /* Projective distortion. */
    double dscale = 1.0; /* Diagonal terms of linear map. */
    double sscale = 1.0; /* Non-diagonal terms of linear map. */
    
    double maxErr = 0.5;  /* Max root-mean-square position error. */

    int nx = 8; /* Number of packed parameters. */

    auto void unpack_parameters(int nx, double x[], r3x3_t *M);
      /* Unpacks the parameter vector {x[0..nx-1]} to the 
        homogeneous projective matrix {M}.  The matrix
        will have {M[0,0] == 1}. Requires {nx==8}. */
         
    auto void pack_parameters(r3x3_t *M, int nx, double x[]);
      /* Packs the homogeneous projective matrix {M} to the parameter
        vector {x[0..nx-1]} to. Implicitly normalizes the matrix
        will have {M[0,0] == 1}.  Requires {nx==8}. */
         
    auto double goalf(int nx, double x[]);
      /* Computes the mean squared error in the positions of the
         mapped points, given the packed parameters {x[0..nx-1]}. */
    
    auto bool_t is_ok(int nx, double x[], double Fx);
      /* Returns true if the squared error {Fx} is small enough. */
    
    r3x3_t M; /* Optimized matrix. */

    if (np >= 0)
      { /* Nonlinear optimization to compute {M}: */
        /* !!! TO BE WRITTEN !!! */
        double x[nx]; /* Initial guess and final optimum parameters. */ 
        double dMax = 1.0;              /* Nominal uncertainty in {xk}. */
        double rIni = 0.250*dMax;       /* Initial probe radius. */
        double rMin = 0.5/tscale;       /* Minimum probe radius. */
        double rMax = 0.500*dMax;       /* Maximum probe radius. */
        int ns = (nx+1)*(nx+2)/2;       /* Number of sampling points. */
        int maxIters = 3;               /* Max number of major iterations. */
        int maxEvals = maxIters*(ns+1); /* Max number of goal evaluations. */
        bool_t debug = TRUE;
        
        pack_parameters(M0, nx, x);
        sve_minn_iterate(nx, &goalf, &is_ok, x, dMax, rIni, rMin, rMax, maxEvals, debug);
        unpack_parameters(nx, x, &M);
      }
    else
      { fprintf(stderr, "too few point pairs (%d) for projective map, using affine map\n", np); 
        M = (*M0);
      }
    
    return M;
 
    /* --- internal procs ----------------------------------------------------- */

    void pack_parameters(r3x3_t *M, int nx, double x[])
      { assert(nx == 8);
        int i, j;
        int k = 0;
        double w = M->c[0][0];
        assert(w > 1.0e-15);
        for (i = 0; i < 3; i++)
          { for (j = 0; j < 3; j++)
              {   
                if ((i == 0) && (j == 0))
                  { assert(M->c[i][j] == w); }
                else if (i == 0)
                  { x[k] = M->c[i][j]/w/tscale; k++; }
                else if (j == 0)
                  { x[k] = M->c[i][j]/w/pscale; k++; }
                else if (i == j)
                  { x[k] = M->c[i][j]/w/dscale; k++; }
                else
                  { x[k] = M->c[i][j]/w/sscale; k++; }
              }
          }
      }
    
    void unpack_parameters(int nx, double x[], r3x3_t *M)
      { assert(nx == 8);
        int i, j;
        int k = 0;
        for (i = 0; i < 3; i++)
          { for (j = 0; j < 3; j++)
              { 
                if ((i == 0) && (j == 0))
                  { M->c[i][j] = 1.0; }
                else if (i == 0)
                  { M->c[i][j] = x[k]*tscale; k++; }
                else if (j == 0)
                  { M->c[i][j] = x[k]*pscale; k++; }
                else if (i == j)
                  { M->c[i][j] = x[k]*dscale; k++; }
                else
                  { M->c[i][j] = x[k]*sscale; k++; }
              }
          }
      }
    
    double goalf(int nx, double x[])
      {
        r3x3_t M1, M2;
        unpack_parameters(nx, x, &M1);
        r3x3_inv(&M1, &M2);
        double esq = image_stitch_mean_err_sqr(p1, &M1, p2, &M2);
        return esq;
      }
      
    bool_t is_ok(int nx, double x[], double Fx)
      {
        return Fx < maxErr*maxErr;
      }
  }

double image_stitch_translation_scale(r2_vec_t *p1, r2_vec_t *p2)
  {
    int np = p1->ne;
    assert(np == p2->ne);
    r2_t bar1, bar2;
    image_stitch_bar(p1, &bar1);
    image_stitch_bar(p2, &bar2);
    int k;
    double maxrsq = 0.0;
    for (k = 0; k < np; k++)
      { r2_t q1, q2;
        r2_sub(&(p1->e[k]), &bar1, &q1);
        r2_sub(&(p2->e[k]), &bar2, &q2);
        double r1sq = r2_norm_sqr(&q1);
        double r2sq = r2_norm_sqr(&q2);
        maxrsq = fmax(maxrsq, fmax(r1sq, r2sq));
      }
    return sqrt(maxrsq);
  }

void image_stitch_output_matrices(r3x3_t *M, r3x3_t *N)
  {
    fprintf(stderr, "matrix:\n");
    r3x3_gen_print(stderr, M, "%13.6e", "\n  ", "\n  ", "\n\n", "[ ", " ", " ]");

    fprintf(stderr, "matrix inverse:\n");
    r3x3_gen_print(stderr, N, "%13.6e", "\n  ", "\n  ", "\n\n", "[ ", " ", " ]");
    
    fprintf(stderr, "matrix squared:\n");
    r3x3_t Msq; r3x3_mul(M, M, &Msq);
    r3x3_gen_print(stderr, &Msq, "%13.6e", "\n  ", "\n  ", "\n\n", "[ ", " ", " ]");

    fprintf(stderr, "matrix inverse squared:\n");
    r3x3_t Nsq; r3x3_mul(N, N, &Nsq);
    r3x3_gen_print(stderr, &Msq, "%13.6e", "\n  ", "\n  ", "\n\n", "[ ", " ", " ]");
  }

void image_stitch_check_matrices(r2_vec_t *p1, r3x3_t *M1, r2_vec_t *p2, r3x3_t *M2)
  {
    int np = p1->ne;
    assert(np == p2->ne);
    
    int k;
    for(k = 0; k < np; k++)
      { r2_t *p1k = &(p1->e[k]);
        r2_t q1k; image_stitch_map_point(p1k, M1, &q1k);
        r2_t *p2k = &(p2->e[k]);
        r2_t q2k; image_stitch_map_point(p2k, M2, &q2k);
        r2_gen_print(stderr, p1k, "%8.2f", "( ", " ", " )");
        fprintf(stderr, " -> ");
        r2_gen_print(stderr, &q1k, "%8.2f", "( ", " ", " )");
        fprintf(stderr, "  ");
        r2_t d; r2_sub(&q1k, &q2k, &d);
        r2_gen_print(stderr, &d, "%+6.2f", "( ", " ", " )");
        double e = r2_dist(&q1k, &q2k);
        fprintf(stderr, "%6.2f", e);
        fprintf(stderr, "  ");
        r2_gen_print(stderr, &q2k, "%8.2f", "( ", " ", " )");
        fprintf(stderr, " <- ");
        r2_gen_print(stderr, p2k, "%8.2f", "( ", " ", " )");
        fprintf(stderr, "\n");
      }
  }


