/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.meta;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.options.ClassOption;

public class ADOB
extends AbstractClassifier
implements MultiClassClassifier {
    private static final long serialVersionUID = 1L;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "drift.SingleClassifierDrift -l trees.HoeffdingTree -d ADWINChangeDetector");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of models to boost.", 10, 1, Integer.MAX_VALUE);
    public FlagOption pureBoostOption = new FlagOption("pureBoost", 'p', "Boost with weights only; no poisson.");
    protected Classifier[] ensemble;
    protected int[] orderPosition;
    protected double[] scms;
    protected double[] swms;

    @Override
    public String getPurposeString() {
        return "Adaptable Diversity-based Online Boosting (ADOB)";
    }

    @Override
    public void resetLearningImpl() {
        this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
        this.orderPosition = new int[this.ensemble.length];
        Classifier baseLearner = (Classifier)this.getPreparedClassOption(this.baseLearnerOption);
        baseLearner.resetLearning();
        for (int i = 0; i < this.ensemble.length; ++i) {
            this.ensemble[i] = baseLearner.copy();
            this.orderPosition[i] = i;
        }
        this.scms = new double[this.ensemble.length];
        this.swms = new double[this.ensemble.length];
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        double[] acc = new double[this.ensemble.length];
        for (int i = 0; i < this.ensemble.length; ++i) {
            acc[i] = this.scms[this.orderPosition[i]] + this.swms[this.orderPosition[i]];
            if (acc[i] == 0.0) continue;
            acc[i] = this.scms[this.orderPosition[i]] / acc[i];
        }
        for (int i = 1; i < this.ensemble.length; ++i) {
            int key_position = this.orderPosition[i];
            double key_acc = acc[i];
            for (int j = i - 1; j >= 0 && acc[j] < key_acc; --j) {
                this.orderPosition[j + 1] = this.orderPosition[j];
                acc[j + 1] = acc[j];
            }
            this.orderPosition[j + 1] = key_position;
            acc[j + 1] = key_acc;
        }
        boolean correct = false;
        double lambda_d = 1.0;
        int maxAcc = 0;
        int minAcc = this.ensemble.length - 1;
        for (int i = 0; i < this.ensemble.length; ++i) {
            int pos;
            if (correct) {
                pos = this.orderPosition[maxAcc];
                ++maxAcc;
            } else {
                pos = this.orderPosition[minAcc];
                --minAcc;
            }
            double k = this.pureBoostOption.isSet() ? lambda_d : (double)MiscUtils.poisson(lambda_d, this.classifierRandom);
            if (k > 0.0) {
                Instance weightedInst = inst.copy();
                weightedInst.setWeight(inst.weight() * k);
                this.ensemble[pos].trainOnInstance(weightedInst);
            }
            if (this.ensemble[pos].correctlyClassifies(inst)) {
                int n = pos;
                this.scms[n] = this.scms[n] + lambda_d;
                lambda_d *= this.trainingWeightSeenByModel / (2.0 * this.scms[pos]);
                correct = true;
                continue;
            }
            int n = pos;
            this.swms[n] = this.swms[n] + lambda_d;
            lambda_d *= this.trainingWeightSeenByModel / (2.0 * this.swms[pos]);
            correct = false;
        }
    }

    protected double getEnsembleMemberWeight(int i) {
        double em;
        if (this.scms[i] > 0.0 && this.swms[i] > 0.0 && (em = this.swms[i] / (this.scms[i] + this.swms[i])) <= 0.5) {
            double Bm = em / (1.0 - em);
            return Math.log(1.0 / Bm);
        }
        return 0.0;
    }

    @Override
    public double[] getVotesForInstance(Instance inst) {
        double memberWeight;
        DoubleVector combinedVote = new DoubleVector();
        for (int i = 0; i < this.ensemble.length && (memberWeight = this.getEnsembleMemberWeight(i)) > 0.0; ++i) {
            DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst));
            if (!(vote.sumOfValues() > 0.0)) continue;
            vote.normalize();
            vote.scaleValues(memberWeight);
            combinedVote.addValues(vote);
        }
        return combinedVote.getArrayRef();
    }

    @Override
    public boolean isRandomizable() {
        return true;
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return new Measurement[]{new Measurement("ensemble size", this.ensemble != null ? (double)this.ensemble.length : 0.0)};
    }

    @Override
    public Classifier[] getSubClassifiers() {
        return (Classifier[])this.ensemble.clone();
    }
}

