#include "IrregularMatrix.h"
#include<stdlib.h>


IrregularMatrix* IrregularMatrix::add(IrregularMatrix* op)
{
	if(!this->mDim->equals(op->mDim))
		return NULL;
	
	IrregularMatrix* M = new IrregularMatrix(this->mDim);
	M->allocate();
	int i,j;
	for(i=0;i<=this->mDim->dim;i++)
		for(j=0;j<=this->mDim->mIdx[i];j++)
			M->data[i][j] = op->data[i][j]->add(this->data[i][j]);
	
	return M;
}

IrregularMatrix* IrregularMatrix::multiply_by_scalar(Number* a)
{
	IrregularMatrix* M = new IrregularMatrix(this->mDim);
	M->allocate();
	int i,j;
	for(i=0;i<=this->mDim->dim;i++)
		for(j=0;j<=this->mDim->mIdx[i];j++)
			M->data[i][j] = a->times(this->data[i][j]);
	
	return M;
}

IrregularMatrix::IrregularMatrix(MultiIndex* mDim)
{
	this->mDim = mDim;
}

IrregularMatrix::~IrregularMatrix()
{
	this->deallocate();
}


void IrregularMatrix::allocate_and_zero()
{
	
	this->data = (Number***) calloc(this->mDim->dim+1,sizeof(Number**));
	
	int i;
	for(i=0;i<=this->mDim->dim;i++)
		this->data[i] = (Number**) calloc(this->mDim->mIdx[i]+1,sizeof(Number*));
	
}

void IrregularMatrix::allocate()
{
	
	this->data = (Number***) calloc(this->mDim->dim+1,sizeof(Number**));
	
	int i;
	for(i=0;i<=this->mDim->dim;i++)
		this->data[i] = (Number**) calloc(this->mDim->mIdx[i]+1,sizeof(Number*));
	
}
void IrregularMatrix::deallocate()
{
	
	this->data = (Number***) calloc(this->mDim->dim+1,sizeof(Number**));
	
	int i;
	for(i=0;i<=this->mDim->dim;i++)
		delete(this->data[i]);
	
	free(this->data);
}


Number* IrregularMatrix::pow_elements(IrregularMatrix* exp)
{
	if(!this->mDim->equals(exp->mDim))
		return NULL;
	
	int i,j;
	Number* tmp01;
	Number* tmp02;
	Number* a = Number::getUnitary();
	
	for(i=0;i<=this->mDim->dim;i++)
		for(j=0;j<=this->mDim->mIdx[i];j++)
			{
				if(this->data[i][j]->equals(Number::getNULL()))
					return Number::getNULL();
				
				tmp01 = this->data[i][j]->powN(exp->data[i][j]);
				tmp02 = a->times(tmp01);
				free(a);
				free(tmp01);
				a = tmp02;
			}
		
	
	return a;
}

Number* IrregularMatrix::pow_elements(HyperIndex* exp)
{
	if(!this->mDim->equals(exp->dim))
		return NULL;
	
	int i,j;
	Number* tmp01;
	Number* tmp02;
	Number* a = Number::getUnitary();
	
	for(i=0;i<=this->mDim->dim;i++)
		for(j=0;j<=this->mDim->mIdx[i];j++)
			{
				if( (this->data[i][j]->equals(Number::getNULL()))&&
						(this->data[i][j]->val!=0 ) )
					return Number::getNULL();
				
				tmp01 = this->data[i][j]->pow_int(exp->line[i]->mIdx[j]);
				tmp02 = a->times(tmp01);
				free(a);
				free(tmp01);
				a = tmp02;
			}
		
	
	return a;
}


Number* IrregularMatrix::multiply_elements(IrregularMatrix* op)
{
	if(!this->mDim->equals(op->mDim))
		return NULL;
	
	int i,j;
	Number* tmp01;
	Number* tmp02;
	Number* a = Number::getUnitary();
	
	for(i=0;i<=this->mDim->dim;i++)
		for(j=0;j<=this->mDim->mIdx[i];j++)
			{
				if(this->data[i][j]->equals(Number::getNULL()))
					return Number::getNULL();
				
				tmp01 = this->get(i,j)->times(op->get(i,j));
				tmp02 = a->times(tmp01);
				free(a);
				free(tmp01);
				a = tmp02;
			}
		
	
	return a;
}

Number* IrregularMatrix::dotProduct(IrregularMatrix* op)
{
	if(!this->mDim->equals(op->mDim))
		return NULL;
	
	int i,j;
	Number* tmp01;
	Number* tmp02;
	Number* a = Number::getNULL();
	
	for(i=0;i<=this->mDim->dim;i++)
		for(j=0;j<=this->mDim->mIdx[i];j++)
			{
				tmp01 = this->data[i][j]->times(op->data[i][j]);
				tmp02 = a->add(tmp01);
				free(a);
				free(tmp01);
				a = tmp02;
			}
		
	
	return a;
}


bool IrregularMatrix::validIdxs(int i,int j)
{
	if((i>this->mDim->dim)||(i<0))
			return false;
		
		if((j<0)||(j>this->mDim->mIdx[i]))
			return false;
		
		return true;
}

Number* IrregularMatrix::get(int i,int j)
{
	if(!validIdxs(i,j))
		return NULL;
	
	return this->data[i][j];
}

void IrregularMatrix::setVal(int i,int j,Number* n)
{
	if(!validIdxs(i,j))
		return;
	
	this->data[i][j]->val = n->val;
	return;
}

IrregularMatrix* IrregularMatrix::getRows(int nrows,int* ridx)
{
	MultiIndex* outmDim = this->mDim->getSubIndex(nrows-1,ridx);
	IrregularMatrix* outRows = new IrregularMatrix(outmDim);
	outRows->allocate();
	
	int i,j;
	for(i=0;i<nrows;i++)
		for(j=0;j<=outmDim->mIdx[i];j++)
			outRows->data[i][j]=this->data[ridx[i]][j]->clone();
	
	return outRows;
}
IrregularMatrix* IrregularMatrix::getRow(int ridx)
{
	if(!validIdxs(ridx,0))
		return NULL;
	
	return IrregularMatrix::fromNumber1DVector(this->mDim->mIdx[ridx]+1,this->data[ridx]);
}

IrregularMatrix* IrregularMatrix::fromNumber1DVector(int length,Number** v)
{
	MultiIndex* outmDim = new MultiIndex(0,length-1);
	IrregularMatrix* outRow = new IrregularMatrix(outmDim);
	outRow->allocate();
		
	int j;
	for(j=0;j<=outmDim->mIdx[0];j++)
		outRow->data[0][j]=v[j]->clone();
	
	return outRow;
}

	
void IrregularMatrix::setRow(int ridx,IrregularMatrix* r)
{
	if(r->mDim->mIdx[0]!=this->mDim->mIdx[ridx])
		return;
	
	int j;
	for(j=0;j<=this->mDim->mIdx[ridx];j++)
		this->data[ridx][j] = r->data[0][j]->clone();
	
	return;
}


IrregularMatrix* IrregularMatrix::insertElement(int i,int j, Number* n)
{
	if(!validIdxs(i,0))
		return NULL;
	//need to check if its the last element of the row
	if((j<0)||(j>this->mDim->mIdx[i]+1))
		return NULL;
	
	//increment the dimension of row "i"
	MultiIndex* newmDim = this->mDim->clone();
	newmDim->mIdx[i] = newmDim->mIdx[i]+1;
	
	IrregularMatrix* newM = new IrregularMatrix(newmDim);
	newM->allocate();

	
	// copy the rows above "i"
	int k,l;
	for(k=0;k<i;k++)
		for(l=0;l<=this->mDim->mIdx[k];l++)
			newM->data[k][l] = this->data[k][l]->clone();
	
	//copy "ith" row and modify it
	for(l=0;l<j;l++)
		newM->data[i][l]=this->data[i][l]->clone();
	
	newM->data[i][j] = n->clone();
	
	for(l=j+1;l<=newmDim->mIdx[i];l++)
			newM->data[i][l]=this->data[i][l-1]->clone();
	
	// copy the rows above "i"
	for(k=i+1;k<=newmDim->dim;k++)
		for(l=0;l<=this->mDim->mIdx[k];l++)
			newM->data[k][l] = this->data[k][l]->clone();
		
	return newM;	
	
}

IrregularMatrix* IrregularMatrix::insertRow(int i, IrregularMatrix* row)
{
	MultiIndex* newmDim = this->mDim->insertElement(i,row->mDim->mIdx[0]);
	//newmDim->print(stderr);
	
	IrregularMatrix* newM = new IrregularMatrix(newmDim);
	newM->allocate();
		
	int j=0,k;
	
	for(j=0;j<i;j++)
		for(k=0;k<=newmDim->mIdx[j];k++)
			newM->data[j][k] = this->data[j][k]->clone();
	
	for(k=0;k<=newmDim->mIdx[i];k++)
		newM->data[i][k] = row->data[0][k]->clone();
	
	
	for(j=i+1;j<=newM->mDim->dim;j++)
		for(k=0;k<=newmDim->mIdx[j];k++)
			newM->data[j][k] = this->data[j-1][k]->clone();
	
	return newM;
	
}

void IrregularMatrix::print(FILE* fp)
{
	int i,j;
	fprintf(fp,"\n{");
	for(i=0;i<=this->mDim->dim;i++)
	{
		fprintf(fp,"(");
		for(j=0;j<=this->mDim->mIdx[i];j++)
			fprintf(fp,"%g, ",this->data[i][j]->val);
		fprintf(fp,")");
	}
	fprintf(fp,"}\n");
			
}

void IrregularMatrix::printPlot(FILE* fp)
{
	int i,j;
	fprintf(fp,"\n");
	for(i=0;i<=this->mDim->dim;i++)
	{
		for(j=0;j< this->mDim->mIdx[i];j++)
			fprintf(fp,"%g ",this->data[i][j]->val);
		
	}
	fprintf(fp,"\n");
			
}


Number* IrregularMatrix::rowSum(int row_n)
{
	Number* sum = Number::getNULL();
	Number* tmp;
	
	int i,ncols;
	ncols = this->mDim->mIdx[row_n]+1;
	for(i=0;i<ncols;i++)
	{
		tmp = sum;
		sum = tmp->add(this->get(row_n,i));
		delete(tmp);
	}
	
	return sum;
}


std::vector<IrregularMatrix*> IrregularMatrix::insertElement(int i,int j, Number* n, std::vector<IrregularMatrix*> v, std::vector<IrregularMatrix*> accum)
{
	std::vector<IrregularMatrix*>::iterator iter;
	IrregularMatrix* tmp01M;
	IrregularMatrix* tmp02M;
	
	for(iter=v.begin();iter!=v.end();iter++)
	{
		tmp01M = (IrregularMatrix*)(*iter);
		tmp02M = tmp01M->insertElement(i,j,n);
		accum.push_back(tmp02M);
	}
		
	return accum;
	
}

std::vector<IrregularMatrix*> IrregularMatrix::insertRow(int i,IrregularMatrix* row, std::vector<IrregularMatrix*> v, std::vector<IrregularMatrix*> accum)
{
	std::vector<IrregularMatrix*>::iterator iter;
	IrregularMatrix* tmp01M;
	IrregularMatrix* tmp02M;
	
	for(iter=v.begin();iter!=v.end();iter++)
	{
		tmp01M = (IrregularMatrix*)(*iter);
		tmp02M = tmp01M->insertRow(i,row);
		accum.push_back(tmp02M);
	}
		
	return accum;
}

