package JKernelMachines;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;

import JKernelMachines.DoublePegasosSVM;
import JKernelMachines.SMOSVM;
import JKernelMachines.Kernel;
import JKernelMachines.DoubleLinear;
import JKernelMachines.TrainingSample;

public class TestSimpleSVM {

	public static void main(String[] args)
	{
		int dimension = 3;
		int nbPosTrain = 100;
		int nbNegTrain = 100;
		int nbPosTest = 400;
		int nbNegTest = 400;
		
		Random ran = new Random(System.currentTimeMillis());
		
		ArrayList<TrainingSample<double[]>> train = new ArrayList<TrainingSample<double[]>>();
		//1. generate positive train samples
		for(int i = 0 ; i < nbPosTrain; i++)
		{
			double[] t = new double[dimension];
			for(int x = 0 ; x < dimension; x++)
			{
				t[x] = ran.nextGaussian();
			}
			
			train.add(new TrainingSample<double[]>(t, 1));
		}
		//2. generate negative train samples
		for(int i = 0 ; i < nbNegTrain; i++)
		{
			double[] t = new double[dimension];
			for(int x = 0 ; x < dimension; x++)
			{
				t[x] = 3. + ran.nextGaussian();
			}
			
			train.add(new TrainingSample<double[]>(t, -1));
		}
		
		
		//3. train svm
		Kernel<double[]> k = new DoubleLinear();
		SMOSVM<double[]> svm = new SMOSVM<double[]>(k);
		svm.setC(1e3);
		svm.train(train);
		
		//3.1 train pegasos
		DoublePegasosSVM peg = new DoublePegasosSVM();
		peg.setK(25);
		peg.setT(50000);
		peg.train(train);
		
		
		ArrayList<TrainingSample<double[]>> test = new ArrayList<TrainingSample<double[]>>();
		//4. generate positive test samples
		for(int i = 0 ; i < nbPosTest; i++)
		{
			double[] t = new double[dimension];
			for(int x = 0 ; x < dimension; x++)
			{
				t[x] = ran.nextGaussian();
			}
			
			test.add(new TrainingSample<double[]>(t, 1));
		}
		//5. generate negative test samples
		for(int i = 0 ; i < nbNegTest; i++)
		{
			double[] t = new double[dimension];
			for(int x = 0 ; x < dimension; x++)
			{
				t[x] = 3. + ran.nextGaussian();
			}
			
			test.add(new TrainingSample<double[]>(t, -1));
		}
		
		//6. test svm
		int nbErr = 0;
		int pegErr = 0;
		for(TrainingSample<double[]> t : test)
		{
			int y = t.label;
			double value = svm.valueOf(t.sample);
			if(y*value < 0)
				nbErr++;
			double pegVal = peg.valueOf(t.sample);
			if(y*pegVal < 0)
				pegErr++;
			
			System.err.println("y : "+y+" value : "+value+" nbErr : "+nbErr+" pegVal : "+pegVal+" pegErr : "+pegErr);
			
			
		}
		
		//7. alphas from svm
		System.err.println("smo : alphas : "+Arrays.toString(svm.getAlphas()));
		
		//7.1 compute w for smo
		double w[] = new double[dimension];
		double alpha[] = svm.getAlphas();
		for(int t = 0 ; t < train.size(); t++)
		{
			double d[] = train.get(t).sample;
			int y = train.get(t).label;
			for(int i = 0 ; i < dimension; i++)
			{
				w[i] += alpha[t] * y * d[i];
			}
		}
		System.err.println("smo : w : "+Arrays.toString(w));
		System.err.println("smo : bias : "+svm.getB());
		System.err.println("smo : ||w|| : "+k.valueOf(w, w));
		
		//7.2 w from pegasos
		System.err.println("peg : w : "+Arrays.toString(peg.getW()));
		System.err.println("peg : bias : "+peg.getB());
		System.err.println("peg : ||w|| : "+k.valueOf(peg.getW(), peg.getW()));
		
		//8. comparing smo and peg
		System.err.println("< smo, peg > : "+(k.valueOf(w, peg.getW())/Math.sqrt(k.valueOf(w, w)*k.valueOf(peg.getW(), peg.getW()))));
		
	}
	
}
