#include <ls_system.h>
#include <stdlib.h>
#include <assert.h>
#include <gauss_elim.h>
#include <rmxn.h>

void computeLeastSquaresMatrix(double* A,phi_function* phi, int basis_size, double** X,int n, double w[], double wpos[],void* l_data){
	int r,s;
	int count = 0;
	int total = basis_size;
	//fprintf(stderr,"\n\n");
	for( r = 0; r < basis_size; r++){
		for( s = 0; s < basis_size; s++){
			int iA = (r*basis_size) + s;
			double valueA = 0;
			int i;
			for(i = 0; i < n; i++){
			
				//r3_t normal = X[i];
				double* x = X[i];
                                // Change to view-depedent coordinates:
				double wi = w[i]*wpos[i];
// 				double phiR = phi(r,&normal,l_data);
// 				double phiS = phi(s,&normal,l_data);
				double phiR = phi(r,x,l_data);
 				double phiS = phi(s,x,l_data);
				valueA+= (phiR*phiS)*wi;
			}
			
			A[iA] = valueA;
		}
		
		//fprintf(stderr,"\033[1A");
		//fprintf(stderr,"Processed [%04d of %04d] - %4.3f%%\n",count,total,count*100.0/(float)total);
		count++;
	}
	
}

void computeLeastSquaresRHSVector(double* b,phi_function* phi, int basis_size,double** X,double* F,int n , double w[], double wpos[],void* l_data){
	int r;
	int count = 0;
	int total = basis_size;
	//fprintf(stderr,"\n\n");
	for( r = 0; r < basis_size; r++){
		
		double valueB = 0;
		int i;
		for(i = 0; i < n; i++){
			//r3_t normal = X[i];
			double* x = X[i];
                        // Change to view-depedent coordinates
			double wi = w[i]*wpos[i];
			double Di = F[i];
			
// 			double phiR = phi(r,&normal,l_data);
			double phiR = phi(r,x,l_data);
			valueB+=(phiR*Di)*wi;
		}
		b[r] = valueB;
		//fprintf(stderr,"\033[1A");
		//fprintf(stderr,"Processed [%04d of %04d] - %4.3f%%\n",count,total,count*100.0/(float)total);
		count++;
	}
	

}

void computeLSTerms(double* A,
	double* b ,
	double* c,
	phi_function* phi,
	int basis_size,
	validate_function* validateResults,
	double** X,
	double* F,
	int n,
	double* w,
	double* wpos,
	void* l_data
	){

	int valids = 0;
	int iterations = 0;
	double *Ared = (double *)malloc(basis_size*basis_size*sizeof(double)); // Matrix for reduced basis
        bool_t *valid_basis = (bool_t *)malloc(basis_size*sizeof(bool_t)); // Marks elements of reduced basis
	int *oldix = (int *)malloc(basis_size*sizeof(int)); // Maps reduced indices to full indices

        // The generic element of the reduced basis:
        auto double phiRed(int i, double* x,void* l_data);
        double phiRed(int i, double* x,void* l_data){ return phi(oldix[i], x,l_data); }

	int i,j;
        // Initially all basis elements are valid:
        for(j = 0; j < basis_size; j++){ valid_basis[j] = TRUE; oldix[j] = j; }
	valids = basis_size;

	int last_valids;  // Number of valid basis element in previous iteration
	
	do{
                //(Re)compute the basis index table:
                int ixred = 0;
                for(j = 0; j < basis_size; j++){ if (valid_basis[j]){ oldix[ixred] = j; ixred++; } }
		//fprintf(stderr,"Iteration %d of light source %d with %d valid basis elements\n",iterations,luz,valids);
		last_valids = valids;
		//first we copy the relevant rows and columns from A
                ixred = 0;
		for(i = 0; i < basis_size; i++){
			for(j = 0; j < basis_size; j++){
				int index = (i*basis_size) + j;
				if( valid_basis[i] && valid_basis[j] ){
					Ared[ixred] = A[index]; ixred++;
				}
			}
		}
		
	
                assert(ixred == valids*valids);
                // Now compute the RHS for the reduced basis:
		//fprintf(stderr,"Generating RHS vector...");
		computeLeastSquaresRHSVector(b,&phiRed,valids,X,F,n,w,wpos,l_data);
		//fprintf(stderr,"OK\n");
		//fprintf(stderr,"Solving system ...");
		
		
		gsel_solve(valids,valids, Ared, 1, b, c);
		//fprintf(stderr,"OK\n");
		
		
		
                // Scatter the coefficients to their original number:
                ixred = valids;
                for (j = basis_size-1; j >= 0; j--){ 
                  if(valid_basis[j]){ ixred--; c[j] = c[ixred]; }else{ c[j] = 0.0; }
                }
		valids = validateResults(c,basis_size,valid_basis);
		iterations++;
		
		
	}while((last_valids != valids) && (valids != 0));

	
	
	for(i = 0; i < basis_size; i++){
		if(!valid_basis[i]){
			c[i] = 0.0;
		}
	}
	//fprintf(stderr,"Status: %d basis elements discarded in %d iterations. \n",basis_size - valids,iterations);
	

	
        free(Ared);
        free(valid_basis);
        free(oldix);
}

	
void fitModelToFunction(double** X, double* F, int n,
		    ls_model_t* lm,
		     void* l_data,
		      int update_steps
		     ){
	double* A;
	double* b;
	double* c;
	double* w; //dynamically adjusted weight, probability of goodness of a point of table
	double* wpos; //a priori weights
	
	int basis_size = lm->get_num_components(l_data);
	
	A = (double*)malloc(sizeof(double)*(basis_size*basis_size));
	b = (double*)malloc(sizeof(double)*basis_size);
	c = (double*)malloc(sizeof(double)*basis_size);
	//w = (double*)malloc(sizeof(double)*n);
	int i;

	if(lm->wpos != NULL){
	   wpos = lm->wpos;
	}else{
	  wpos = (double*)malloc(sizeof(double)*n);
	}
	
	if(lm->weights != NULL){
	   w = lm->weights;
	}else{
	  w = (double*)malloc(sizeof(double)*n);
	}
	
	for(i = 0; i < n; i++){
		  if(lm->weights == NULL){
		    w[i] = 1.0;
		  }
		  double* x = X[i];
		  if(lm->wpos == NULL){
		    wpos[i] = lm->compute_pos_weights(x,l_data);
		  }
	}
	 
	 

        // Compute matrix that maps true normal to view-relative normal:
	double sigma = 0.2;
	//fprintf(stderr,"Basis size is %d\n",lm->get_num_components(l_data));
	//fprintf(stderr,"Starting weight-adjusting iteration SIGMA = %9.6f\n",sigma);
	int passo = 0;
	//int npassos = 3;
	
//	int luz;
	double threshold = 0.01;
// 	if(lm->type == 0){
// 	    fprintf(stderr," Start : ");
// 	    lm->write_param(stderr,l_data);
// 	}
	for(passo = 0; passo < update_steps; passo++){
//		double sigma2 = sigma*sigma;
		//fprintf(stderr,"Weight adjustment step %d\n",passo);
		
		int num_iters = 0;
		int MAX_ITER = 100;
		double epsilon = 10e-5;
		double diff = 1.0; 
		while( (num_iters < MAX_ITER) && (diff > epsilon)){
		//  fprintf(stderr,"Generating system matrix...");
		
	
		  computeLeastSquaresMatrix(A,lm->phi,basis_size,X,n,w,wpos,l_data);
		//  fprintf(stderr,"OK\n");
		  void* l_data_old = lm->copy_data(l_data);
		  computeLSTerms(A,
			      b,
			      c,
			      lm->phi,
			      basis_size,
			      lm->validate_results,
			      X,F,n,
			      w,
			      wpos,
			      l_data
		  );
		  
		  /*update light model*/
		  lm->retrieve_components(c,basis_size,l_data);
		  diff = lm->compare(l_data,l_data_old);
// 
		//  fprintf(stderr,"Finished Iteration %d - (diff %lf vs %lf) \n",num_iters,diff,epsilon);
		if(lm->type == 3){
		  lm->write_param(stderr,l_data);
		}
		  lm->release_data(l_data_old);
		  num_iters++;
		}
		
		//fprintf(stderr,"Computing errors and adjusting weight...");
		lm->update_weights(l_data,X,F,w,n,&sigma);
		
		int bad_lines = 0;
		double sumw = 1.0e-200;
		for(i = 0; i<  n; i++){
		  sumw+=w[i];
		}
		for(i = 0; i<  n; i++){
			w[i] = (w[i]/sumw)*n;
			if(w[i] < threshold) bad_lines++;
		}
// 			if(lm->type == 0){
// 			    fprintf(stderr," Step : ");
// 			     lm->write_param(stderr,l_data);
// 			 }
		//fprintf(stderr,"OK - %d  bad lines in table (%9.6f %%)\n",bad_lines,100.0*bad_lines/(double)n);
		
	}

	lm->num_weights =  n;
	lm->weights = w;
	lm->wpos = wpos;
	
        free(A);
	free(b);

	free(c);
	
}