package JKernelMachines;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Hashtable;
import cern.colt.Arrays;
import JKernelMachines.Kernel;
import JKernelMachines.TrainingSample;
/**
* Major kernel computed as a weighted sum of minor kernels :
* K = w_i * k_i
* Computation of the kernel matrix is done by running a thread on sub matrices.
* The number of threads is choosen as function of the number of available cpus.
* @author dpicard
*
* @param
*/
public class ThreadedSumKernel extends Kernel {
/**
*
*/
private static final long serialVersionUID = 7780445301175174296L;
private Hashtable, Double> kernels;
private transient HashMap, double[][]> matrixMap;
protected int numThread = 0;
public ThreadedSumKernel()
{
kernels = new Hashtable, Double>();
}
/**
* Sets the weights to h. Beware! It does not make a copy of h!
* @param h
*/
public ThreadedSumKernel(Hashtable, Double> h)
{
kernels = new Hashtable, Double>();
kernels.putAll(h);
}
/**
* adds a kernel to the sum with weight 1.0
* @param k
*/
public void addKernel(Kernel k)
{
kernels.put(k, 1.0);
}
/**
* adds a kernel to the sum with weight d
* @param k
* @param d
*/
public void addKernel(Kernel k , double d)
{
kernels.put(k, d);
}
/**
* removes kernel k from the sum
* @param k
*/
public void removeKernel(Kernel k)
{
kernels.remove(k);
}
/**
* gets the weights of kernel k
* @param k
* @return the weight associated with k
*/
public double getWeight(Kernel k)
{
Double d = kernels.get(k);
if(d == null)
return 0.;
return d.doubleValue();
}
/**
* Sets the weight of kernel k
* @param k
* @param d
*/
public void setWeight(Kernel k, Double d)
{
kernels.put(k, d);
}
@Override
public double valueOf(T t1, T t2) {
double sum = 0.;
for(Kernel k : kernels.keySet())
sum += kernels.get(k)*k.valueOf(t1, t2);
return sum;
}
@Override
public double valueOf(T t1) {
double sum = 0.;
for(Kernel k : kernels.keySet())
sum += kernels.get(k)*k.valueOf(t1);
return sum;
}
/**
* get the list of kernels and associated weights.
* @return hashtable containing kernels as keys and weights as values.
*/
public Hashtable, Double> getWeights()
{
return kernels;
}
public double[][] getKernelMatrix(ArrayList> e)
{
double matrix[][] = new double[e.size()][e.size()];
numThread = 0;
//computing each matrix and storing them
matrixMap = new HashMap, double[][]>();
for(Kernel k : kernels.keySet())
{
double[][] m = k.getKernelMatrix(e);
matrixMap.put(k, m);
}
int nbc = ((int)Math.sqrt(2*Runtime.getRuntime().availableProcessors()))+1;
int icrem = e.size()/nbc ;
ArrayList threads = new ArrayList();
for(int i = 0 ; i < e.size() ; i+=icrem)
for(int j = 0 ; j < e.size() ; j+=icrem)
{
MatrixThread t = new MatrixThread(matrix, e, i, i+icrem, j, j+icrem);
threads.add(t);
t.start();
}
boolean cont = true;
while(cont)
{
cont = false;
for(MatrixThread t : threads)
if(!t.hasFinished())
cont = true;
Thread.yield();
}
return matrix;
}
private class MatrixThread extends Thread
{
double[][] m;
ArrayList> e;
int mini, maxi, minj, maxj;
boolean finished = false;
/**
* @param m
* @param e
* @param mini
* @param maxi
* @param minj
* @param maxj
*/
public MatrixThread(double[][] m, ArrayList> e, int mini, int maxi,
int minj, int maxj) {
this.m = m;
this.e = e;
this.mini = mini;
this.maxi = Math.min(maxi, e.size());
this.minj = minj;
this.maxj = Math.min(maxj, e.size());
}
public void run() {
finished = false;
ArrayList> listOfK = new ArrayList>();
synchronized(kernels)
{
listOfK.addAll(kernels.keySet());
}
for(Kernel k : listOfK)
{
double[][] matrix = null;
double w = 0.;
synchronized(kernels)
{
w = kernels.get(k);
if(Double.isNaN(w) || Double.isInfinite(w))
{
System.err.println("w error : "+w+" kernels:"+kernels);
System.exit(3);
}
}
if(w != 0)
{
synchronized(matrixMap)
{
matrix = matrixMap.get(k);
}
for (int i = mini; i < maxi; i++) {
for (int j = minj; j < maxj; j++) {
if(!Double.isNaN(matrix[i][j]))
{
synchronized(m[i])
{
m[i][j] += w*matrix[i][j];
}
if(Double.isNaN(m[i][j]) || Double.isInfinite(m[i][j]))
{
System.err.println(i+":"+j+" NaN!!! "+k+" "+m[i][j]+" w "+w+" matrixij"+matrix[i][j]);
System.exit(2);
}
}
}
}
}
}
finished = true;
}
public boolean hasFinished()
{
return finished;
}
}
}