/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.meta;

import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.Measurement;
import moa.core.Utils;
import moa.options.ClassOption;

public class OCBoost
extends AbstractClassifier
implements MultiClassClassifier {
    private static final long serialVersionUID = 1L;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "trees.HoeffdingTree");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of models to boost.", 10, 1, Integer.MAX_VALUE);
    public FloatOption smoothingOption = new FloatOption("smoothingParameter", 'e', "Smoothing parameter.", 0.5, 0.0, 100.0);
    protected Classifier[] ensemble;
    protected double[] alpha;
    protected double[] alphainc;
    protected double[] pipos;
    protected double[] pineg;
    protected double[][] wpos;
    protected double[][] wneg;

    @Override
    public String getPurposeString() {
        return "Online Coordinate boosting for two classes evolving data streams.";
    }

    @Override
    public void resetLearningImpl() {
        this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
        this.alpha = new double[this.ensemble.length];
        this.alphainc = new double[this.ensemble.length];
        this.pipos = new double[this.ensemble.length];
        this.pineg = new double[this.ensemble.length];
        this.wpos = new double[this.ensemble.length][this.ensemble.length];
        this.wneg = new double[this.ensemble.length][this.ensemble.length];
        Classifier baseLearner = (Classifier)this.getPreparedClassOption(this.baseLearnerOption);
        baseLearner.resetLearning();
        for (int i = 0; i < this.ensemble.length; ++i) {
            this.ensemble[i] = baseLearner.copy();
            this.alpha[i] = 0.0;
            this.alphainc[i] = 0.0;
            for (int j = 0; j < this.ensemble.length; ++j) {
                this.wpos[i][j] = this.smoothingOption.getValue();
                this.wneg[i][j] = this.smoothingOption.getValue();
            }
        }
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        double d = 1.0;
        int[] m = new int[this.ensemble.length];
        for (int j = 0; j < this.ensemble.length; ++j) {
            int k;
            int j0 = 0;
            this.pipos[j] = 1.0;
            this.pineg[j] = 1.0;
            m[j] = -1;
            if (this.ensemble[j].correctlyClassifies(inst)) {
                m[j] = 1;
            }
            for (k = j0; k <= j - 1; ++k) {
                int n = j;
                this.pipos[n] = this.pipos[n] * (this.wpos[j][k] / this.wpos[j][j] * Math.exp(-this.alphainc[k]) + (1.0 - this.wpos[j][k] / this.wpos[j][j]) * Math.exp(this.alphainc[k]));
                int n2 = j;
                this.pineg[n2] = this.pineg[n2] * (this.wneg[j][k] / this.wneg[j][j] * Math.exp(-this.alphainc[k]) + (1.0 - this.wneg[j][k] / this.wneg[j][j]) * Math.exp(this.alphainc[k]));
            }
            for (k = 0; k <= j; ++k) {
                this.wpos[j][k] = this.wpos[j][k] * this.pipos[j] + d * (double)(m[k] == 1 ? 1 : 0) * (double)(m[j] == 1 ? 1 : 0);
                this.wneg[j][k] = this.wneg[j][k] * this.pineg[j] + d * (double)(m[k] == -1 ? 1 : 0) * (double)(m[j] == -1 ? 1 : 0);
            }
            this.alphainc[j] = -this.alpha[j];
            this.alpha[j] = 0.5 * Math.log(this.wpos[j][j] / this.wneg[j][j]);
            int n = j;
            this.alphainc[n] = this.alphainc[n] + this.alpha[j];
            if (!((d *= Math.exp(-this.alpha[j] * (double)m[j])) > 0.0)) continue;
            Instance weightedInst = inst.copy();
            weightedInst.setWeight(inst.weight() * d);
            this.ensemble[j].trainOnInstance(weightedInst);
        }
    }

    protected double getEnsembleMemberWeight(int i) {
        return this.alpha[i];
    }

    @Override
    public double[] getVotesForInstance(Instance inst) {
        double[] output = new double[2];
        double combinedVote = 0.0;
        for (int i = 0; i < this.ensemble.length; ++i) {
            int vote = Utils.maxIndex(this.ensemble[i].getVotesForInstance(inst));
            if (vote == 0) {
                vote = -1;
            }
            combinedVote += (double)vote * this.getEnsembleMemberWeight(i);
        }
        output[0] = 0.0;
        output[1] = 0.0;
        output[combinedVote > 0.0 ? 1 : 0] = 1.0;
        return output;
    }

    @Override
    public boolean isRandomizable() {
        return true;
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return new Measurement[]{new Measurement("ensemble size", this.ensemble != null ? (double)this.ensemble.length : 0.0)};
    }

    @Override
    public Classifier[] getSubClassifiers() {
        return (Classifier[])this.ensemble.clone();
    }
}

