package JKernelMachines; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import JKernelMachines.TrainingSample; /** * Simple multithreaded implementation over a given Kernel. The multithreading comes only when * computing the Gram matrix.
* Number of Threads is function of available processors. * @author dpicard * * @param */ public class ThreadedKernel extends Kernel { /** * */ private static final long serialVersionUID = -2193768216118832033L; protected Kernel k; // private double[][] matrix; /** * MultiThread the given kernel * @param kernel */ public ThreadedKernel(Kernel kernel) { this.k = kernel; } @Override public double valueOf(T t1, T t2) { return k.valueOf(t1, t2); } @Override public double valueOf(T t1) { return k.valueOf(t1); } /* (non-Javadoc) * @see JKernelMachines.Kernel#getKernelMatrix(java.util.ArrayList) */ @Override public double[][] getKernelMatrix(List> e) { double[][] matrix = new double[e.size()][e.size()]; //heuristic choice of number of threads : about as much as the available processors int nbc = ((int)Math.sqrt(Runtime.getRuntime().availableProcessors()+1)); int icrem = e.size()/nbc ; ArrayList threads = new ArrayList(); for(int i = 0 ; i < e.size() ; i+=icrem) for(int j = 0 ; j < e.size() ; j+=icrem) { MatrixThread t = new MatrixThread(matrix, e, i, i+icrem, j, j+icrem); threads.add(t); t.start(); } boolean cont = true; while(cont) { cont = false; for(MatrixThread t : threads) if(!t.hasFinished() && t.isAlive()) cont = true; Thread.yield(); } return matrix; } private class MatrixThread extends Thread { double[][] m; List> e; int mini, maxi, minj, maxj; boolean finished = false; /** * @param m * @param e2 * @param mini * @param maxi * @param minj * @param maxj */ public MatrixThread(double[][] m, List> e2, int mini, int maxi, int minj, int maxj) { this.m = m; this.e = e2; this.mini = mini; this.maxi = Math.min(maxi, e2.size()); this.minj = minj; this.maxj = Math.min(maxj, e2.size()); } public void run() { finished = false; for (int i = mini; i < maxi; i++) { for (int j = minj; j < maxj; j++) { T t1 = e.get(i).sample; T t2 = e.get(j).sample; double v = valueOf(t1, t2); if(!Double.isNaN(v)) { m[i][j] = valueOf(t1, t2); } else { System.err.println("NAN : v="+v); System.err.println("t1="+Arrays.toString((double[])t1)); System.err.println("t1="+Arrays.toString((double[])t2)); System.exit(0); } } } finished = true; } public boolean hasFinished() { return finished; } } }