/* See pst_height_map.h */
/* Last edited on 2010-05-04 03:38:52 by stolfi */

#include <assert.h>
#include <affirm.h>
#include <math.h>

#include <pst_height_map.h>
#include <float_image.h>
#include <float_image_mscale.h>

float_image_t *pst_height_map_expand_old(float_image_t *JZ, int NXI, int NYI, int expOrder)
  { 
    /* !!! Generalize and move to {float_image_mscale.h} !!! */
    
    demand(expOrder == 2, "{expOrder} other than 2 is not supported yet");
    
    assert(JZ->sz[0] == 1);
    int NXJ = JZ->sz[1]; assert(NXI/2 + 1 == NXJ);
    int NYJ = JZ->sz[2]; assert(NYI/2 + 1 == NYJ);
    float_image_t *IZ = float_image_new(1, NXI, NYI);
    int xI, yI;
    for(yI = 0; yI < NYI; yI++)
      { for(xI = 0; xI < NXI; xI++)
          { /* Get relevant samples from original image: */
            int xJ0 = xI/2, xJ1 = (xI+1)/2;
            int yJ0 = yI/2, yJ1 = (yI+1)/2;
            assert(xJ1 < NXJ); assert(yJ1 < NYJ);
            
	    double d00 = float_image_get_sample(JZ, 0, xJ0, yJ0);
            double d01 = (yJ1 == yJ0 ? d00 : float_image_get_sample(JZ, 0, xJ0, yJ1));
            double d10 = (xJ1 == xJ0 ? d00 : float_image_get_sample(JZ, 0, xJ1, yJ0));
            double d11 = ((xJ1 == xJ0) && (yJ1 == yJ0) ? d00 : float_image_get_sample(JZ, 0, xJ1, yJ1));
            /* Average samples, and scale {Z} up by 2: */
            float v = (d00 + d01 + d10 + d11)/2;
            /* Store sample in expanded image: */
            float_image_set_sample(IZ, 0, xI, yI, v);
          }
      }
    return IZ;
  }

float_image_t *pst_height_map_expand(float_image_t *JZ,float_image_t *JW, int NXI, int NYI, int expOrder)
  { 
     /* !!! Generalize and move to {float_image_mscale.h} !!! */
    
    demand(expOrder == 2, "{expOrder} other than 2 is not supported yet");
    
    assert(JZ->sz[0] == 1);
    int NXJ = JZ->sz[1]; assert(NXI/2 + 1 == NXJ);
    int NYJ = JZ->sz[2]; assert(NYI/2 + 1 == NYJ);
    float_image_t *IZ = float_image_new(1, NXI, NYI);
    int xI, yI;
    for(yI = 0; yI < NYI; yI++)
      { for(xI = 0; xI < NXI; xI++)
          { /* Get relevant samples from original image: */
            int xJ0 = xI/2, xJ1 = (xI+1)/2;
            int yJ0 = yI/2, yJ1 = (yI+1)/2;
            assert(xJ1 < NXJ); assert(yJ1 < NYJ);
            
	    double d00 = float_image_get_sample(JZ, 0, xJ0, yJ0);
            double d01 = (yJ1 == yJ0 ? d00 : float_image_get_sample(JZ, 0, xJ0, yJ1));
            double d10 = (xJ1 == xJ0 ? d00 : float_image_get_sample(JZ, 0, xJ1, yJ0));
            double d11 = ((xJ1 == xJ0) && (yJ1 == yJ0) ? d00 : float_image_get_sample(JZ, 0, xJ1, yJ1));
	    
	    double w00,w10,w01,w11;
	    if( JW != NULL){
	      w00 = float_image_get_sample(JW, 0, xJ0, yJ0);
	      w01 = (yJ1 == yJ0 ? w00 : float_image_get_sample(JW, 0, xJ0, yJ1));
	      w10 = (xJ1 == xJ0 ? w00 : float_image_get_sample(JW, 0, xJ1, yJ0));
	      w11 = ((xJ1 == xJ0) && (yJ1 == yJ0) ? w00 : float_image_get_sample(JW, 0, xJ1, yJ1));
	    }else{
	      w00 = w01 = w10 = w11 = 1.0;
	    }
	    
	    
	    double sW = w00+w01+w10+w11;
	    if(sW == 0){
	      w00 = w01 = w10 = w11 = 1.0;
	      sW = 4.0;
	    }
	    
	    /* Average samples, and scale {Z} up by 2: */
            float v = 2*((d00*w00) + (d01*w01) + (d10*w10) + (d11*w11))/sW;
            /* Store sample in expanded image: */
            float_image_set_sample(IZ, 0, xI, yI, v);
          }
      }
    return IZ;
  }


float_image_t *pst_height_map_shrink(float_image_t *IZ, int avgWidth)
  {
    int NX_JZ = IZ->sz[1]/2+1;
    int NY_JZ = IZ->sz[2]/2+1;
    int dxy = (avgWidth-1)/2;
    //we need to rescale the heights according
    float_image_t* msSZ = float_image_mscale_shrink(IZ, NULL, NX_JZ, NY_JZ, dxy, dxy, avgWidth);
    int x,y;
    for(x = 0; x < msSZ->sz[1]; x++){
      for(y = 0; y < msSZ->sz[2]; y++){
	double v = float_image_get_sample(msSZ,0,x,y);
	v = v/2;
	float_image_set_sample(msSZ,0,x,y,v);
      }
    }
    return msSZ;
  }

float_image_t *pst_height_map_compare
  ( float_image_t *AZ,
    float_image_t *BZ,
    float_image_t *W,
    bool_t zero_mean,
    double *sAZP,
    double *sBZP,
    double *sEZP,
    double *sreP
  )
  { 
    assert(AZ->sz[0] == 1);
    assert(BZ->sz[0] == 1);
    int NX = AZ->sz[1]; assert(BZ->sz[1] == NX);
    int NY = AZ->sz[2]; assert(BZ->sz[2] == NY);
    
    if (W != NULL)
      { assert(W->sz[0] == 1);
        assert(W->sz[1] == NX);
        assert(W->sz[2] == NY);
      }
      
    /* Compute the mean values of {AZ,BZ,EZ}: */
    double sum_AZW = 0;
    double sum_BZW = 0;
    double sum_EZW = 0;
    double sum_W = 0;
    int x, y;
    for(y = 0; y < NY; y++)
      { for(x = 0; x < NX; x++)
          { /* Get relevant samples from original image: */
            double vA = float_image_get_sample(AZ, 0, x, y);
            double vB = float_image_get_sample(BZ, 0, x, y);
            double vW = 1.0;
	    if(W != NULL) vW = float_image_get_sample(W, 0, x, y);
            double vE = vA - vB;
            /* Accumulate values: */
            sum_AZW += vA*vW;
            sum_BZW += vB*vW;
            sum_EZW += vE*vW;
            sum_W += vW;
          }
      }
    double avgA = (sum_W == 0 ? 0.5 : sum_AZW/sum_W);
    double avgB = (sum_W == 0 ? 0.5 : sum_BZW/sum_W);
    double avgE = (sum_W == 0 ? 0.0 : sum_EZW/sum_W);
      
    /* Fill {EZ} and RMS values: */
    float_image_t *EZ = float_image_new(1, NX, NY);
    double sum_AdZ2W = 0.0;
    double sum_BdZ2W = 0.0;
    double sum_EZ2W = 0.0;
    for(y = 0; y < NY; y++)
      { for(x = 0; x < NX; x++)
          { /* Get relevant samples from original image: */
            double vA = float_image_get_sample(AZ, 0, x, y);
            double vB = float_image_get_sample(BZ, 0, x, y);
            double vW = 1.0;
	    if(W != NULL) vW = float_image_get_sample(W, 0, x, y);
            double vE = vA - vB;
            if (zero_mean) { vE = vE - avgE; }
            /* Store difference in error image: */
            float_image_set_sample(EZ, 0, x, y, (float)vE);
            /* Accumuate squares: */
            vA = vA - avgA;
            vB = vB - avgB;
            sum_AdZ2W += vW*vA*vA;
            sum_BdZ2W += vW*vB*vB;
            sum_EZ2W += vW*vE*vE;
          }
      }
    /* Compute the RMS values and errors: */
    double sAZ = sqrt(sum_AdZ2W/sum_W);
    double sBZ = sqrt(sum_BdZ2W/sum_W);
    double sEZ = sqrt(sum_EZ2W/sum_W);
    double sMZ = hypot(sAZ,sBZ)/M_SQRT2;
    double sre = sEZ/sMZ;
    
    /* Return results: */
    if (sAZP != NULL) { (*sAZP) = sAZ; }
    if (sBZP != NULL) { (*sBZP) = sBZ; }
    if (sEZP != NULL) { (*sEZP) = sEZ; }
    if (sreP != NULL) { (*sreP) = sre; }

    return EZ;
  }

void pst_height_map_level_analyze_and_write
  ( char *filePrefix,
    int level,
    bool_t levelTag,
    int iter,
    bool_t iterTag,
    double change,
    float_image_t *CZ,
    float_image_t *RZ,
    float_image_t *U, 
    bool_t writeImages,
    bool_t writeError,
    int indent
  )
  {
    demand(CZ->sz[0] == 1, "bad CZ channels");
    int NX = CZ->sz[1]; 
    int NY = CZ->sz[2]; 
    
    if (RZ != 0)
      { demand(RZ->sz[0] == 1, "bad RZ channels");
        demand(NX == RZ->sz[1], "bad RZ cols"); 
        demand(NY == RZ->sz[2], "bad RZ rows"); 
      }
    
    if (U != 0)
      { demand(U->sz[0] == 1, "bad U channels");
        demand(NX == U->sz[1], "bad U cols"); 
        demand(NY == U->sz[2], "bad U rows"); 
      }
      
    int levelQ = (levelTag ? level : -1);
    int iterQ = (iterTag ? iter : -1);
    
    if (writeImages)
      { float_image_mscale_write_file(CZ, filePrefix, levelQ, iterQ, "Z", indent); }
      
    if (RZ != NULL)
      {
        double sAZ, sBZ, sEZ, sre;
        float_image_t *EZ = pst_height_map_compare(CZ, RZ, U, TRUE, &sAZ, &sBZ, &sEZ, &sre);
        if (writeImages)
          { float_image_mscale_write_file(EZ, filePrefix, levelQ, iterQ, "eZ", indent); }
        if (writeError)
          { char *fileName = float_image_mscale_file_name(filePrefix, levelQ, iterQ, "eZ", "txt");
            fprintf(stderr, "%*swriting %s ...", indent, "", fileName);
            FILE* wr = fopen(fileName, "wt");
            assert(wr != NULL); 
            fprintf
              ( wr, "%2d %6d %6d %9d %14.11f  %14.11f %14.11f  %14.11f %14.11f\n",
                level, NX, NY, iter, change, sAZ, sBZ, sEZ, sre
              );
            if (wr == stdout) { fflush(wr); } else { fclose(wr); }
            fprintf(stderr, "\n");
            free(fileName);
          }
      }
  }
