/*
 * Decompiled with CFR 0.152.
 */
package core.model;

import core.lattice.Lattice;
import core.lattice.LatticeNode;
import core.model.DecomposableModel;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.eclipse.recommenders.jayes.BayesNet;
import org.eclipse.recommenders.jayes.BayesNode;
import org.eclipse.recommenders.jayes.inference.IBayesInferer;
import org.eclipse.recommenders.jayes.inference.jtree.JunctionTreeAlgorithm;
import org.jgrapht.experimental.dag.DirectedAcyclicGraph;
import org.jgrapht.graph.DefaultEdge;

public class Inference {
    private DirectedAcyclicGraph<Integer, DefaultEdge> bn;
    private BayesNet jbn;
    private Map<Integer, BayesNode> jnodes;
    private Map<Integer, String> nodeNames;
    private Map<String, Integer> nodeIDFromName;
    private Map<BayesNode, Integer> nodesNumber;
    private IBayesInferer inferer;
    Map<BayesNode, String> evidence;

    public Inference(DecomposableModel model, String[] variableNames, String[][] outcomes) {
        try {
            this.bn = model.getBayesianNetwork();
        }
        catch (DirectedAcyclicGraph.CycleFoundException e) {
            e.printStackTrace();
        }
        this.jnodes = new HashMap<Integer, BayesNode>();
        this.nodesNumber = new HashMap<BayesNode, Integer>();
        this.nodeNames = new HashMap<Integer, String>();
        this.nodeIDFromName = new HashMap<String, Integer>();
        this.jbn = new BayesNet();
        for (Integer nodeID : this.bn.vertexSet()) {
            String name = variableNames[nodeID];
            BayesNode node = this.jbn.createNode(name);
            node.addOutcomes(outcomes[nodeID]);
            this.nodeNames.put(nodeID, name);
            this.nodeIDFromName.put(name, nodeID);
            this.nodesNumber.put(node, nodeID);
            this.jnodes.put(nodeID, node);
        }
        for (Integer nodeID : this.bn.vertexSet()) {
            BayesNode node = this.jnodes.get(nodeID);
            ArrayList<BayesNode> parents = new ArrayList<BayesNode>();
            for (DefaultEdge e : this.bn.edgesOf((Object)nodeID)) {
                if (this.bn.getEdgeTarget((Object)e) != nodeID) continue;
                BayesNode oneParent = this.jnodes.get(this.bn.getEdgeSource((Object)e));
                parents.add(oneParent);
            }
            if (parents.isEmpty()) continue;
            node.setParents(parents);
        }
    }

    public void setProbabilities(Lattice lattice) {
        for (Integer nodeID : this.jnodes.keySet()) {
            BayesNode n = this.jnodes.get(nodeID);
            List parents = n.getParents();
            ArrayList<BayesNode> parentsAndChild = new ArrayList<BayesNode>(parents);
            parentsAndChild.add(n);
            int nbParents = parents.size();
            BitSet numbers = new BitSet();
            numbers.set(nodeID);
            int[] sizes = new int[nbParents];
            int nbRowsInCPT = 1;
            int i = 0;
            while (i < parents.size()) {
                BayesNode parent = (BayesNode)parents.get(i);
                numbers.set(this.nodesNumber.get(parent));
                sizes[i] = ((BayesNode)parents.get(i)).getOutcomeCount();
                nbRowsInCPT *= sizes[i];
                ++i;
            }
            LatticeNode latticeNode = lattice.getNode(numbers);
            HashMap<Integer, Integer> fromNodeIDToPositionInSortedTable = new HashMap<Integer, Integer>();
            Integer[] variablesNumbers = new Integer[numbers.cardinality()];
            int current = 0;
            int i2 = numbers.nextSetBit(0);
            while (i2 >= 0) {
                variablesNumbers[current] = i2;
                ++current;
                i2 = numbers.nextSetBit(i2 + 1);
            }
            i2 = 0;
            while (i2 < variablesNumbers.length) {
                fromNodeIDToPositionInSortedTable.put(variablesNumbers[i2], i2);
                ++i2;
            }
            int[] counts = new int[nbRowsInCPT * n.getOutcomeCount()];
            int[] indexes4lattice = new int[parentsAndChild.size()];
            int[] indexes4Jayes = new int[parentsAndChild.size()];
            int c = 0;
            while (c < counts.length) {
                int count;
                int index = c;
                int i3 = indexes4Jayes.length - 1;
                while (i3 > 0) {
                    BayesNode associatedNode = (BayesNode)parentsAndChild.get(i3);
                    int dim = associatedNode.getOutcomeCount();
                    indexes4Jayes[i3] = index % dim;
                    index /= dim;
                    --i3;
                }
                indexes4Jayes[0] = index;
                i3 = 0;
                while (i3 < indexes4Jayes.length) {
                    BayesNode nodeInPositionI = (BayesNode)parentsAndChild.get(i3);
                    int nodeInPositionIID = this.nodesNumber.get(nodeInPositionI);
                    int indexInSortedTable = (Integer)fromNodeIDToPositionInSortedTable.get(nodeInPositionIID);
                    indexes4lattice[indexInSortedTable] = indexes4Jayes[i3];
                    ++i3;
                }
                counts[c] = count = latticeNode.getMatrixCell(indexes4lattice);
                ++c;
            }
            double mTerm = 0.5;
            double[] probas1D = new double[n.getOutcomeCount() * nbRowsInCPT];
            int s = 0;
            while (s < probas1D.length) {
                double sumOfCounts = 0.0;
                int j = 0;
                while (j < n.getOutcomeCount()) {
                    sumOfCounts += (double)counts[s + j] + mTerm;
                    ++j;
                }
                j = 0;
                while (j < n.getOutcomeCount()) {
                    probas1D[s + j] = ((double)counts[s + j] + mTerm) / sumOfCounts;
                    ++j;
                }
                s += n.getOutcomeCount();
            }
            n.setProbabilities(probas1D);
        }
        System.out.println("Compiling network for inference...");
        this.inferer = new JunctionTreeAlgorithm();
        this.inferer.setNetwork(this.jbn);
        this.evidence = new HashMap<BayesNode, String>();
        System.out.println("Compiled.");
    }

    public void addEvidence(int nodeID, String outcome) {
        this.addEvidence(this.jnodes.get(nodeID), outcome);
    }

    public void addEvidence(String nodeName, String outcome) {
        this.addEvidence(this.jnodes.get(this.nodeIDFromName.get(nodeName)), outcome);
    }

    protected void addEvidence(BayesNode node, String outcome) {
        this.evidence.put(node, outcome);
    }

    public void recordEvidence() {
        this.inferer.setEvidence(this.evidence);
    }

    public void clearEvidences() {
        this.evidence = new HashMap<BayesNode, String>();
        this.inferer.setEvidence(this.evidence);
    }

    public double[] getBelief(int nodeID) {
        return this.getBelief(this.jnodes.get(nodeID));
    }

    public double[] getBelief(BayesNode n) {
        return this.inferer.getBeliefs(n);
    }

    public double[] getBelief(String nodeName) {
        Integer nodeID = this.nodeIDFromName.get(nodeName);
        if (nodeID == null) {
            System.err.println("Cannot find a node named '" + nodeName + "'.");
            return null;
        }
        BayesNode node = this.jnodes.get(nodeID);
        if (node == null) {
            System.err.println("Cannot find a node named '" + nodeName + "'.");
            return null;
        }
        return this.getBelief(node);
    }

    public void exportDSC(File file, Lattice lattice) throws FileNotFoundException {
        ArrayList<BayesNode> parentsAndChild;
        List parents;
        BayesNode n;
        PrintWriter out = new PrintWriter(new BufferedOutputStream(new FileOutputStream(file)));
        out.println("belief network \"net\"");
        for (Integer nodeID : this.jnodes.keySet()) {
            n = this.jnodes.get(nodeID);
            out.print("node " + n.getName() + " {\n" + "\ttype : discrete [ " + n.getOutcomeCount() + " ] = { ");
            out.print("\"" + n.getOutcomeName(0) + "\"");
            int i = 1;
            while (i < n.getOutcomeCount()) {
                out.print(",\"" + n.getOutcomeName(i) + "\"");
                ++i;
            }
            out.println("},\n}");
            parents = n.getParents();
            parentsAndChild = new ArrayList(parents);
            parentsAndChild.add(n);
        }
        for (Integer nodeID : this.jnodes.keySet()) {
            n = this.jnodes.get(nodeID);
            parents = n.getParents();
            parentsAndChild = new ArrayList<BayesNode>(parents);
            parentsAndChild.add(n);
            out.print("probability ( " + n.getName());
            if (!parents.isEmpty()) {
                out.print(" | " + ((BayesNode)parents.get(0)).getName());
                int p = 1;
                while (p < parents.size()) {
                    out.print(" , " + ((BayesNode)parents.get(p)).getName());
                    ++p;
                }
            }
            out.println(" ) {");
            int nbParents = parents.size();
            BitSet numbers = new BitSet();
            numbers.set(nodeID);
            int[] sizes = new int[nbParents];
            int nbRowsInCPT = 1;
            int i = 0;
            while (i < parents.size()) {
                BayesNode parent = (BayesNode)parents.get(i);
                numbers.set(this.nodesNumber.get(parent));
                sizes[i] = ((BayesNode)parents.get(i)).getOutcomeCount();
                nbRowsInCPT *= sizes[i];
                ++i;
            }
            LatticeNode latticeNode = lattice.getNode(numbers);
            HashMap<Integer, Integer> fromNodeIDToPositionInSortedTable = new HashMap<Integer, Integer>();
            Integer[] variablesNumbers = new Integer[numbers.cardinality()];
            int current = 0;
            int i2 = numbers.nextSetBit(0);
            while (i2 >= 0) {
                variablesNumbers[current] = i2;
                ++current;
                i2 = numbers.nextSetBit(i2 + 1);
            }
            i2 = 0;
            while (i2 < variablesNumbers.length) {
                fromNodeIDToPositionInSortedTable.put(variablesNumbers[i2], i2);
                ++i2;
            }
            int[] counts = new int[nbRowsInCPT * n.getOutcomeCount()];
            int[] indexes4lattice = new int[parentsAndChild.size()];
            int[] indexes4Jayes = new int[parentsAndChild.size()];
            int c = 0;
            while (c < counts.length) {
                int count;
                int index = c;
                int i3 = indexes4Jayes.length - 1;
                while (i3 > 0) {
                    BayesNode associatedNode = (BayesNode)parentsAndChild.get(i3);
                    int dim = associatedNode.getOutcomeCount();
                    indexes4Jayes[i3] = index % dim;
                    index /= dim;
                    --i3;
                }
                indexes4Jayes[0] = index;
                i3 = 0;
                while (i3 < indexes4Jayes.length) {
                    BayesNode nodeInPositionI = (BayesNode)parentsAndChild.get(i3);
                    int nodeInPositionIID = this.nodesNumber.get(nodeInPositionI);
                    int indexInSortedTable = (Integer)fromNodeIDToPositionInSortedTable.get(nodeInPositionIID);
                    indexes4lattice[indexInSortedTable] = indexes4Jayes[i3];
                    ++i3;
                }
                counts[c] = count = latticeNode.getMatrixCell(indexes4lattice);
                ++c;
            }
            double mTerm = 0.5;
            if (parents.isEmpty()) {
                double sumOfCounts = 0.0;
                int j = 0;
                while (j < n.getOutcomeCount()) {
                    sumOfCounts += (double)counts[j] + mTerm;
                    ++j;
                }
                double p = ((double)counts[0] + mTerm) / sumOfCounts;
                out.print("\t " + p);
                int j2 = 1;
                while (j2 < n.getOutcomeCount()) {
                    p = ((double)counts[j2] + mTerm) / sumOfCounts;
                    out.print(", " + p);
                    ++j2;
                }
                out.println(";");
            } else {
                int[] indexes4Parents = new int[parents.size()];
                int r = 0;
                while (r < nbRowsInCPT) {
                    int index = r;
                    int i4 = indexes4Parents.length - 1;
                    while (i4 > 0) {
                        BayesNode associatedNode = (BayesNode)parents.get(i4);
                        int dim = associatedNode.getOutcomeCount();
                        indexes4Parents[i4] = index % dim;
                        index /= dim;
                        --i4;
                    }
                    indexes4Parents[0] = index;
                    out.print("\t(" + indexes4Parents[0]);
                    int p = 1;
                    while (p < indexes4Parents.length) {
                        out.print("," + indexes4Parents[p]);
                        ++p;
                    }
                    out.print("): ");
                    double sumOfCounts = 0.0;
                    int j = 0;
                    while (j < n.getOutcomeCount()) {
                        sumOfCounts += (double)counts[r * n.getOutcomeCount() + j] + mTerm;
                        ++j;
                    }
                    double p2 = ((double)counts[r * n.getOutcomeCount() + 0] + mTerm) / sumOfCounts;
                    out.print(p2);
                    int j3 = 1;
                    while (j3 < n.getOutcomeCount()) {
                        p2 = ((double)counts[r * n.getOutcomeCount() + j3] + mTerm) / sumOfCounts;
                        out.print(", " + p2);
                        ++j3;
                    }
                    out.println(";");
                    ++r;
                }
            }
            out.println("}");
        }
        out.flush();
        out.close();
    }
}

