#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <assert.h>
#include <jsfile.h>
#include <rn.h>
#include <float_image.h>


void discard_undesired_heights(float_image_t* im,float_image_t* wm){
	int NX,NY,NC;
		NC = im->sz[0];
		NX = im->sz[1];
		NY = im->sz[2];

		//long double mean = 0;
		int x,y;

		//double sumW = 0;
		for( x =0; x < NX; x++){
			for(y = 0; y < NY; y++){
				double w = float_image_get_sample(wm,0,x,y);
				if(w < 0.01){
					float_image_set_sample(im,0,x,y,0);
				}
			}
		}

}

float_image_t* RMS_normalize_heights(float_image_t* im,float_image_t* wm);

float_image_t* RMS_normalize_heights(float_image_t* im,float_image_t* wm){

	int NX,NY,NC;
	NC = im->sz[0];
	NX = im->sz[1];
	NY = im->sz[2];

	long double mean = 0;
	int x,y;

	double sumW = 0;
	for( x =0; x < NX; x++){
			for(y = 0; y < NY; y++){
				double w = float_image_get_sample(wm,0,x,y);
				sumW+= (w)/(double)(NX*NY);
			}
	}

	double big_div = sumW*NX*NY;

	for( x =0; x < NX; x++){
		for(y = 0; y < NY; y++){
			double val = float_image_get_sample(im,0,x,y);
			double w = float_image_get_sample(wm,0,x,y);
			mean+= (val*w)/(double)(big_div);
		}
	}

	float_image_t* Nim = float_image_copy(im);

	for( x =0; x < NX; x++){
		for(y = 0; y < NY; y++){
			double val = float_image_get_sample(im,0,x,y);
			val = val - mean;
			float_image_set_sample(Nim,0,x,y,val);
		}
	}

	return Nim;

}

double compute_RMS_simple(float_image_t* im1, float_image_t* wimg){
  int NX,NY,NC;
   NC = im1->sz[0];
   NX = im1->sz[1];
   NY= im1->sz[2];
   int x,y;

   double sumW = 0;
   double sumRMS = 0;

   for( x =0; x < NX; x++){
    for(y = 0; y < NY; y++){
      double z1,w;
      z1 = float_image_get_sample(im1,0,x,y);
      w = float_image_get_sample(wimg,0,x,y);
      sumRMS += (z1*z1)*w;
      sumW+= w;

    }
   }

   return sqrt(sumRMS/sumW);

}

float_image_t* compute_RMS_dual (float_image_t* im1, float_image_t* im2, float_image_t* wimg, double* RMS_diffP,double* RMS_relP, double* RMS1, double* RMS2 ){
  int NX,NY,NC;
  NC = im1->sz[0];
  NX = im1->sz[1];
  NY = im1->sz[2];

  assert(NC == 1);
  assert( NC == im2->sz[0]);
  assert( NX == im2->sz[1]);
  assert( NY == im2->sz[2]);

  assert( NC == wimg->sz[0]);
  assert( NX == wimg->sz[1]);
  assert( NY == wimg->sz[2]);

  float_image_t* Nim1 = RMS_normalize_heights(im1,wimg);
  float_image_t* Nim2 = RMS_normalize_heights(im2,wimg);
  float_image_t* RMS_img = float_image_new(NC,NX,NY);

  double sumRMS = 0;
  double sumW = 0;
  int x,y;
   for( x =0; x < NX; x++){
    for(y = 0; y < NY; y++){
      double z1,z2,w;
      z1 = float_image_get_sample(Nim1,0,x,y);
      z2 = float_image_get_sample(Nim2,0,x,y);
      w = float_image_get_sample(wimg,0,x,y);
      sumRMS+= (z1*z1)*w;
      sumW+= w;
      if( w > 0){ float_image_set_sample(RMS_img,0,x,y,z1-z2);}
      else{ float_image_set_sample(RMS_img,0,x,y,0);}
    }
   }

   double e = compute_RMS_simple(RMS_img,wimg);
   double rms1 = compute_RMS_simple(Nim1,wimg);
   double rms2 = compute_RMS_simple(Nim2,wimg);



   *RMS_diffP = e;
   *RMS_relP = hypot(rms1,rms2)/sqrt(2.0);
   *RMS1 = rms1;
   *RMS2 = rms2;

  return RMS_img;

}

float_image_t* compute_RMS_image(float_image_t* im1, float_image_t* im2, float_image_t* wimg, double* RMS_diffP,double* RMS_relP, double* RMS1, double* RMS2 ){
  int NX,NY,NC;
  NC = im1->sz[0];
  NX = im1->sz[1];
  NY = im1->sz[2];

 //assert(NC == 1);
  assert( NC == im2->sz[0]);
  assert( NX == im2->sz[1]);
  assert( NY == im2->sz[2]);

 // assert( NC == wimg->sz[0]);
  assert( NX == wimg->sz[1]);
  assert( NY == wimg->sz[2]);

  float_image_t* Nim1 = RMS_normalize_heights(im1,wimg);
  float_image_t* Nim2 = RMS_normalize_heights(im2,wimg);
  float_image_t* RMS_img = float_image_new(1,NX,NY);

  //double sumRMS = 0;
  double sumW = 0;
  int x,y;
   for( x =0; x < NX; x++){
    for(y = 0; y < NY; y++){
      double z1[NC];
      double z2[NC];
      double w;
      
      int c;
      for( c = 0; c < NC ; c++){
	z1[c] = float_image_get_sample(Nim1,c,x,y);
	z2[c] = float_image_get_sample(Nim2,c,x,y);
      }
      double diff = rn_dist(NC,z1,z2);
      
      w = float_image_get_sample(wimg,0,x,y);
      //sumRMS+= (z1*z1)*w;
      sumW+= w;
      if( w > 0){ float_image_set_sample(RMS_img,0,x,y,diff);}
      else{ float_image_set_sample(RMS_img,0,x,y,0);}
    }
   }

   double e = compute_RMS_simple(RMS_img,wimg);
   double rms1 = compute_RMS_simple(Nim1,wimg);
   double rms2 = compute_RMS_simple(Nim2,wimg);



   *RMS_diffP = e;
   *RMS_relP = hypot(rms1,rms2)/sqrt(2.0);
   *RMS1 = rms1;
   *RMS2 = rms2;

  return RMS_img;

}
// float_image_t* compute_RMS(float_image_t* im1, float_image_t* im2, float_image_t* wimg, double* RMS_diffP);
//
// float_image_t* compute_RMS(float_image_t* im1, float_image_t* im2, float_image_t* wimg, double* RMS_diffP, double* RMS_relP){
// 	int NX,NY,NC;
// 	NC = im1->sz[0];
// 	NX = im1->sz[1];
// 	NY = im1->sz[2];
//
// 	assert(NC == 1);
//
// 	assert( NC == im2->sz[0]);
// 	assert( NX == im2->sz[1]);
// 	assert( NY == im2->sz[2]);
//
// 	assert( NC == wimg->sz[0]);
// 	assert( NX == wimg->sz[1]);
// 	assert( NY == wimg->sz[2]);
//
// 	double sum_diff = 0;
// 	double sum_w = 0;
//
// 	float_image_t* Nim1 = RMS_normalize_heights(im1,wimg);
// 	float_image_t* Nim2 = RMS_normalize_heights(im2,wimg);
// 	discard_undesired_heights(Nim1,wimg);
// 	discard_undesired_heights(Nim2,wimg);
// 	float_image_t* RMS_img = float_image_new(NC,NX,NY);
//
// 	int x,y;
// 	for( x =0; x < NX; x++){
// 		for(y = 0; y < NY; y++){
// 			double z1,z2,w;
// 			z1 = float_image_get_sample(Nim1,0,x,y);
// 			z2 = float_image_get_sample(Nim2,0,x,y);
// 			w = float_image_get_sample(wimg,0,x,y);
// 			double d = z1 - z2;
//
// 			sum_diff+= w*d*d;
// 			sum_w += w;
//
// 			float_image_set_sample(RMS_img,0,x,y,z1-z2);
// 		}
// 	}
//
// 	float_image_free(Nim1);
// 	float_image_free(Nim2);
//
// 	double RMS_diff = sqrt(sum_diff/sum_w);
//
// 	*RMS_diffP = RMS_diff;
// 	return RMS_img;
// }
//
//


float_image_t* reduceMap(float_image_t* im){
  int NX,NY,NC;
  NC = im->sz[0];
  NX = im->sz[1];
  NY = im->sz[2];

  float_image_t* novo = float_image_new(NC,NX-1,NY-1);
  int c,x,y ;
  for(c = 0 ; c < NC; c++){
    for(x = 0; x < NX -1; x++){
      for(y = 0; y < NY -1; y++){
	double P00 = float_image_get_sample(im,c,x,y);
	double P10 = float_image_get_sample(im,c,x+1,y);
	double P01 = float_image_get_sample(im,c,x,y+1);
	double P11 = float_image_get_sample(im,c,x+1,y+1);

	double val = (P00 + P01 + P10 + P11)/4.0;
	float_image_set_sample(novo,c,x,y,val);
      }
    }
  }
  return novo;
}

float_image_t* expandMap(float_image_t* im,bool_t box){
  int NX,NY,NC;
  NC = im->sz[0];
  NX = im->sz[1];
  NY = im->sz[2];

  float_image_t* novo = float_image_new(NC,NX + 1,NY + 1);
  int x,y,c ;
  for(c = 0; c < NC; c++){
    for(x = 0; x < NX + 1; x++){
      for(y = 0; y < NY + 1; y++){

	double sumW = 0;
	double P00,P01,P10,P11;
	P00 = P01 = P10 = P11 = 0;
	if( (x > 0) && (y > 0) ){ P00 = float_image_get_sample(im,c,x-1,y-1); sumW = sumW+1;}
	if( (x < NX ) && (y > 0) ){ P01 = float_image_get_sample(im,c,x,y-1); sumW = sumW+1;}
	if( (x >0) && (y < NY) ){ P10 = float_image_get_sample(im,c,x-1,y); sumW = sumW+1;}
	if( (x < NX ) && (y < NY) ){ P11 = float_image_get_sample(im,c,x,y); sumW = sumW+1;}


	double val;
	if(!box){
	  val = (P00 + P01 + P10 + P11)/sumW;
	}else{
	  val = P00*P01*P10*P11;
	}
	float_image_set_sample(novo,c,x,y,val);

      }
    }
  }
  return novo;
}

void correct_images(float_image_t** im1 ,float_image_t** im2){
  float_image_t* orig_im1 = *im1;
  float_image_t* orig_im2 = *im2;

  int NX1,NY1,NC1;
  NC1 = orig_im1->sz[0];
  NX1 = orig_im1->sz[1];
  NY1 = orig_im1->sz[2];

  int NX2,NY2,NC2;
  NC2 = orig_im2->sz[0];
  NX2 = orig_im2->sz[1];
  NY2 = orig_im2->sz[2];

  if((NX1 == NX2) && (NY1==NY2) ){ return ;}
  if((NX1 == (NX2+1)) && (NY1==(NY2+1)) ){
    fprintf(stderr,"Image 1 is one pixel bigger than 2\n");
    float_image_t* rd_im1 = reduceMap(orig_im1);
    *im1 = rd_im1;
     return;
  }

   if((NX2 == (NX1+1)) && (NY2==(NY1+1)) ){
     fprintf(stderr,"Image 2 is one pixel bigger than 1\n");
    float_image_t* rd_im2 = reduceMap(orig_im2);
    *im2 = rd_im2;
     return;
  }

  fprintf(stderr,"Incompatible heightmaps !\n");
  assert(FALSE);

}

float_image_t* readFNI(char* filename);


float_image_t* readFNI(char* filename){
	FILE* arq = open_read(filename,TRUE);
	float_image_t* img = float_image_read(arq);
	fclose(arq);
	return img;
}

void writeFNI(char* filename, float_image_t* img);

void writeFNI(char* filename, float_image_t* img){
	FILE* arq = open_write(filename,TRUE);
	float_image_write(arq,img);
	fclose(arq);
}

int main(int argc, char** argv){
	if(argc < 4){
		fprintf(stderr,"program usage\ncompute_rms <height_map1> <height_map2> <output_file> [weight_image]\n");
		return 0;
	}
	char* hm1_filename = argv[1];
	char* hm2_filename = argv[2];
	char* output_filename = argv[3];
	char* wm_filename = NULL;
	if(argc == 5){
		wm_filename = argv[4];
	}

	float_image_t* im1 = readFNI(hm1_filename);
	float_image_t* im2 = readFNI(hm2_filename);

	assert(im1->sz[0] == im2->sz[0]);

	
	double RMS_abs,RMS_rel,RMS1,RMS2;
//	discard_undesired_heights(im1,wimg);
//	discard_undesired_heights(im2,wimg);
	correct_images(&im1,&im2);

	float_image_t* wimg;
	if(wm_filename == NULL){
		wimg = float_image_copy(im1);
		float_image_fill_channel(wimg,0,1.0);
	}else{
		wimg = readFNI(wm_filename);
	}

	if(wimg->sz[1] < im1->sz[1]){
	  wimg = expandMap(wimg,1);
	}

	
	int NC = im1->sz[0];
	
	float_image_t* img_diff;
	if(NC == 1){
	  img_diff = compute_RMS_dual(im1,im2,wimg,&RMS_abs,&RMS_rel,&RMS1,&RMS2);
	  fprintf(stdout,"%lf %lf %lf %lf\n",RMS_abs,RMS_abs/RMS_rel,RMS1,RMS2);
	}else{
	  img_diff = compute_RMS_image(im1,im2,wimg,&RMS_abs,&RMS_rel,&RMS1,&RMS2);
	  fprintf(stdout,"%lf %lf %lf %lf\n",RMS_abs,RMS_abs/RMS_rel,RMS1,RMS2);
	}
	/*
	int NX = im1->sz[1];
	int NY = im1->sz[2];
	int c;
	
	float_image_t* img_diff_final = float_image_new(NC,NX,NY);
	float_image_t* img_cmp1 = float_image_new(1,NX,NY);
	float_image_t* img_cmp2 = float_image_new(1,NX,NY);
	
	
	for( c = 0 ; c < NC; c++){
	  float_image_set_channel(img_cmp1,0,im1,c);
	  float_image_set_channel(img_cmp2,0,im2,c);
// 	  float_image_t* img_diff = compute_RMS_dual(im1,im2,wimg,&RMS_abs,&RMS_rel,&RMS1,&RMS2);
	  float_image_t* img_diff = compute_RMS_dual(img_cmp1,img_cmp2,wimg,&RMS_abs,&RMS_rel,&RMS1,&RMS2);
	  float_image_set_channel(img_diff_final,c,img_diff,0);
	  fprintf(stdout,"%lf %lf %lf %lf\n",RMS_abs,RMS_abs/RMS_rel,RMS1,RMS2);
	  float_image_free(img_diff);
	}*/
	
// 	writeFNI(output_filename,img_diff_final);
	writeFNI(output_filename,img_diff);


	return 0;
}