public class EMClusterer extends AbstractConditionalDistribution implements FunctionApproximater {

/**
 * The tolerance
 */
private static final double TOLERANCE = 1E-6;
/**
 * The tolerance
 */
private static final int MAX_ITERATIONS = 1000;
/**
 * The mixture distribution
 */
private MixtureDistribution mixture;
/**
 * The number of clusters
 */
private int k;
/**
 * The threshold
 */
private double tolerance;

/**
 * The max iterations
 */
private int maxIterations;

/**
 * How many iterations it took
 */
private int iterations;

/**
 * Whether to print stuff
 */
private boolean debug = false;

/**
 * Make a new em clusterer
 * @param k the number of clusters
 * @param tolerance the tolerance
 */
public EMClusterer(int k, double tolerance, int maxIterations) {
    this.k = k;
    this.tolerance = tolerance;
    this.maxIterations = maxIterations;
}

/**
 * Make a new clusterer
 */
public EMClusterer() {
    this(2, TOLERANCE, MAX_ITERATIONS);
}

/**
 * @see func.Classifier#classDistribution(shared.Instance)
 */
public Distribution distributionFor(Instance instance) {
    // calculate the log probs
    double[] probs = new double[mixture.getComponents().length];
    double maxLog = Double.NEGATIVE_INFINITY;
    for (int i = 0; i < probs.length; i++) {
        probs[i] = mixture.getComponents()[i].logp(instance);
        maxLog = Math.max(maxLog, probs[i]);
    }
    // turn into real probs
    double sum = 0;
    for (int i = 0; i < probs.length; i++) {
        probs[i] = Math.exp(probs[i] - maxLog);
        sum += probs[i];
    }
    // normalize
    for (int i = 0; i < probs.length; i++) {
        probs[i] /= sum;
    }
    return new DiscreteDistribution(probs);
}

/**
 * @see func.FunctionApproximater#estimate(shared.DataSet)
 */
public void estimate(DataSet set) {
    // kmeans initialization
    KMeansClusterer kmeans = new KMeansClusterer(k);
    kmeans.estimate(set);
    double[] prior = new double[k];
    double weightSum = 0;
    int[] counts = new int[k];
    int[] classifications = new int[set.size()];
    for (int i = 0; i < set.size(); i++) {
        classifications[i] = kmeans.value(set.get(i)).getDiscrete();
        counts[classifications[i]]++;
        prior[classifications[i]] += set.get(i).getWeight();
        weightSum += set.get(i).getWeight();
    }
    // create data sets for each of the classes
    Instance[][] instances = new Instance[k][];
    for (int i = 0; i < instances.length; i++) {
        instances[i] = new Instance[counts[i]];
    }
    Arrays.fill(counts, 0);
    for (int i = 0; i < set.size(); i++) {
        instances[classifications[i]][counts[classifications[i]]] = set.get(i);
        counts[classifications[i]]++;
    }
    MultivariateGaussian[] initial = new MultivariateGaussian[k];
    for (int i = 0; i < initial.length; i++) {
        initial[i] = new MultivariateGaussian();
        initial[i].setDebug(debug);
        initial[i].estimate(new DataSet(instances[i]));
        prior[i] /= weightSum;
    }
    mixture = new MixtureDistribution(initial, prior);
    // reestimate
    boolean done = false;
    double lastLogLikelihood = 0;
    iterations = 0;
    while (!done) {
        if (debug) {
            System.out.println("On iteration " + iterations);
            System.out.println(mixture);
        }
        mixture.estimate(set);
        double logLikelihood = 0;
        for (int j = 0; j < set.size(); j++) {
            logLikelihood += mixture.logp(set.get(j));
        }
        logLikelihood /= set.size();
        done = (iterations > 0 && Math.abs(logLikelihood - lastLogLikelihood) < tolerance)
            || (iterations + 1 >= maxIterations);
        lastLogLikelihood = logLikelihood;
        iterations++;
    }
}

/**
 * @see func.FunctionApproximater#value(shared.Instance)
 */
public Instance value(Instance i) {
    return distributionFor(i).mode();
}

/**
 * Get the number of iterations it took
 * @return the number
 */
public int getIterations() {
    return iterations;
}

/**
 * Is debug mode on
 * @return true if it is
 */
public boolean isDebug() {
    return debug;
}

/**
 * Set debug mode on or off
 * @param b the debug mode
 */
public void setDebug(boolean b) {
    debug = b;
}

/**
 * Get the mixture
 * @return the mixture
 */
public MixtureDistribution getMixture() {
    return mixture;
}

/**
 * @see java.lang.Object#toString()
 */
public String toString() {
    return mixture.toString();
}

}