package JKernelMachines;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import JKernelMachines.TrainingSample;
/**
* Simple multithreaded implementation over a given Kernel. The multithreading comes only when
* computing the Gram matrix.
* Number of Threads is function of available processors.
* @author dpicard
*
* @param
*/
public class ThreadedKernel extends Kernel {
/**
*
*/
private static final long serialVersionUID = -2193768216118832033L;
protected Kernel k;
// private double[][] matrix;
/**
* MultiThread the given kernel
* @param kernel
*/
public ThreadedKernel(Kernel kernel)
{
this.k = kernel;
}
@Override
public double valueOf(T t1, T t2) {
return k.valueOf(t1, t2);
}
@Override
public double valueOf(T t1) {
return k.valueOf(t1);
}
/* (non-Javadoc)
* @see JKernelMachines.Kernel#getKernelMatrix(java.util.ArrayList)
*/
@Override
public double[][] getKernelMatrix(List> e) {
double[][] matrix = new double[e.size()][e.size()];
//heuristic choice of number of threads : about as much as the available processors
int nbc = ((int)Math.sqrt(Runtime.getRuntime().availableProcessors()+1));
int icrem = e.size()/nbc ;
ArrayList threads = new ArrayList();
for(int i = 0 ; i < e.size() ; i+=icrem)
for(int j = 0 ; j < e.size() ; j+=icrem)
{
MatrixThread t = new MatrixThread(matrix, e, i, i+icrem, j, j+icrem);
threads.add(t);
t.start();
}
boolean cont = true;
while(cont)
{
cont = false;
for(MatrixThread t : threads)
if(!t.hasFinished() && t.isAlive())
cont = true;
Thread.yield();
}
return matrix;
}
private class MatrixThread extends Thread
{
double[][] m;
List> e;
int mini, maxi, minj, maxj;
boolean finished = false;
/**
* @param m
* @param e2
* @param mini
* @param maxi
* @param minj
* @param maxj
*/
public MatrixThread(double[][] m, List> e2, int mini, int maxi,
int minj, int maxj) {
this.m = m;
this.e = e2;
this.mini = mini;
this.maxi = Math.min(maxi, e2.size());
this.minj = minj;
this.maxj = Math.min(maxj, e2.size());
}
public void run() {
finished = false;
for (int i = mini; i < maxi; i++) {
for (int j = minj; j < maxj; j++) {
T t1 = e.get(i).sample;
T t2 = e.get(j).sample;
double v = valueOf(t1, t2);
if(!Double.isNaN(v))
{
m[i][j] = valueOf(t1, t2);
}
else
{
System.err.println("NAN : v="+v);
System.err.println("t1="+Arrays.toString((double[])t1));
System.err.println("t1="+Arrays.toString((double[])t2));
System.exit(0);
}
}
}
finished = true;
}
public boolean hasFinished()
{
return finished;
}
}
}