#include "MatrixIndex.h"


MultiIndex* MatrixIndex::getColumn(int col)
{
	MultiIndex* c = new MultiIndex(this->n_rows-1);
	int i;
	for(i=0;i<this->n_rows;i++)
		c->mIdx[i] = this->line[i]->mIdx[col];
	
	return c;
		                               
		                          
}

Integer* MatrixIndex::combination(MatrixIndex* a, MatrixIndex* b)
{
	Integer* res = new Integer(1);
	Integer* tmp01I;
	Integer* tmp02I;
	
	int i;
	
	for(i=0;i< a->n_rows; i++)
	{
		tmp01I = MultiIndex::combination(a->line[i],b->line[i]);
		tmp02I = res;
		res = tmp02I->times(tmp01I);
		delete tmp01I;
		delete tmp02I;
	}
	
	return res;
}

MatrixIndex* MatrixIndex::clone()
{
	MatrixIndex* mtx;
	mtx = new MatrixIndex(this->n_rows,this->n_cols);
	int i;
	for(i=0;i<this->n_rows;i++)
		mtx->line[i] = this->line[i]->clone();
	
	return mtx;
	
}

MatrixIndex::MatrixIndex(int rows, int cols): HyperIndex(rows)
{
	this->dim = new MultiIndex(rows-1,cols-1);
	
	this->n_rows = rows;
	this->line = new MultiIndex*[rows];
	
	this->n_cols = cols;
	
}

MatrixIndex::~MatrixIndex()
{
	delete[] this->line;
}





MultiIndex* MatrixIndex::getColumnMultiIndex()
{
	MultiIndex* c = new MultiIndex(this->n_cols-1,0);
	
	int i,j;
	
	for(i=0;i<this->n_cols;i++)
		for(j=0;j<this->n_rows;j++)
			c->mIdx[i] = c->mIdx[i] + this->line[j]->mIdx[i];
	
	return c;
	
}


MultiIndex* MatrixIndex::getDiagonal()
{
	MultiIndex* c;
	int i;
	
	
	int min;
	min = this->n_rows;
	if(this->n_cols < min)
		min = this->n_cols;
	
	 c = new MultiIndex(min-1);
	
	for(i=0;i<min;i++)
		c->mIdx[i] = this->line[i]->mIdx[i];
	
	return c;
	
}







std::vector<MatrixIndex*> MatrixIndex::getAllIndexes(MultiIndex* r,MultiIndex* c, MultiIndex* diag)
{
	std::vector<MatrixIndex*> v,v2;
	std::vector<MatrixIndex*>::iterator ii;	
	
	MultiIndex* mtx_diag;
	MatrixIndex* mtx;
	
	
	v = MatrixIndex::getAllIndexes(r,c);
	
	//fprintf(stderr,"c--");
	//c->print(stderr);
	
	for(ii=v.begin(); ii!=v.end();ii++)
	{
		//(*ii)->print(stderr);
		mtx = (MatrixIndex*)(*ii);
		mtx_diag = mtx->getDiagonal();
		//mtx_col->print(stderr);
		
		if(diag->equals( mtx_diag ) )
			v2.push_back(*ii);
	}
	
	
	return v2;
	
}


std::vector<MatrixIndex*> MatrixIndex::getAllIndexes(MultiIndex* r,MultiIndex* c)
{
	std::vector<MatrixIndex*> v,v2;
	std::vector<MatrixIndex*>::iterator ii;	
	
	MultiIndex* mtx_col;
	MatrixIndex* mtx;
	
	int n_cols = c->dim+1;
	
	v = MatrixIndex::getAllIndexes(r,n_cols);
	
	//fprintf(stderr,"c--");
	//c->print(stderr);
	
	for(ii=v.begin(); ii!=v.end();ii++)
	{
		//(*ii)->print(stderr);
		mtx = (MatrixIndex*)(*ii);
		mtx_col = mtx->getColumnMultiIndex();
		//mtx_col->print(stderr);
		
		if(c->equals( mtx_col ) )
			v2.push_back(*ii);
	}
	
	
	
	return v2;
	
}


MatrixIndex* MatrixIndex::fromHyperIndex(HyperIndex* hi)
{
	MatrixIndex* mi = new MatrixIndex(hi->n_rows,hi->dim->mIdx[0]+1);
	int i;
	for(i=0;i<hi->n_rows;i++)
		mi->line[i] = hi->line[i]->clone();
	
	return mi;
}

std::vector<MatrixIndex*> MatrixIndex::getAllIndexes(MultiIndex* r,int n_cols)
{

	std::vector<MatrixIndex*> result;

	std::vector<HyperIndex*> v;
	std::vector<HyperIndex*>::iterator ii;
	
	MultiIndex* dim = new MultiIndex(r->dim,n_cols-1);
	v = HyperIndex::getAllIndexes(r,dim);
	
	
	for(ii=v.begin(); ii!=v.end();ii++)
	{
		//(*ii)->print(stderr);
		result.push_back(MatrixIndex::fromHyperIndex(*ii));
	}
			
	return result;
	
}

void MatrixIndex::print(FILE* fp)
{
	int i,j;
	fprintf(fp,"(\n");
	for(i=0;i<this->n_rows;i++)
	{	
		fprintf(fp,"[ ");
		for(j=0;j<this->n_cols;j++)
			fprintf(fp," %d",this->line[i]->mIdx[j]);
		fprintf(fp," ]\n");
	}
	fprintf(fp,");\n");
}



MatrixIndex* MatrixIndex::fromStar(MultiIndex* a, MultiIndex* b)
{
	if(a->dim!=b->dim)
		return NULL;

	MatrixIndex* mtx = new MatrixIndex(a->dim+1,2);
	
	int i;
	for(i=0;i<=a->dim ;i++)
	{
		mtx->line[i] = new MultiIndex(2-1);
		mtx->line[i]->mIdx[0] = a->mIdx[i];
		mtx->line[i]->mIdx[1] = b->mIdx[i] - a->mIdx[i];
	}

	return mtx;
}


bool MatrixIndex::isGreaterEqThan(MatrixIndex *op )
{
	int i;
	
	
	// Size mismatch
	if((this->n_rows != op->n_rows)||(this->n_cols!= op->n_cols))
		return false;
			
	for(i=0;i<this->n_rows;i++)
	{
		if( ! ( this->line[i]->isGreaterEqThan(op->line[i]) )  )
			return false;
	}
	
	
	return true;
}
