package JKernelMachines; import java.util.Hashtable; import java.util.List; import JKernelMachines.Kernel; import JKernelMachines.TrainingSample; /** * Major kernel computed as a weighted sum of minor kernels : * K = w_i * k_i * @author dpicard * * @param */ public class WeightedSumKernel extends Kernel { /** * */ private static final long serialVersionUID = 4590492743843223113L; private Hashtable, Double> kernels; public WeightedSumKernel() { kernels = new Hashtable, Double>(); } /** * Sets the weights to h. Beware! It does not make a copy of h! * @param h */ public WeightedSumKernel(Hashtable, Double> h) { kernels = h; } /** * adds a kernel to the sum with weight 1.0 * @param k */ public void addKernel(Kernel k) { kernels.put(k, 1.0); } /** * adds a kernel to the sum with weight d * @param k * @param d */ public void addKernel(Kernel k , double d) { kernels.put(k, d); } /** * removes kernel k from the sum * @param k */ public void removeKernel(Kernel k) { kernels.remove(k); } /** * gets the weights of kernel k * @param k * @return the weight associated with k */ public double getWeight(Kernel k) { Double d = kernels.get(k); if(d == null) return 0.; return d.doubleValue(); } /** * Sets the weight of kernel k * @param k * @param d */ public void setWeight(Kernel k, Double d) { kernels.put(k, d); } @Override public double valueOf(T t1, T t2) { double sum = 0.; for(Kernel k : kernels.keySet()) sum += kernels.get(k)*k.valueOf(t1, t2); return sum; } @Override public double valueOf(T t1) { double sum = 0.; for(Kernel k : kernels.keySet()) sum += kernels.get(k)*k.valueOf(t1); return sum; } /** * get the list of kernels and associated weights. * @return hashtable containing kernels as keys and weights as values. */ public Hashtable, Double> getWeights() { return kernels; } /* (non-Javadoc) * @see JKernelMachines.Kernel#getKernelMatrix(java.util.ArrayList) */ @Override public double[][] getKernelMatrix(List> e) { double matrix[][] = new double[e.size()][e.size()]; for(Kernel k : kernels.keySet()) { double[][] m = k.getKernelMatrix(e); double w = kernels.get(k)/100; w = w*100; for(int i = 0 ; i < e.size() ; i++) for(int j = i ; j < e.size() ; j++) { matrix[i][j] += w*m[i][j]; if(i != j) matrix[j][i] += w*m[j][i]; } } return matrix; } }