/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.continuous;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.continuous.AbstractMultivariateTraitLikelihood;
import dr.evomodel.continuous.MissingTraits;
import dr.evomodel.continuous.MultivariateDiffusionModel;
import dr.evomodel.continuous.RestrictedPartials;
import dr.inference.loggers.LogColumn;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.distributions.WishartSufficientStatistics;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.Vector;
import dr.util.Author;
import dr.util.Citation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public abstract class IntegratedMultivariateTraitLikelihood
extends AbstractMultivariateTraitLikelihood {
    public static final double LOG_SQRT_2_PI = 0.5 * Math.log(Math.PI * 2);
    protected final CacheHelper cacheHelper;
    protected boolean areStatesRedrawn = false;
    protected double[] meanCache;
    protected double[] correctedMeanCache;
    protected double[] upperPrecisionCache;
    protected double[] lowerPrecisionCache;
    private double[] logRemainderDensityCache;
    protected double[] storedMeanCache;
    private double[] storedUpperPrecisionCache;
    private double[] storedLowerPrecisionCache;
    private double[] storedLogRemainderDensityCache;
    protected double[] drawnStates;
    protected final boolean integrateRoot = true;
    protected static boolean DEBUG = false;
    protected static boolean DEBUG_PREORDER = false;
    protected static boolean DEBUG_PNAS = false;
    private double[] zeroDimVector;
    protected WishartSufficientStatistics wishartStatistics;
    protected double[] Ay;
    protected double[][] tmpM;
    protected double[] tmp2;
    protected final MissingTraits missingTraits;
    protected Map<BitSet, RestrictedPartials> clampList = null;
    protected Map<NodeRef, RestrictedPartials> nodeToClampMap = null;
    private int partialsCount;
    private int spareIndex;
    protected boolean anyClamps = false;

    public IntegratedMultivariateTraitLikelihood(String string, MutableTreeModel mutableTreeModel, MultivariateDiffusionModel multivariateDiffusionModel, CompoundParameter compoundParameter, Parameter parameter, List<Integer> list, boolean bl, boolean bl2, boolean bl3, BranchRateModel branchRateModel, List<BranchRateModel> list2, List<BranchRateModel> list3, BranchRateModel branchRateModel2, Model model, List<RestrictedPartials> list4, boolean bl4, boolean bl5) {
        super(string, mutableTreeModel, multivariateDiffusionModel, compoundParameter, parameter, list, bl, bl2, bl3, branchRateModel, list2, list3, branchRateModel2, model, bl4, bl5);
        this.partialsCount = mutableTreeModel.getNodeCount();
        if (list4 != null) {
            for (RestrictedPartials restrictedPartials : list4) {
                restrictedPartials.setIndex(this.partialsCount);
                this.addRestrictedPartials(restrictedPartials);
                ++this.partialsCount;
            }
            this.spareIndex = this.partialsCount++;
            this.setupClamps();
        }
        this.cacheHelper = list2 != null ? new DriftCacheHelper(this.dim * this.partialsCount, bl) : (list3 != null ? new OUCacheHelper(this.dim * this.partialsCount, bl) : new CacheHelper(this.dim * this.partialsCount, bl));
        this.drawnStates = new double[this.dim * this.partialsCount];
        this.upperPrecisionCache = new double[this.partialsCount];
        this.lowerPrecisionCache = new double[this.partialsCount];
        this.logRemainderDensityCache = new double[this.partialsCount];
        if (bl) {
            this.storedUpperPrecisionCache = new double[this.partialsCount];
            this.storedLowerPrecisionCache = new double[this.partialsCount];
            this.storedLogRemainderDensityCache = new double[this.partialsCount];
        }
        this.Ay = new double[this.dimTrait];
        this.tmpM = new double[this.dimTrait][this.dimTrait];
        this.tmp2 = new double[this.dimTrait];
        this.zeroDimVector = new double[this.dim];
        this.missingTraits = new MissingTraits.CompletelyMissing(mutableTreeModel, list, this.dim);
        this.setTipDataValuesForAllNodes();
    }

    private void setTipDataValuesForAllNodes() {
        for (int i = 0; i < this.treeModel.getExternalNodeCount(); ++i) {
            NodeRef nodeRef = this.treeModel.getExternalNode(i);
            this.setTipDataValuesForNode(nodeRef);
        }
        this.missingTraits.handleMissingTips();
    }

    public double getTotalTreePrecision() {
        this.getLogLikelihood();
        int n = this.treeModel.getRoot().getNumber();
        return this.lowerPrecisionCache[n];
    }

    private void setTipDataValuesForNode(NodeRef nodeRef) {
        int n = nodeRef.getNumber();
        double[] dArray = this.traitParameter.getParameter(n).getParameterValues();
        if (dArray.length < this.dim) {
            throw new RuntimeException("The trait parameter for the tip with index, " + n + ", is too short");
        }
        this.cacheHelper.setTipMeans(dArray, this.dim, n, nodeRef);
    }

    public double[] getTipDataValues(int n) {
        double[] dArray = new double[this.dim];
        System.arraycopy(this.cacheHelper.getMeanCache(), this.dim * n, dArray, 0, this.dim);
        return dArray;
    }

    public void setTipDataValuesForNode(int n, double[] dArray) {
        this.cacheHelper.setTipMeans(dArray, this.dim, n);
        this.makeDirty();
    }

    @Override
    protected String extraInfo() {
        return "\tSample internal node traits: false\n";
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.TRAIT_MODELS;
    }

    @Override
    public String getDescription() {
        return super.getDescription() + " (first citation) with efficiently integrated internal traits (second citation)";
    }

    @Override
    public List<Citation> getCitations() {
        ArrayList<Citation> arrayList = new ArrayList<Citation>(super.getCitations());
        arrayList.add(new Citation(new Author[]{new Author("OG", "Pybus"), new Author("MA", "Suchard"), new Author("P", "Lemey"), new Author("F", "Bernadin"), new Author("A", "Rambaut"), new Author("FW", "Crawford"), new Author("RR", "Gray"), new Author("N", "Arinaminpathy"), new Author("S", "Stramer"), new Author("MP", "Busch"), new Author("E", "Delwart")}, "Unifying the spatial epidemiology and evolution of emerging epidemics", 2012, "Proceedings of the National Academy of Sciences", 109, 15066, 15071, Citation.Status.PUBLISHED));
        return arrayList;
    }

    @Override
    public double getLogDataLikelihood() {
        return this.getLogLikelihood();
    }

    private void setupClamps() {
        if (this.nodeToClampMap == null) {
            this.nodeToClampMap = new HashMap<NodeRef, RestrictedPartials>();
        }
        this.nodeToClampMap.clear();
        this.recursiveSetupClamp(this.treeModel, this.treeModel.getRoot(), new BitSet());
        this.anyClamps = this.nodeToClampMap.size() > 0;
    }

    private void recursiveSetupClamp(Tree tree, NodeRef nodeRef, BitSet bitSet) {
        if (tree.isExternal(nodeRef)) {
            bitSet.set(nodeRef.getNumber());
        } else {
            for (int i = 0; i < tree.getChildCount(nodeRef); ++i) {
                NodeRef nodeRef2 = tree.getChild(nodeRef, i);
                BitSet bitSet2 = new BitSet();
                this.recursiveSetupClamp(tree, nodeRef2, bitSet2);
                bitSet.or(bitSet2);
            }
            if (this.clampList.containsKey(bitSet)) {
                RestrictedPartials restrictedPartials = this.clampList.get(bitSet);
                restrictedPartials.setNode(nodeRef);
                this.nodeToClampMap.put(nodeRef, restrictedPartials);
            }
        }
    }

    public abstract boolean getComputeWishartSufficientStatistics();

    @Override
    public double calculateLogLikelihood() {
        int n;
        if (this.updateRestrictedNodePartials) {
            if (this.clampList != null) {
                this.setupClamps();
            }
            this.updateRestrictedNodePartials = false;
        }
        double d = 0.0;
        double[][] dArray = this.diffusionModel.getPrecisionmatrix();
        double d2 = Math.log(this.diffusionModel.getDeterminantPrecisionMatrix());
        double[] dArray2 = this.tmp2;
        boolean bl = this.getComputeWishartSufficientStatistics();
        if (bl) {
            this.wishartStatistics = new WishartSufficientStatistics(this.dimTrait);
        }
        this.postOrderTraverse(this.treeModel, this.treeModel.getRoot(), dArray, d2, bl);
        if (DEBUG) {
            System.err.println("mean: " + new Vector(this.cacheHelper.getMeanCache()));
            System.err.println("correctedMean: " + new Vector(this.cacheHelper.getCorrectedMeanCache()));
            System.err.println("upre: " + new Vector(this.upperPrecisionCache));
            System.err.println("lpre: " + new Vector(this.lowerPrecisionCache));
            System.err.println("cach: " + new Vector(this.logRemainderDensityCache));
        }
        int n2 = this.treeModel.getRoot().getNumber();
        double d3 = this.lowerPrecisionCache[n2];
        for (n = 0; n < this.numData; ++n) {
            double d4 = 0.0;
            System.arraycopy(this.cacheHelper.getMeanCache(), n2 * this.dim + n * this.dimTrait, dArray2, 0, this.dimTrait);
            if (DEBUG) {
                System.err.println("Datum #" + n);
                System.err.println("root mean: " + new Vector(dArray2));
                System.err.println("root prec: " + d3);
                System.err.println("diffusion prec: " + new Matrix(dArray));
            }
            double d5 = IntegratedMultivariateTraitLikelihood.computeWeightedAverageAndSumOfSquares(dArray2, this.Ay, dArray, this.dimTrait, d3);
            if (d3 != 0.0) {
                d4 += -LOG_SQRT_2_PI * (double)this.dimTrait + 0.5 * (d2 + (double)this.dimTrait * Math.log(d3) - d5);
            }
            if (DEBUG) {
                double[][] dArray3 = new double[this.dimTrait][this.dimTrait];
                for (int i = 0; i < this.dimTrait; ++i) {
                    for (int j = 0; j < this.dimTrait; ++j) {
                        dArray3[i][j] = dArray[i][j] * d3;
                    }
                }
                System.err.println("Conditional root MVN precision = \n" + new Matrix(dArray3));
                System.err.println("Conditional root MVN density = " + MultivariateNormalDistribution.logPdf(dArray2, new double[this.dimTrait], dArray3, Math.log(MultivariateNormalDistribution.calculatePrecisionMatrixDeterminate(dArray3)), 1.0));
            }
            d4 += this.integrateLogLikelihoodAtRoot(dArray2, this.Ay, this.tmpM, dArray, d3);
            if (DEBUG) {
                System.err.println("yAy = " + d5);
                System.err.println("logLikelihood (before remainders) = " + d4 + " (should match conditional root MVN density when root not integrated out)");
            }
            d += d4;
        }
        d += this.sumLogRemainders();
        if (DEBUG) {
            System.out.println("logLikelihood is " + d);
        }
        if (DEBUG) {
            System.err.println("logLikelihood (final) = " + d);
        }
        if (DEBUG_PNAS) {
            this.checkLogLikelihood(d, this.sumLogRemainders(), dArray2, d3, dArray);
            for (n = 0; n < this.logRemainderDensityCache.length; ++n) {
                if (!(this.logRemainderDensityCache[n] < -1.0E10)) continue;
                System.err.println(this.logRemainderDensityCache[n] + " @ " + n);
            }
        }
        this.areStatesRedrawn = false;
        return d;
    }

    protected void checkLogLikelihood(double d, double d2, double[] dArray, double d3, double[][] dArray2) {
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        if (variable == this.traitParameter) {
            if (n > this.dimTrait * this.numData * this.treeModel.getExternalNodeCount()) {
                throw new RuntimeException("Attempting to update an invalid index");
            }
            if (n != -1) {
                this.cacheHelper.setMeanCache(n, this.traitParameter.getValue(n));
            } else {
                for (int i = 0; i < this.traitParameter.getDimension(); ++i) {
                    this.cacheHelper.setMeanCache(i, this.traitParameter.getValue(i));
                }
            }
            this.likelihoodKnown = false;
        }
        super.handleVariableChangedEvent(variable, n, changeType);
    }

    protected static double computeWeightedAverageAndSumOfSquares(double[] dArray, double[] dArray2, double[][] dArray3, int n, double d) {
        double d2 = 0.0;
        for (int i = 0; i < n; ++i) {
            dArray2[i] = 0.0;
            for (int j = 0; j < n; ++j) {
                int n2 = i;
                dArray2[n2] = dArray2[n2] + dArray3[i][j] * dArray[j] * d;
            }
            d2 += dArray[i] * dArray2[i];
        }
        return d2;
    }

    private double sumLogRemainders() {
        double d = 0.0;
        for (double d2 : this.logRemainderDensityCache) {
            d += d2;
        }
        return d;
    }

    protected abstract double integrateLogLikelihoodAtRoot(double[] var1, double[] var2, double[][] var3, double[][] var4, double var5);

    @Override
    public void makeDirty() {
        super.makeDirty();
        this.areStatesRedrawn = false;
    }

    void postOrderTraverse(MutableTreeModel mutableTreeModel, NodeRef nodeRef, double[][] dArray, double d, boolean bl) {
        int n = nodeRef.getNumber();
        if (mutableTreeModel.isExternal(nodeRef)) {
            if (this.missingTraits.isCompletelyMissing(n)) {
                this.upperPrecisionCache[n] = 0.0;
                this.lowerPrecisionCache[n] = 0.0;
            } else {
                this.upperPrecisionCache[n] = this.cacheHelper.getUpperPrecFactor(nodeRef) * Math.pow(this.cacheHelper.getOUFactor(nodeRef), 2.0);
                this.lowerPrecisionCache[n] = Double.POSITIVE_INFINITY;
            }
            return;
        }
        NodeRef nodeRef2 = mutableTreeModel.getChild(nodeRef, 0);
        NodeRef nodeRef3 = mutableTreeModel.getChild(nodeRef, 1);
        this.postOrderTraverse(mutableTreeModel, nodeRef2, dArray, d, bl);
        this.postOrderTraverse(mutableTreeModel, nodeRef3, dArray, d, bl);
        int n2 = nodeRef2.getNumber();
        int n3 = nodeRef3.getNumber();
        int n4 = this.dim * n2;
        int n5 = this.dim * n3;
        int n6 = this.dim * n;
        double d2 = this.upperPrecisionCache[n2];
        double d3 = this.upperPrecisionCache[n3];
        double d4 = d2 + d3;
        double d5 = this.cacheHelper.getOUFactor(nodeRef2);
        double d6 = this.cacheHelper.getOUFactor(nodeRef3);
        this.doPeel(n, n6, n4, n5, d4, d2, d3, this.missingTraits, n, dArray, d, d5, d6, bl, nodeRef, nodeRef2, nodeRef3, true, false);
        if (this.nodeToClampMap != null && this.nodeToClampMap.containsKey(nodeRef)) {
            RestrictedPartials restrictedPartials = this.nodeToClampMap.get(nodeRef);
            int n7 = restrictedPartials.getIndex();
            int n8 = this.dim * n7;
            int n9 = this.dim * this.spareIndex;
            for (int i = 0; i < this.dim; ++i) {
                this.meanCache[n8 + i] = restrictedPartials.getPartial(i);
            }
            double d7 = this.lowerPrecisionCache[n];
            double d8 = restrictedPartials.getPriorSampleSize() / this.rescaleLength(1.0);
            double d9 = d7 + d8;
            this.doPeel(this.spareIndex, n9, n6, n8, d9, d7, d8, this.missingTraits, n7, dArray, d, 1.0, 1.0, bl, nodeRef, null, null, true, false);
            this.lowerPrecisionCache[n] = this.lowerPrecisionCache[this.spareIndex];
            this.upperPrecisionCache[n] = this.upperPrecisionCache[this.spareIndex];
            for (int i = 0; i < this.dim; ++i) {
                this.meanCache[n6 + i] = this.meanCache[n9 + i];
            }
        }
    }

    private void doPeel(int n, int n2, int n3, int n4, double d, double d2, double d3, MissingTraits missingTraits, int n5, double[][] dArray, double d4, double d5, double d6, boolean bl, NodeRef nodeRef, NodeRef nodeRef2, NodeRef nodeRef3, boolean bl2, boolean bl3) {
        this.lowerPrecisionCache[n] = d;
        this.cacheHelper.computeMeanCaches(n2, n3, n4, d, d2, d3, missingTraits, nodeRef, nodeRef2, nodeRef3);
        if (!this.treeModel.isRoot(nodeRef)) {
            double d7 = this.cacheHelper.getUpperPrecFactor(nodeRef);
            this.upperPrecisionCache[n] = Double.isInfinite(d7) ? d : d * d7 / (d + d7) * Math.pow(this.cacheHelper.getOUFactor(nodeRef), 2.0);
        }
        this.logRemainderDensityCache[n5] = 0.0;
        if (d2 != 0.0 && d3 != 0.0 && bl2) {
            this.incrementRemainderDensities(dArray, d4, n5, n2, n3, n4, d2, d3, d5, d6, bl);
        }
    }

    private void incrementRemainderDensities(double[][] dArray, double d, int n, int n2, int n3, int n4, double d2, double d3, double d4, double d5, boolean bl) {
        double d6 = d2 * d3 / (d2 + d3);
        if (bl) {
            this.incrementOuterProducts(n2, n3, n4, d2, d3);
        }
        for (int i = 0; i < this.numData; ++i) {
            int n5;
            double d7 = 0.0;
            double d8 = 0.0;
            double d9 = 0.0;
            for (n5 = 0; n5 < this.dimTrait; ++n5) {
                double d10 = this.cacheHelper.getCorrectedMeanCache()[n3 + i * this.dimTrait + n5] * d2;
                double d11 = this.cacheHelper.getCorrectedMeanCache()[n4 + i * this.dimTrait + n5] * d3;
                for (int j = 0; j < this.dimTrait; ++j) {
                    double d12 = this.cacheHelper.getCorrectedMeanCache()[n3 + i * this.dimTrait + j];
                    double d13 = this.cacheHelper.getCorrectedMeanCache()[n4 + i * this.dimTrait + j];
                    d7 += d10 * dArray[n5][j] * d12;
                    d8 += d11 * dArray[n5][j] * d13;
                    d9 += (d10 + d11) * dArray[n5][j] * this.cacheHelper.getMeanCache()[n2 + i * this.dimTrait + j];
                }
            }
            int n6 = n;
            this.logRemainderDensityCache[n6] = this.logRemainderDensityCache[n6] + ((double)(-this.dimTrait) * LOG_SQRT_2_PI + 0.5 * ((double)this.dimTrait * Math.log(d6) + d) - 0.5 * (d7 + d8 - d9) - (double)this.dimTrait * (Math.log(d4) + Math.log(d5)));
            if (!DEBUG || !(this.logRemainderDensityCache[n] > 100.0)) continue;
            System.err.println(n);
            System.err.println(this.logRemainderDensityCache[n]);
            System.err.println("rP = " + d6);
            System.err.println("p0 = " + d2);
            System.err.println("p1 = " + d3 + "\n");
            System.err.println(new Matrix(dArray));
            System.err.println(d7);
            System.err.println(d8);
            System.err.println(d9);
            for (n5 = 0; n5 < this.dimTrait; ++n5) {
                System.err.println("\t" + this.cacheHelper.getCorrectedMeanCache()[n3 + 0 * this.dimTrait + n5] + " " + this.cacheHelper.getCorrectedMeanCache()[n4 + 0 * this.dimTrait + n5]);
            }
            System.exit(-1);
        }
    }

    private void incrementOuterProducts(int n, int n2, int n3, double d, double d2) {
        double[] dArray = this.wishartStatistics.getScaleMatrix();
        if (d == 0.0 || d2 == 0.0) {
            System.err.println("ZERO PRECISION");
        }
        if (d < 1.0E-16 || d2 < 1.0E-16) {
            System.err.println("LOW PRECISION");
        }
        double d3 = d * d2 / (d + d2);
        for (int i = 0; i < this.numData; ++i) {
            for (int j = 0; j < this.dimTrait; ++j) {
                double d4 = this.cacheHelper.getCorrectedMeanCache()[n2 + i * this.dimTrait + j];
                double d5 = this.cacheHelper.getCorrectedMeanCache()[n3 + i * this.dimTrait + j];
                for (int k = 0; k < this.dimTrait; ++k) {
                    double d6 = this.cacheHelper.getCorrectedMeanCache()[n2 + i * this.dimTrait + k];
                    double d7 = this.cacheHelper.getCorrectedMeanCache()[n3 + i * this.dimTrait + k];
                    double d8 = (d4 - d5) * (d6 - d7) * d3;
                    int n4 = j * this.dimTrait + k;
                    dArray[n4] = dArray[n4] + d8;
                }
            }
        }
        this.wishartStatistics.incrementDf(this.numData);
    }

    @Override
    protected double[] getRootNodeTrait() {
        return this.getTraitForNode(this.treeModel, this.treeModel.getRoot(), this.traitName);
    }

    @Override
    public double[] getTraitForNode(Tree tree, NodeRef nodeRef, String string) {
        this.getLogLikelihood();
        if (!this.areStatesRedrawn) {
            this.redrawAncestralStates();
        }
        int n = nodeRef.getNumber();
        double[] dArray = new double[this.dim];
        System.arraycopy(this.drawnStates, n * this.dim, dArray, 0, this.dim);
        return dArray;
    }

    public void redrawAncestralStates() {
        double[][] dArray = this.diffusionModel.getPrecisionmatrix();
        double[][] dArray2 = new SymmetricMatrix(dArray).inverse().toComponents();
        this.preOrderTraverseSample(this.treeModel, this.treeModel.getRoot(), 0, dArray, dArray2);
        if (DEBUG) {
            System.err.println("all draws = " + new Vector(this.drawnStates));
        }
        this.areStatesRedrawn = true;
    }

    @Override
    public void storeState() {
        super.storeState();
        if (this.cacheBranches) {
            this.cacheHelper.store();
            System.arraycopy(this.upperPrecisionCache, 0, this.storedUpperPrecisionCache, 0, this.upperPrecisionCache.length);
            System.arraycopy(this.lowerPrecisionCache, 0, this.storedLowerPrecisionCache, 0, this.lowerPrecisionCache.length);
            System.arraycopy(this.logRemainderDensityCache, 0, this.storedLogRemainderDensityCache, 0, this.logRemainderDensityCache.length);
        }
    }

    @Override
    public void restoreState() {
        super.restoreState();
        if (this.cacheBranches) {
            this.cacheHelper.restore();
            double[] dArray = this.storedUpperPrecisionCache;
            this.storedUpperPrecisionCache = this.upperPrecisionCache;
            this.upperPrecisionCache = dArray;
            dArray = this.storedLowerPrecisionCache;
            this.storedLowerPrecisionCache = this.lowerPrecisionCache;
            this.lowerPrecisionCache = dArray;
            dArray = this.storedLogRemainderDensityCache;
            this.storedLogRemainderDensityCache = this.logRemainderDensityCache;
            this.logRemainderDensityCache = dArray;
        }
    }

    protected static double computeQuadraticProduct(double[] dArray, double[][] dArray2, double[] dArray3, int n) {
        double d = 0.0;
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                d += dArray[i] * dArray2[i][j] * dArray3[j];
            }
        }
        return d;
    }

    public static void computeWeightedAverage(double[] dArray, int n, double d, double[] dArray2, int n2, double d2, double[] dArray3, int n3, int n4) {
        for (int i = 0; i < n4; ++i) {
            dArray3[n3 + i] = (dArray[n + i] * d + dArray2[n2 + i] * d2) / (d + d2);
        }
    }

    protected void computeCorrectedWeightedAverage(int n, double d, NodeRef nodeRef, int n2, double d2, NodeRef nodeRef2, int n3, int n4, NodeRef nodeRef3) {
        int n5;
        double d3 = 1.0 / (d + d2);
        double[] dArray = !this.treeModel.isRoot(nodeRef3) ? this.getShiftForBranchLength(nodeRef3) : null;
        double[] dArray2 = this.getShiftForBranchLength(nodeRef);
        double[] dArray3 = this.getShiftForBranchLength(nodeRef2);
        if (this.treeModel.isExternal(nodeRef)) {
            for (n5 = 0; n5 < n4; ++n5) {
                this.correctedMeanCache[n + n5] = this.meanCache[n + n5] - dArray2[n5];
            }
        }
        if (this.treeModel.isExternal(nodeRef2)) {
            for (n5 = 0; n5 < n4; ++n5) {
                this.correctedMeanCache[n2 + n5] = this.meanCache[n2 + n5] - dArray3[n5];
            }
        }
        for (n5 = 0; n5 < n4; ++n5) {
            this.meanCache[n3 + n5] = (this.correctedMeanCache[n + n5] * d + this.correctedMeanCache[n2 + n5] * d2) * d3;
            this.correctedMeanCache[n3 + n5] = !this.treeModel.isRoot(nodeRef3) ? this.meanCache[n3 + n5] - dArray[n5] : this.meanCache[n3 + n5];
        }
    }

    protected void computeCorrectedOUWeightedAverage(int n, double d, NodeRef nodeRef, int n2, double d2, NodeRef nodeRef2, int n3, int n4, NodeRef nodeRef3) {
        int n5;
        double d3;
        double[] dArray;
        double d4 = 1.0 / (d + d2);
        if (!this.treeModel.isRoot(nodeRef3)) {
            dArray = this.getOptimalValue(nodeRef3);
            d3 = this.getTimeScaledSelection(nodeRef3);
        } else {
            dArray = null;
            d3 = 1.0;
        }
        double[] dArray2 = this.getOptimalValue(nodeRef);
        double[] dArray3 = this.getOptimalValue(nodeRef2);
        double d5 = this.getTimeScaledSelection(nodeRef);
        double d6 = this.getTimeScaledSelection(nodeRef2);
        if (this.treeModel.isExternal(nodeRef)) {
            for (n5 = 0; n5 < n4; ++n5) {
                this.correctedMeanCache[n + n5] = Math.exp(d5) * this.meanCache[n + n5] - (Math.exp(d5) - 1.0) * dArray2[n5];
            }
        }
        if (this.treeModel.isExternal(nodeRef2)) {
            for (n5 = 0; n5 < n4; ++n5) {
                this.correctedMeanCache[n2 + n5] = Math.exp(d6) * this.meanCache[n2 + n5] - (Math.exp(d6) - 1.0) * dArray3[n5];
            }
        }
        for (n5 = 0; n5 < n4; ++n5) {
            this.meanCache[n3 + n5] = (this.correctedMeanCache[n + n5] * d + this.correctedMeanCache[n2 + n5] * d2) * d4;
            this.correctedMeanCache[n3 + n5] = !this.treeModel.isRoot(nodeRef3) ? Math.exp(d3) * this.meanCache[n3 + n5] - (Math.exp(d3) - 1.0) * dArray[n5] : this.meanCache[n3 + n5];
        }
    }

    protected abstract double[][] computeMarginalRootMeanAndVariance(double[] var1, double[][] var2, double[][] var3, double var4);

    private void preOrderTraverseSample(MutableTreeModel mutableTreeModel, NodeRef nodeRef, int n, double[][] dArray, double[][] dArray2) {
        int n2 = nodeRef.getNumber();
        if (mutableTreeModel.isRoot(nodeRef)) {
            double[] dArray3 = new double[this.dimTrait];
            int n3 = mutableTreeModel.getRoot().getNumber();
            double d = this.lowerPrecisionCache[n3];
            for (int i = 0; i < this.numData; ++i) {
                System.arraycopy(this.cacheHelper.getMeanCache(), n2 * this.dim + i * this.dimTrait, dArray3, 0, this.dimTrait);
                double[][] dArray4 = this.computeMarginalRootMeanAndVariance(dArray3, dArray, dArray2, d);
                double[] dArray5 = MultivariateNormalDistribution.nextMultivariateNormalVariance(dArray3, dArray4);
                if (DEBUG_PREORDER) {
                    Arrays.fill(dArray5, 1.0);
                }
                System.arraycopy(dArray5, 0, this.drawnStates, n3 * this.dim + i * this.dimTrait, this.dimTrait);
                if (!DEBUG) continue;
                System.err.println("Root mean: " + new Vector(dArray3));
                System.err.println("Root var : " + new Matrix(dArray4));
                System.err.println("Root draw: " + new Vector(dArray5));
            }
        } else if (!this.missingTraits.isCompletelyMissing(n2) && !this.missingTraits.isPartiallyMissing(n2)) {
            System.arraycopy(this.cacheHelper.getMeanCache(), n2 * this.dim, this.drawnStates, n2 * this.dim, this.dim);
        } else {
            if (this.missingTraits.isPartiallyMissing(n2)) {
                throw new RuntimeException("Partially missing values are not yet implemented");
            }
            double d = 1.0 / this.getRescaledBranchLengthForPrecision(nodeRef);
            double d2 = this.lowerPrecisionCache[n2];
            double d3 = d2 + d;
            double[] dArray6 = this.Ay;
            double[][] dArray7 = this.tmpM;
            for (int i = 0; i < this.numData; ++i) {
                double[] dArray8;
                int n4 = n * this.dim + i * this.dimTrait;
                int n5 = n2 * this.dim + i * this.dimTrait;
                if (DEBUG) {
                    dArray8 = new double[this.dimTrait];
                    System.arraycopy(this.drawnStates, n4, dArray8, 0, this.dimTrait);
                    System.err.println("Parent draw: " + new Vector(dArray8));
                    if (dArray8[0] != this.drawnStates[n4]) {
                        throw new RuntimeException("Error in setting indices");
                    }
                }
                for (int j = 0; j < this.dimTrait; ++j) {
                    dArray6[j] = ((this.drawnStates[n4 + j] + this.cacheHelper.getShift(nodeRef)[j]) * d + this.cacheHelper.getMeanCache()[n5 + j] * d2) / d3;
                    for (int k = 0; k < this.dimTrait; ++k) {
                        dArray7[j][k] = dArray2[j][k] / d3;
                    }
                }
                dArray8 = MultivariateNormalDistribution.nextMultivariateNormalVariance(dArray6, dArray7);
                System.arraycopy(dArray8, 0, this.drawnStates, n5, this.dimTrait);
                if (!DEBUG) continue;
                System.err.println("Int prec: " + d3);
                System.err.println("Int mean: " + new Vector(dArray6));
                System.err.println("Int var : " + new Matrix(dArray7));
                System.err.println("Int draw: " + new Vector(dArray8));
                System.err.println("");
            }
        }
        if (this.peel() && !mutableTreeModel.isExternal(nodeRef)) {
            this.preOrderTraverseSample(mutableTreeModel, mutableTreeModel.getChild(nodeRef, 0), n2, dArray, dArray2);
            this.preOrderTraverseSample(mutableTreeModel, mutableTreeModel.getChild(nodeRef, 1), n2, dArray, dArray2);
        }
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (this.driftModels != null && this.driftModels.contains(model)) {
            if (this.cacheBranches) {
                this.updateAllNodes();
            } else {
                this.likelihoodKnown = false;
            }
        } else if (this.optimalValues != null && this.optimalValues.contains(model)) {
            if (this.cacheBranches) {
                this.updateAllNodes();
            } else {
                this.likelihoodKnown = false;
            }
        } else if (this.strengthOfSelection != null) {
            if (this.cacheBranches) {
                this.updateAllNodes();
            } else {
                this.likelihoodKnown = false;
            }
        } else {
            super.handleModelChangedEvent(model, object, n);
        }
    }

    protected boolean peel() {
        return true;
    }

    @Override
    public LogColumn[] getColumns() {
        return new LogColumn[]{new AbstractModelLikelihood.LikelihoodColumn(this, this.getId())};
    }

    private CacheHelper createCacheHelper(IntegratedDiffusionType integratedDiffusionType, int n, boolean bl) {
        CacheHelper cacheHelper = null;
        switch (integratedDiffusionType) {
            case PLAIN: {
                cacheHelper = new CacheHelper(n, bl);
                break;
            }
            case SCALED: {
                cacheHelper = new CacheHelper(n, bl);
                break;
            }
            case DRIFT: {
                cacheHelper = new DriftCacheHelper(n, bl);
                break;
            }
            case OU: {
                cacheHelper = new OUCacheHelper(n, bl);
            }
        }
        return cacheHelper;
    }

    @Override
    protected void addRestrictedPartials(RestrictedPartials restrictedPartials) {
        if (this.clampList == null) {
            this.clampList = new HashMap<BitSet, RestrictedPartials>();
        }
        this.clampList.put(restrictedPartials.getTipBitSet(), restrictedPartials);
        this.addModel(restrictedPartials);
        System.err.println("Added a CLAMP!");
    }

    public static enum IntegratedDiffusionType {
        PLAIN,
        SCALED,
        DRIFT,
        OU;

    }

    class OUCacheHelper
    extends CacheHelper {
        public OUCacheHelper(int n, boolean bl) {
            super(n, bl);
            IntegratedMultivariateTraitLikelihood.this.correctedMeanCache = new double[n];
        }

        @Override
        public double[] getCorrectedMeanCache() {
            return IntegratedMultivariateTraitLikelihood.this.correctedMeanCache;
        }

        @Override
        public double getOUFactor(NodeRef nodeRef) {
            return Math.exp(-IntegratedMultivariateTraitLikelihood.this.getTimeScaledSelection(nodeRef));
        }

        @Override
        public double getUpperPrecFactor(NodeRef nodeRef) {
            return 2.0 * IntegratedMultivariateTraitLikelihood.this.strengthOfSelection.getBranchRate(IntegratedMultivariateTraitLikelihood.this.treeModel, nodeRef) / (1.0 - Math.exp(-2.0 * IntegratedMultivariateTraitLikelihood.this.getTimeScaledSelection(nodeRef)));
        }

        @Override
        public void setTipMeans(double[] dArray, int n, int n2, NodeRef nodeRef) {
            System.arraycopy(dArray, 0, IntegratedMultivariateTraitLikelihood.this.meanCache, n * n2, n);
        }

        @Override
        public void computeMeanCaches(int n, int n2, int n3, double d, double d2, double d3, MissingTraits missingTraits, NodeRef nodeRef, NodeRef nodeRef2, NodeRef nodeRef3) {
            if (d == 0.0) {
                System.arraycopy(IntegratedMultivariateTraitLikelihood.this.zeroDimVector, 0, IntegratedMultivariateTraitLikelihood.this.meanCache, n, IntegratedMultivariateTraitLikelihood.this.dim);
            } else {
                IntegratedMultivariateTraitLikelihood.this.computeCorrectedOUWeightedAverage(n2, d2, nodeRef2, n3, d3, nodeRef3, n, IntegratedMultivariateTraitLikelihood.this.dim, nodeRef);
            }
        }
    }

    class DriftCacheHelper
    extends CacheHelper {
        public DriftCacheHelper(int n, boolean bl) {
            super(n, bl);
            IntegratedMultivariateTraitLikelihood.this.correctedMeanCache = new double[n];
        }

        @Override
        public double[] getShift(NodeRef nodeRef) {
            return IntegratedMultivariateTraitLikelihood.this.getShiftForBranchLength(nodeRef);
        }

        @Override
        public double[] getCorrectedMeanCache() {
            return IntegratedMultivariateTraitLikelihood.this.correctedMeanCache;
        }

        @Override
        public double getOUFactor(NodeRef nodeRef) {
            return 1.0;
        }

        @Override
        public double getUpperPrecFactor(NodeRef nodeRef) {
            return 1.0 / IntegratedMultivariateTraitLikelihood.this.getRescaledBranchLengthForPrecision(nodeRef);
        }

        @Override
        public void setTipMeans(double[] dArray, int n, int n2, NodeRef nodeRef) {
            System.arraycopy(dArray, 0, IntegratedMultivariateTraitLikelihood.this.meanCache, n * n2, n);
        }

        @Override
        public void computeMeanCaches(int n, int n2, int n3, double d, double d2, double d3, MissingTraits missingTraits, NodeRef nodeRef, NodeRef nodeRef2, NodeRef nodeRef3) {
            if (d == 0.0) {
                System.arraycopy(IntegratedMultivariateTraitLikelihood.this.zeroDimVector, 0, IntegratedMultivariateTraitLikelihood.this.meanCache, n, IntegratedMultivariateTraitLikelihood.this.dim);
            } else {
                IntegratedMultivariateTraitLikelihood.this.computeCorrectedWeightedAverage(n2, d2, nodeRef2, n3, d3, nodeRef3, n, IntegratedMultivariateTraitLikelihood.this.dim, nodeRef);
            }
        }
    }

    class StandarizedCacheHelper
    extends CacheHelper {
        private final int dim;
        private final int nodeCount;

        public StandarizedCacheHelper(int n, int n2, boolean bl) {
            super(n * n2, bl);
            this.dim = n;
            this.nodeCount = n2;
        }

        @Override
        public void setTipMeans(double[] dArray, int n, int n2, NodeRef nodeRef) {
            for (int i = 0; i < n; ++i) {
                this.setMeanCache(n * n2 + i, dArray[i]);
            }
        }

        @Override
        public void setTipMeans(double[] dArray, int n, int n2) {
            for (int i = 0; i < n; ++i) {
                this.setMeanCache(n * n2 + i, dArray[i]);
            }
        }

        @Override
        public void setMeanCache(int n, double d) {
            int n2 = n % this.dim;
            IntegratedMultivariateTraitLikelihood.this.meanCache[n] = d;
        }
    }

    class CacheHelper {
        protected boolean cacheBranches;

        public CacheHelper(int n, boolean bl) {
            IntegratedMultivariateTraitLikelihood.this.meanCache = new double[n];
            this.cacheBranches = bl;
            if (bl) {
                IntegratedMultivariateTraitLikelihood.this.storedMeanCache = new double[n];
            }
        }

        public double[] getShift(NodeRef nodeRef) {
            double[] dArray = new double[IntegratedMultivariateTraitLikelihood.this.dimTrait * IntegratedMultivariateTraitLikelihood.this.numData];
            for (int i = 0; i < IntegratedMultivariateTraitLikelihood.this.dim; ++i) {
                dArray[i] = 0.0;
            }
            return dArray;
        }

        public double[] getMeanCache() {
            return IntegratedMultivariateTraitLikelihood.this.meanCache;
        }

        public double[] getCorrectedMeanCache() {
            return IntegratedMultivariateTraitLikelihood.this.meanCache;
        }

        public void store() {
            if (IntegratedMultivariateTraitLikelihood.this.storedMeanCache.length != IntegratedMultivariateTraitLikelihood.this.meanCache.length) {
                IntegratedMultivariateTraitLikelihood.this.storedMeanCache = new double[IntegratedMultivariateTraitLikelihood.this.meanCache.length];
            }
            System.arraycopy(IntegratedMultivariateTraitLikelihood.this.meanCache, 0, IntegratedMultivariateTraitLikelihood.this.storedMeanCache, 0, IntegratedMultivariateTraitLikelihood.this.meanCache.length);
        }

        public void restore() {
            double[] dArray = IntegratedMultivariateTraitLikelihood.this.storedMeanCache;
            IntegratedMultivariateTraitLikelihood.this.storedMeanCache = IntegratedMultivariateTraitLikelihood.this.meanCache;
            IntegratedMultivariateTraitLikelihood.this.meanCache = dArray;
        }

        public double getOUFactor(NodeRef nodeRef) {
            return 1.0;
        }

        public double getUpperPrecFactor(NodeRef nodeRef) {
            return 1.0 / IntegratedMultivariateTraitLikelihood.this.getRescaledBranchLengthForPrecision(nodeRef);
        }

        public void computeMeanCaches(int n, int n2, int n3, double d, double d2, double d3, MissingTraits missingTraits, NodeRef nodeRef, NodeRef nodeRef2, NodeRef nodeRef3) {
            if (d == 0.0) {
                System.arraycopy(IntegratedMultivariateTraitLikelihood.this.zeroDimVector, 0, IntegratedMultivariateTraitLikelihood.this.meanCache, n, IntegratedMultivariateTraitLikelihood.this.dim);
            } else {
                missingTraits.computeWeightedAverage(IntegratedMultivariateTraitLikelihood.this.meanCache, n2, d2, n3, d3, n, IntegratedMultivariateTraitLikelihood.this.dim);
            }
        }

        public void setTipMeans(double[] dArray, int n, int n2, NodeRef nodeRef) {
            System.arraycopy(dArray, 0, IntegratedMultivariateTraitLikelihood.this.meanCache, n * n2, n);
        }

        public void setTipMeans(double[] dArray, int n, int n2) {
            System.arraycopy(dArray, 0, IntegratedMultivariateTraitLikelihood.this.meanCache, n * n2, n);
        }

        public void copyToMeanCache(double[] dArray, int n, int n2) {
            System.arraycopy(dArray, 0, IntegratedMultivariateTraitLikelihood.this.meanCache, n, n2);
        }

        public void setMeanCache(int n, double d) {
            IntegratedMultivariateTraitLikelihood.this.meanCache[n] = d;
        }
    }
}

