#include "Bernstein.h"
#include<vector>



Bernstein::Bernstein()
{
}

Bernstein::~Bernstein()
{
}


//Number* Bernstein::C_03(MultiIndex* delta,MultiIndex* epsilon, MultiIndex* alpha, HyperIndex* Lambda, HyperIndex* Omega, DomainMapping* DM)

/*
 * Xi' para diversos omegas
 */
std::vector<Number*> Bernstein::get_C_03s_for_each_Omega(MultiIndex* delta, MultiIndex* epsilon, MultiIndex* alpha, HyperIndex* Lambda, std::vector<HyperIndex*> Omega_v, DomainMapping* DM)
{
	std::vector<Number*> C_03_v;
	std::vector<HyperIndex*>::iterator Omega_ptr;
	HyperIndex* Omega;
	MultiIndex* nu;
	Number* C_03;

	for(Omega_ptr = Omega_v.begin();Omega_ptr!=Omega_v.end();Omega_ptr++)
	{
		Omega = (HyperIndex*)(*Omega_ptr);
		Omega->load_MultiIndex_r();


		//Xi'
		C_03 = Bernstein::C_03(delta, epsilon,alpha,Lambda,Omega,DM);
		fprintf(stderr,"c_03-- %g -- \n",C_03->val);
		C_03_v.push_back(C_03);

	}
	return C_03_v;
}


std::vector<Number*> Bernstein::get_C_02s_for_each_Upsilon(MultiIndex* delta, MultiIndex* epsilon, MultiIndex* alpha, HyperIndex* Lambda, std::vector<HyperIndex*> Upsilon_v, DomainMapping* DM)
{
	std::vector<Number*> C_02_v;
	std::vector<HyperIndex*>::iterator Upsilon_ptr;
	HyperIndex* Upsilon;
	MultiIndex* nu;
	Number* C_02;

	for(Upsilon_ptr = Upsilon_v.begin();Upsilon_ptr!=Upsilon_v.end();Upsilon_ptr++)
	{
		Upsilon = (HyperIndex*)(*Upsilon_ptr);
		Upsilon->load_MultiIndex_r();
		nu = Upsilon->r->clone();

		C_02 = Bernstein::C_02(delta, epsilon,alpha,nu,Lambda,Upsilon,DM);
		fprintf(stderr,"-- %g -- \n",C_02->val);
		C_02_v.push_back(C_02);

	}
	return C_02_v;
}


std::vector<Number*> Bernstein::get_C_01s_for_each_Omega(int d, MultiIndex* kappa, MultiIndex* delta, std::vector<HyperIndex*> omega_v, DomainMapping* DM)
{
	std::vector<Number*> C_01_v;
	std::vector<HyperIndex*>::iterator omega_ptr;
	HyperIndex* omega;

	for(omega_ptr = omega_v.begin();omega_ptr!=omega_v.end();omega_ptr++)
	{
		omega = (HyperIndex*)(*omega_ptr);
		omega->load_MultiIndex_r();
		C_01_v.push_back(Bernstein::C_01(d,kappa,delta,omega,DM));
	}
	return C_01_v;
}

/*
 * Note that in terms of the formulation
 * presented in the technical report:
 *  C_01 = g!/mi! C_00 (= g!/mi! phi()  )
 * where phi() is given by equation ???
 *
 */
Number* Bernstein::C_01(int d, MultiIndex* kappa, MultiIndex* delta, HyperIndex* omega, DomainMapping* DM)
{

	Number* C_00;
	Number* tmp01N;


	Integer* g_fac = Integer::factorial(kappa->sum());
	Integer* mi_fac = omega->r->factorial();
	Number* coef = g_fac->div_by(mi_fac);

	C_00 = Bernstein::C_00( d,  kappa,  delta,  omega,  DM,0);

	delete g_fac;
	delete mi_fac;

	tmp01N = coef->times(C_00);

	delete C_00;
	delete coef;

	return tmp01N;

}

/*
 * Note that in terms of the formulation
 * presented in the technical report:
 *  C_00 = phi()
 * where phi() is given by equation ???
 *
 */
Number* Bernstein::C_00(int d, MultiIndex* kappa, MultiIndex* delta, HyperIndex* omega, DomainMapping* DM, int j)
{
	int i;
	Number* tmp01N;
	Number* ac_sum;

	Number* ac_prod;
	Number* it_prod;
	Number* tmp_prod;



	Number* coef = Number::getUnitary();


	Integer* omega_fac = omega->factorial();
	Integer* lambda_fac;


	MultiIndex* omega_lin = omega->toMultiIdx();

	//omega->print(stderr);
	//omega_lin->print(stderr);

	MultiIndex* lambda_i;

	HyperIndex* lambda_i_delta;

	std::vector<MatrixIndex*> v = MatrixIndex::getAllIndexes(kappa,omega_lin);
	std::vector<MatrixIndex*>::iterator Lambda_ptr;
	MatrixIndex* Lambda;


	ac_sum = Number::getNULL();
	for(Lambda_ptr=v.begin();Lambda_ptr!=v.end();Lambda_ptr++)
	{
		Lambda = (*Lambda_ptr);
		lambda_fac = Lambda->factorial();

		ac_prod = omega_fac->div_by(lambda_fac);
		//
		delete lambda_fac;

		//Lambda->print(stderr);

		for(i=0;i<=d;i++)
		{
			lambda_i = Lambda->line[i]->clone();
			lambda_i_delta = HyperIndex::fromMultiIdx(lambda_i,delta);

			//lambda_i->print(stderr);
			//delta->print(stderr);
			//lambda_i_delta->print(stderr);

			it_prod = DM->IM[j][i]->pow_elements(lambda_i_delta);
			tmp_prod = ac_prod;
			ac_prod = tmp_prod->times(it_prod);
			delete lambda_i;
			delete lambda_i_delta;
			delete it_prod;
			delete tmp_prod;
		}
		tmp01N = ac_sum;;
		ac_sum = tmp01N->add(ac_prod);
		delete ac_prod;
		delete tmp01N;

	}

	tmp01N = ac_sum->times(coef);
	delete omega_fac;
	delete ac_sum;
	delete coef;

	return tmp01N;
}


/*
 * Xi' \sum_{\nu \in \I_m^{|\alpha|} }  \sum_{\Upsilon \in \HI_{delta}^{\nu}; \Upsilon \le \Omega} C02 * comb(Omega,Upsilon)/comb(beta,nu)
 */


Number* Bernstein::C_03(MultiIndex* delta,MultiIndex* epsilon, MultiIndex* alpha, HyperIndex* Lambda, HyperIndex* Omega, DomainMapping* DM)
{
	int alpha_sum = alpha->sum();
	int m = alpha->dim;

	std::vector<HyperIndex*> upsilon_vec = get_Upsilons_for_composition_Adelta_to_Aepsilon(alpha_sum,m,epsilon);
	std::vector<HyperIndex*>::iterator hyperIter;

	MultiIndex* beta = new MultiIndex(m,alpha_sum);
	HyperIndex* Upsilon;

	Number* C_03 = Number::getNULL();
	Number* aux;
	Number* aux2;
	Number* aux3;
	Number* aux4;

	MultiIndex* nu;

	for(hyperIter=upsilon_vec.begin();hyperIter!=upsilon_vec.end();hyperIter++)
	{
		Upsilon = (*hyperIter);
		if(HyperIndex::compareHyperIndices_le(Upsilon,Omega))
		{
			Upsilon->load_MultiIndex_r();
			nu = Upsilon->r;

			aux = Bernstein::C_02(delta,epsilon,alpha,nu,Lambda,Upsilon,DM);
			aux2 = HyperIndex::combination(Omega,Upsilon);
			aux3 = MultiIndex::combination(beta,nu);
			aux4 = aux->times(aux2);
			delete(aux);
			delete(aux2);
			aux = aux4->div_by(aux3);

			delete(aux3);
			delete(aux4);

			aux2 = C_03->add(aux);
			delete(C_03);
			delete(aux);

			C_03 = aux2;


		}
	}


	return C_03;


}



/*
 * Note that in terms of the formulation
 * presented in the technical report:
 *  C_02 = alpha!Upsilon!/nu! Xi
 * where Xi() is given by equation ???
 *
 * A^delta --> A^epsilon
 */
Number* Bernstein::C_02(MultiIndex* delta, MultiIndex* epsilon, MultiIndex* alpha, MultiIndex* nu, HyperIndex* Lambda, HyperIndex* Upsilon, DomainMapping* DM)
{
	int i;
	Number* tmp01N;
	Number* tmp02N;
	Number* ac_sum;

	Number* ac_prod;
	Number* it_prod;
	Number* tmp_prod;



	//-------- coef = alpha!Upsilon!/nu! x n_elem_psi_vector -------------------------
	Integer* alpha_fac = alpha->factorial();
	Integer* Upsilon_fac = Upsilon->factorial();
	Integer* nu_fac = nu->factorial();

	tmp01N = alpha_fac->times(Upsilon_fac);


	Number* coef = tmp01N->div_by(nu_fac);

	delete alpha_fac;
	delete Upsilon_fac;
	delete nu_fac;
	delete tmp01N;



	std::vector<MatrixIndex*> psi_v;
	std::vector<MatrixIndex*>::iterator psi_ptr;
	psi_v = MatrixIndex::getAllIndexes(alpha,nu);
	//tmp01N = new Number(psi_v.size());
	tmp01N = new Number(1);
	tmp02N = coef;
	coef = tmp02N->times(tmp01N);

	delete tmp01N;
	delete tmp02N;

	//-------  (end of) coef = alpha!Upsilon!/nu! x n_elem_psi_vector -------------------------


	MultiIndex* Upsilon_lin = Upsilon->toMultiIdx();

	Number* unit = Number::getUnitary();
	Integer* mho_fac;
	MultiIndex* mho_i;
	HyperIndex* mho_i_delta;

	std::vector<MatrixIndex*> mho_v = MatrixIndex::getAllIndexes(alpha,Upsilon_lin);
	std::vector<MatrixIndex*>::iterator mho_ptr;
	MatrixIndex* mho;


	ac_sum = Number::getNULL();
	for(mho_ptr=mho_v.begin();mho_ptr!=mho_v.end();mho_ptr++)
	{
		mho = (*mho_ptr);
		mho_fac = mho->factorial();
		ac_prod = unit->div_by(mho_fac);
		//
		delete mho_fac;



		for(i=0;i<=epsilon->dim;i++)
		{
			mho_i = mho->line[i]->clone();
			mho->print(stderr);
			delta->print(stderr);
			mho_i_delta = HyperIndex::fromMultiIdx(mho_i,delta);
			
			


			it_prod = Bernstein::C_00(epsilon->mIdx[i],Lambda->line[i],delta,mho_i_delta,DM,i);
			tmp_prod = ac_prod;
			ac_prod = tmp_prod->times(it_prod);
			delete mho_i;
			delete mho_i_delta;
			delete it_prod;
			delete tmp_prod;
		}
		tmp01N = ac_sum;
		ac_sum = tmp01N->add(ac_prod);
		delete ac_prod;
		delete tmp01N;


	}

	tmp01N = ac_sum->times(coef);
	delete ac_sum;
	delete coef;
	delete unit;

	return tmp01N;
}

std::vector<HyperIndex*> Bernstein::get_Omegas_for_composition_Adelta_to_Ad(int m, MultiIndex* delta, int g)
{
	std::vector<MultiIndex*> mi_v;
	std::vector<MultiIndex*>::iterator mi_ptr;
	MultiIndex* mi;

	std::vector<HyperIndex*> omega_v;
	std::vector<HyperIndex*> omega_mi_v;
	std::vector<HyperIndex*>::iterator omega_ptr;

	mi_v = MultiIndex::getMultiIdxSet(g,m);
	for(mi_ptr=mi_v.begin();mi_ptr!=mi_v.end();mi_ptr++)
	{
		mi = (*mi_ptr);
		//mi->print(stderr);
		omega_mi_v = HyperIndex::getAllIndexes(mi,delta);
		//for(omega_ptr=omega_mi_v.begin();omega_ptr!=omega_mi_v.end();omega_ptr++)
			//(*omega_ptr)->print(stderr);
		omega_v.insert(omega_v.end(),omega_mi_v.begin(),omega_mi_v.end());
	}

	return omega_v;

}

std::vector<HyperIndex*> Bernstein::get_Upsilons_for_composition_Adelta_to_Aepsilon(int alpha_sum, int m, MultiIndex* delta)
{

	std::vector<MultiIndex*> nu_v;
	std::vector<MultiIndex*>::iterator nu_ptr;
	MultiIndex* nu;

	std::vector<HyperIndex*> mho_v;
	std::vector<HyperIndex*> mho_nu_v;
	std::vector<HyperIndex*>::iterator mho_ptr;

	nu_v = MultiIndex::getMultiIdxSet(alpha_sum,m);
	for(nu_ptr=nu_v.begin();nu_ptr!=nu_v.end();nu_ptr++)
	{
		nu = (*nu_ptr);

		mho_nu_v = HyperIndex::getAllIndexes(nu,delta);
		//for(omega_ptr=omega_mi_v.begin();omega_ptr!=omega_mi_v.end();omega_ptr++)
			//(*omega_ptr)->print(stderr);
		mho_v.insert(mho_v.end(),mho_nu_v.begin(),mho_nu_v.end());

	}

	return mho_v;

}

std::vector<HyperIndex*> Bernstein::get_Omegas_for_composition_Adelta_to_Aepsilon_fixed_mdegree(int alpha_sum, MultiIndex* delta)
{

	std::vector<HyperIndex*> mho_v;

	MultiIndex* beta = new MultiIndex(delta->dim,alpha_sum);
	mho_v = HyperIndex::getAllIndexes(beta,delta);

	return mho_v;

}


Number* Bernstein::eval_summation_over_vector_of_indices(std::vector<HyperIndex*> hIdx_v, std::vector<Number*> coef_v,DomainPoint* U)
{
	std::vector<HyperIndex*>::iterator hIdx_ptr;
	std::vector<Number*>::iterator coef_ptr;
	HyperIndex* hIdx;

	Number* coef;
	Number* ac_sum;
	Number* tmp01N;
	Number* tmp02N;
	Number* tmp03N;

	ac_sum = Number::getNULL();
	for(hIdx_ptr = hIdx_v.begin(),coef_ptr=coef_v.begin() ;
	((hIdx_ptr!=hIdx_v.end())&&(coef_ptr!=coef_v.end()));
	coef_ptr++,hIdx_ptr++)
	{
		hIdx = (HyperIndex*) (*hIdx_ptr);
		coef = (Number*) (*coef_ptr);
		//fprintf(stderr,"--coef = %g--",coef->val);
		tmp01N = Bernstein::eval(hIdx->r,hIdx,U);
		//fprintf(stderr,"--B = %g--",tmp01N->val);
		tmp02N = tmp01N->times(coef);
		tmp03N = ac_sum;
		ac_sum = tmp03N->add(tmp02N);
		delete tmp01N;
		delete tmp02N;
		delete tmp03N;

	}

	//if((hIdx_ptr!=hIdx_v.end())||(coef_ptr!=coef_v.end()))
	//	fprintf(stderr,"length mismatch of vectors!! -- check it!!\n");


	return ac_sum;
}


void Bernstein::precompute_composition_Adelta_to_Ad(int g, MultiIndex* kappa, MultiIndex* delta, int d,DomainMapping* DM)
{
	hIdx_v  = Bernstein::get_Omegas_for_composition_Adelta_to_Ad(delta->dim,delta,g);
	coef_v =  Bernstein::get_C_01s_for_each_Omega(d,kappa,delta,hIdx_v,DM);

}

void Bernstein::precompute_composition_Adelta_to_Aepsilon(MultiIndex* alpha, MultiIndex* delta, MultiIndex* epsilon, HyperIndex* Lambda,DomainMapping* DM)
{
	hIdx_v = Bernstein::get_Upsilons_for_composition_Adelta_to_Aepsilon(alpha->sum(),delta->dim,delta);
	coef_v = Bernstein::get_C_02s_for_each_Upsilon(delta, epsilon, alpha, Lambda, hIdx_v,DM);


}


void Bernstein::precompute_composition_Adelta_to_Aepsilon_fixed_mdegree(MultiIndex* alpha, MultiIndex* delta, MultiIndex* epsilon, HyperIndex* Lambda,DomainMapping* DM)
{
	hIdx_v = Bernstein::get_Omegas_for_composition_Adelta_to_Aepsilon_fixed_mdegree(alpha->sum(),delta);
	coef_v = Bernstein::get_C_03s_for_each_Omega(delta, epsilon, alpha, Lambda, hIdx_v,DM);


}


Number* Bernstein::eval_summation_over_precomputed_indices_and_coefs(DomainPoint* U)
{
	Number *res;

	res = Bernstein::eval_summation_over_vector_of_indices(this->hIdx_v,this->coef_v,U);

	return res;
}


Number* Bernstein::eval(MultiIndex* mdeg, HyperIndex* hIdx, DomainPoint* U)
{
	Number* res;

	Number *tmp01N;
	Number *tmp02N;

	res = Number::getUnitary();


	int i;
	for(i=0;i<=hIdx->dim->dim;i++)
	{
		tmp01N = Bernstein::eval(mdeg->mIdx[i],hIdx->line[i],(DomainPoint*)U->getRow(i));
		tmp02N = res;
		res = tmp02N->times(tmp01N);
		delete tmp01N;
		delete tmp02N;
	}


	return res;
}

Number* Bernstein::eval(int g, MultiIndex* mIdx, DomainPoint* U)
{
	Number* res;


	Number *tmp01N;
	Number *tmp02N;


	Integer* g_fac = Integer::factorial(g);
	Integer* mIdx_fac = mIdx->factorial();

	res = ((Number*)g_fac)->div_by((Number*)mIdx_fac);
	delete g_fac;
	delete mIdx_fac;

	int i;
	for(i=0;i<=mIdx->dim;i++)
	{
		tmp01N = U->data[0][i]->pow_int(mIdx->mIdx[i]);
		tmp02N = res;
		res = tmp02N->times(tmp01N);
		delete tmp01N;
		delete tmp02N;
	}

	return res;

}
