/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.branchratemodel;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.tree.TreeModel;
import dr.inference.distribution.ParametricMultivariateDistributionModel;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.GradientProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import dr.xml.Reportable;
import java.util.Collections;
import java.util.List;

public class AutoCorrelatedBranchRatesDistribution
extends AbstractModelLikelihood
implements GradientWrtParameterProvider,
Citable,
Reportable {
    private final ArbitraryBranchRates branchRateModel;
    private final ParametricMultivariateDistributionModel distribution;
    private final BranchVarianceScaling scaling;
    private final BranchRateUnits units;
    private final Tree tree;
    private final Parameter rateParameter;
    private boolean incrementsKnown = false;
    private boolean savedIncrementsKnown;
    private boolean likelihoodKnown = false;
    private boolean savedLikelihoodKnown;
    private double logLikelihood;
    private double savedLogLikelihood;
    private double logJacobian;
    private double savedLogJacobian;
    private final int dim;
    private double[] increments;
    private double[] savedIncrements;
    public static Citation CITATION = new Citation(new Author[0], Citation.Status.IN_PREPARATION);

    public AutoCorrelatedBranchRatesDistribution(String string, ArbitraryBranchRates arbitraryBranchRates, ParametricMultivariateDistributionModel parametricMultivariateDistributionModel, BranchVarianceScaling branchVarianceScaling, boolean bl) {
        super(string);
        this.branchRateModel = arbitraryBranchRates;
        this.distribution = parametricMultivariateDistributionModel;
        this.scaling = branchVarianceScaling;
        this.units = bl ? BranchRateUnits.STRICTLY_POSITIVE : BranchRateUnits.REAL_LINE;
        this.tree = arbitraryBranchRates.getTree();
        this.rateParameter = arbitraryBranchRates.getRateParameter();
        this.addModel(arbitraryBranchRates);
        this.addModel(parametricMultivariateDistributionModel);
        if (this.tree instanceof TreeModel) {
            this.addModel((TreeModel)this.tree);
        }
        this.dim = arbitraryBranchRates.getRateParameter().getDimension();
        this.increments = new double[this.dim];
        this.savedIncrements = new double[this.dim];
        if (this.dim != parametricMultivariateDistributionModel.getMean().length) {
            throw new RuntimeException("Dimension mismatch in AutoCorrelatedRatesDistribution. " + this.dim + " != " + parametricMultivariateDistributionModel.getMean().length);
        }
    }

    public ParametricMultivariateDistributionModel getPrior() {
        return this.distribution;
    }

    @Override
    public Likelihood getLikelihood() {
        return this;
    }

    @Override
    public Parameter getParameter() {
        return this.rateParameter;
    }

    @Override
    public int getDimension() {
        return this.rateParameter.getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        double[] dArray = this.getGradientWrtIncrements();
        this.rescaleGradientWrtIncrements(dArray);
        double[] dArray2 = new double[this.dim];
        this.recurseGradientPreOrder(this.tree.getRoot(), dArray2, dArray);
        this.addJacobianTerm(dArray2);
        return dArray2;
    }

    double[] getGradientWrtIncrements() {
        if (!(this.distribution instanceof GradientProvider)) {
            throw new RuntimeException("Not yet implemented");
        }
        GradientProvider gradientProvider = (GradientProvider)((Object)this.distribution);
        this.checkIncrements();
        return gradientProvider.getGradientLogDensity(this.increments);
    }

    Tree getTree() {
        return this.tree;
    }

    BranchRateUnits getUnits() {
        return this.units;
    }

    BranchVarianceScaling getScaling() {
        return this.scaling;
    }

    ArbitraryBranchRates getBranchRateModel() {
        return this.branchRateModel;
    }

    private void rescaleGradientWrtIncrements(double[] dArray) {
        for (int i = 0; i < this.dim; ++i) {
            NodeRef nodeRef = this.tree.getNode(i);
            if (this.tree.isRoot(nodeRef)) continue;
            int n = this.branchRateModel.getParameterIndexFromNode(nodeRef);
            dArray[n] = this.scaling.rescaleIncrement(dArray[n], this.tree.getBranchLength(nodeRef));
        }
    }

    private void addJacobianTerm(double[] dArray) {
        for (int i = 0; i < this.dim; ++i) {
            NodeRef nodeRef = this.tree.getNode(i);
            if (this.tree.isRoot(nodeRef)) continue;
            int n = this.branchRateModel.getParameterIndexFromNode(nodeRef);
            dArray[n] = this.units.transformGradient(dArray[n], this.branchRateModel.getUntransformedBranchRate(this.tree, nodeRef));
        }
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        this.incrementsKnown = false;
        this.likelihoodKnown = false;
        this.fireModelChanged();
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        this.incrementsKnown = false;
        this.likelihoodKnown = false;
        this.fireModelChanged();
    }

    @Override
    protected void storeState() {
        this.savedIncrementsKnown = this.incrementsKnown;
        System.arraycopy(this.increments, 0, this.savedIncrements, 0, this.dim);
        this.savedLikelihoodKnown = this.likelihoodKnown;
        this.savedLogLikelihood = this.logLikelihood;
        this.savedLogJacobian = this.logJacobian;
    }

    @Override
    protected void restoreState() {
        this.incrementsKnown = this.savedIncrementsKnown;
        double[] dArray = this.savedIncrements;
        this.savedIncrements = this.increments;
        this.increments = dArray;
        this.likelihoodKnown = this.savedLikelihoodKnown;
        this.logLikelihood = this.savedLogLikelihood;
        this.logJacobian = this.savedLogJacobian;
    }

    @Override
    protected void acceptState() {
    }

    public double getIncrement(int n) {
        this.checkIncrements();
        return this.increments[n];
    }

    @Override
    public Model getModel() {
        return this;
    }

    @Override
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            this.logLikelihood = this.calculateLogLikelihood();
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    @Override
    public void makeDirty() {
        this.likelihoodKnown = false;
        this.incrementsKnown = false;
    }

    @Override
    public Citation.Category getCategory() {
        return null;
    }

    @Override
    public String getDescription() {
        return null;
    }

    @Override
    public List<Citation> getCitations() {
        return Collections.singletonList(CITATION);
    }

    private void checkIncrements() {
        if (!this.incrementsKnown) {
            this.logJacobian = this.recursePreOrder(this.tree.getRoot(), 0.0);
            this.incrementsKnown = true;
        }
    }

    private double calculateLogLikelihood() {
        this.checkIncrements();
        return this.logJacobian + this.distribution.logPdf(this.increments);
    }

    private double recursePreOrder(NodeRef nodeRef, double d) {
        double d2 = 0.0;
        if (!this.tree.isRoot(nodeRef)) {
            double d3;
            double d4 = this.branchRateModel.getUntransformedBranchRate(this.tree, nodeRef);
            double d5 = this.units.transform(d4);
            double d6 = this.tree.getBranchLength(nodeRef);
            d2 += this.units.getTransformLogJacobian(d4) + this.scaling.getTransformLogJacobian(d6);
            this.increments[this.branchRateModel.getParameterIndexFromNode((NodeRef)nodeRef)] = d3 = this.scaling.rescaleIncrement(d5 - d, d6);
            d = d5;
        }
        if (!this.tree.isExternal(nodeRef)) {
            d2 += this.recursePreOrder(this.tree.getChild(nodeRef, 0), d);
            d2 += this.recursePreOrder(this.tree.getChild(nodeRef, 1), d);
        }
        return d2;
    }

    private void recurseGradientPreOrder(NodeRef nodeRef, double[] dArray, double[] dArray2) {
        int n = this.branchRateModel.getParameterIndexFromNode(nodeRef);
        if (!this.tree.isRoot(nodeRef)) {
            int n2 = n;
            dArray[n2] = dArray[n2] + dArray2[n];
        }
        if (!this.tree.isExternal(nodeRef)) {
            NodeRef nodeRef2 = this.tree.getChild(nodeRef, 0);
            NodeRef nodeRef3 = this.tree.getChild(nodeRef, 1);
            if (!this.tree.isRoot(nodeRef)) {
                int n3 = n;
                dArray[n3] = dArray[n3] - dArray2[this.branchRateModel.getParameterIndexFromNode(nodeRef2)];
                int n4 = n;
                dArray[n4] = dArray[n4] - dArray2[this.branchRateModel.getParameterIndexFromNode(nodeRef3)];
            }
            this.recurseGradientPreOrder(nodeRef2, dArray, dArray2);
            this.recurseGradientPreOrder(nodeRef3, dArray, dArray2);
        }
    }

    @Override
    public String getReport() {
        return GradientWrtParameterProvider.getReportAndCheckForError(this, 0.0, Double.POSITIVE_INFINITY, null);
    }

    public static enum BranchVarianceScaling {
        NONE("none"){

            @Override
            double rescaleIncrement(double d, double d2) {
                return d;
            }

            @Override
            double getTransformLogJacobian(double d) {
                return 0.0;
            }

            @Override
            double inverseRescaleIncrement(double d, double d2) {
                return d;
            }
        }
        ,
        BY_TIME("byTime"){

            @Override
            double rescaleIncrement(double d, double d2) {
                return d / Math.sqrt(d2);
            }

            @Override
            double inverseRescaleIncrement(double d, double d2) {
                return d * Math.sqrt(d2);
            }

            @Override
            double getTransformLogJacobian(double d) {
                return -0.5 * Math.log(d);
            }
        };

        private final String name;

        private BranchVarianceScaling(String string2) {
            this.name = string2;
        }

        abstract double rescaleIncrement(double var1, double var3);

        abstract double inverseRescaleIncrement(double var1, double var3);

        abstract double getTransformLogJacobian(double var1);

        public String getName() {
            return this.name;
        }

        public static BranchVarianceScaling parse(String string) {
            for (BranchVarianceScaling branchVarianceScaling : BranchVarianceScaling.values()) {
                if (!branchVarianceScaling.getName().equalsIgnoreCase(string)) continue;
                return branchVarianceScaling;
            }
            return null;
        }
    }

    public static enum BranchRateUnits {
        REAL_LINE("realLine"){

            @Override
            double transform(double d) {
                return d;
            }

            @Override
            double getTransformLogJacobian(double d) {
                return 0.0;
            }

            @Override
            double inverseTransform(double d) {
                return d;
            }

            @Override
            double transformGradient(double d, double d2) {
                return d;
            }

            @Override
            double inverseTransformGradient(double d, double d2) {
                return d;
            }

            @Override
            boolean needsIncrementCorrection() {
                return false;
            }
        }
        ,
        STRICTLY_POSITIVE("strictlyPositive"){

            @Override
            double transform(double d) {
                return Math.log(d);
            }

            @Override
            double getTransformLogJacobian(double d) {
                return -Math.log(d);
            }

            @Override
            double inverseTransform(double d) {
                return Math.exp(d);
            }

            @Override
            double transformGradient(double d, double d2) {
                return (d - 1.0) / d2;
            }

            @Override
            double inverseTransformGradient(double d, double d2) {
                return d * d2;
            }

            @Override
            boolean needsIncrementCorrection() {
                return true;
            }
        };

        private final String name;

        private BranchRateUnits(String string2) {
            this.name = string2;
        }

        public String getName() {
            return this.name;
        }

        abstract double transform(double var1);

        abstract double transformGradient(double var1, double var3);

        abstract double getTransformLogJacobian(double var1);

        abstract double inverseTransform(double var1);

        abstract double inverseTransformGradient(double var1, double var3);

        abstract boolean needsIncrementCorrection();
    }
}

