/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.mcmc.nuts;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.ProbabilisticModelWithGradient;
import io.improbable.keanu.algorithms.Statistics;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.mcmc.SamplingAlgorithm;
import io.improbable.keanu.algorithms.mcmc.nuts.LeapfrogIntegrator;
import io.improbable.keanu.algorithms.mcmc.nuts.LeapfrogState;
import io.improbable.keanu.algorithms.mcmc.nuts.NUTS;
import io.improbable.keanu.algorithms.mcmc.nuts.Proposal;
import io.improbable.keanu.algorithms.mcmc.nuts.VariableValues;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.List;
import java.util.Map;

class Tree {
    private final ProbabilisticModelWithGradient logProbGradientCalculator;
    private final LeapfrogIntegrator leapfrogIntegrator;
    private final List<? extends Variable> sampleFromVariables;
    private final KeanuRandom random;
    private LeapfrogState forward;
    private LeapfrogState backward;
    private Proposal proposal;
    private Map<VariableReference, DoubleTensor> sumMomentum;
    private final double startEnergy;
    private final double maxEnergyChange;
    private double logSumWeight;
    private double sumMetropolisAcceptanceProbability;
    private int treeSize;
    private int treeHeight;
    private boolean diverged;
    private boolean uTurned;

    public Tree(LeapfrogState startState, Proposal proposal, double maxEnergyChange, ProbabilisticModelWithGradient logProbGradientCalculator, LeapfrogIntegrator leapfrogIntegrator, List<? extends Variable> sampleFromVariables, KeanuRandom random) {
        this.forward = startState;
        this.backward = startState;
        this.proposal = proposal;
        this.maxEnergyChange = maxEnergyChange;
        this.logProbGradientCalculator = logProbGradientCalculator;
        this.leapfrogIntegrator = leapfrogIntegrator;
        this.sampleFromVariables = sampleFromVariables;
        this.random = random;
        this.sumMomentum = startState.getMomentum();
        this.startEnergy = startState.getEnergy();
        this.logSumWeight = 0.0;
        this.sumMetropolisAcceptanceProbability = 0.0;
        this.treeSize = 0;
        this.treeHeight = 0;
        this.diverged = false;
        this.uTurned = false;
    }

    public void grow(int buildDirection, double timeStep) {
        SubTree otherHalfTree = this.buildTree(buildDirection == -1 ? this.backward : this.forward, buildDirection, this.treeHeight, timeStep);
        if (buildDirection == -1) {
            this.backward = otherHalfTree.backward;
        } else {
            this.forward = otherHalfTree.forward;
        }
        this.sumMetropolisAcceptanceProbability += otherHalfTree.sumMetropolisAcceptanceProbability;
        this.treeSize += otherHalfTree.treeSize;
        if (otherHalfTree.shouldContinue()) {
            if (Tree.acceptOtherProposalWithProbability(otherHalfTree.getLogSumWeight() - this.logSumWeight, this.random)) {
                this.proposal = otherHalfTree.proposal;
            }
            this.logSumWeight = Tree.logSumExp(this.logSumWeight, otherHalfTree.logSumWeight);
            this.sumMomentum = VariableValues.add(this.sumMomentum, otherHalfTree.sumMomentum);
        }
        this.diverged = otherHalfTree.diverged;
        if (!this.diverged) {
            this.uTurned = otherHalfTree.uTurned || Tree.isUTurning(this.forward.getVelocity(), this.backward.getVelocity(), this.sumMomentum);
        }
        ++this.treeHeight;
    }

    private SubTree buildTree(LeapfrogState buildFrom, int buildDirection, int treeHeight, double timeStep) {
        if (treeHeight == 0) {
            return this.treeBuilderBaseCase(buildFrom, buildDirection, timeStep);
        }
        SubTree subTree = this.buildTree(buildFrom, buildDirection, treeHeight - 1, timeStep);
        if (subTree.shouldContinue()) {
            SubTree extendedSubTree = this.buildTree(buildDirection == -1 ? subTree.backward : subTree.forward, buildDirection, treeHeight - 1, timeStep);
            if (buildDirection == -1) {
                subTree.backward = extendedSubTree.backward;
            } else {
                subTree.forward = extendedSubTree.forward;
            }
            subTree.diverged = extendedSubTree.diverged;
            subTree.uTurned = extendedSubTree.uTurned;
            if (extendedSubTree.shouldContinue()) {
                subTree.sumMomentum = VariableValues.add(subTree.sumMomentum, extendedSubTree.sumMomentum);
                subTree.uTurned = Tree.isUTurning(subTree.forward.getVelocity(), subTree.backward.getVelocity(), subTree.sumMomentum);
                double totalLogSumWeight = Tree.logSumExp(subTree.logSumWeight, extendedSubTree.logSumWeight);
                subTree.logSumWeight = totalLogSumWeight;
                if (Tree.acceptOtherProposalWithProbability(extendedSubTree.logSumWeight - totalLogSumWeight, this.random)) {
                    subTree.proposal = extendedSubTree.proposal;
                }
            }
            SubTree subTree2 = subTree;
            subTree2.sumMetropolisAcceptanceProbability = subTree2.sumMetropolisAcceptanceProbability + extendedSubTree.sumMetropolisAcceptanceProbability;
            subTree2 = subTree;
            subTree2.treeSize = subTree2.treeSize + extendedSubTree.treeSize;
        }
        return subTree;
    }

    public static double logSumExp(double a, double b) {
        double max = Math.max(a, b);
        return max + Math.log(Math.exp(a - max) + Math.exp(b - max));
    }

    private SubTree treeBuilderBaseCase(LeapfrogState leapfrogState, int buildDirection, double timeStep) {
        boolean isDivergent;
        LeapfrogState leapfrogStateAfterStep = this.leapfrogIntegrator.step(leapfrogState, this.logProbGradientCalculator, timeStep * (double)buildDirection);
        double energyAfterStep = leapfrogStateAfterStep.getEnergy();
        double energyChange = energyAfterStep - this.startEnergy;
        boolean bl = isDivergent = Math.abs(energyChange) >= this.maxEnergyChange || Double.isNaN(energyChange);
        if (isDivergent) {
            return new SubTree(leapfrogStateAfterStep, leapfrogStateAfterStep, leapfrogStateAfterStep.getMomentum(), null, Double.NEGATIVE_INFINITY, true, false, 0.0, 1);
        }
        double logSumWeight = -energyChange;
        double metropolisAcceptanceProbability = Math.min(1.0, Math.exp(logSumWeight));
        Map<VariableReference, ?> sample = SamplingAlgorithm.takeSample(this.sampleFromVariables);
        Proposal proposal = new Proposal(leapfrogStateAfterStep.getPosition(), leapfrogStateAfterStep.getGradient(), sample, leapfrogStateAfterStep.getLogProb());
        return new SubTree(leapfrogStateAfterStep, leapfrogStateAfterStep, leapfrogStateAfterStep.getMomentum(), proposal, logSumWeight, false, false, metropolisAcceptanceProbability, 1);
    }

    private static boolean acceptOtherProposalWithProbability(double probability, KeanuRandom random) {
        return Math.log(random.nextDouble()) < probability;
    }

    private static boolean isUTurning(Map<VariableReference, DoubleTensor> velocityForward, Map<VariableReference, DoubleTensor> velocityBackward, Map<VariableReference, DoubleTensor> rho) {
        double forward = 0.0;
        double backward = 0.0;
        for (VariableReference latentId : velocityForward.keySet()) {
            DoubleTensor vForward = velocityForward.get(latentId);
            DoubleTensor vBackward = velocityBackward.get(latentId);
            DoubleTensor rhoForLatent = rho.get(latentId);
            forward += ((Double)vForward.times(rhoForLatent).sum()).doubleValue();
            backward += ((Double)vBackward.times(rhoForLatent).sum()).doubleValue();
        }
        return forward < 0.0 || backward < 0.0;
    }

    public boolean shouldContinue() {
        return !this.diverged && !this.uTurned;
    }

    public void save(Statistics statistics) {
        statistics.store(NUTS.Metrics.LOG_PROB, this.proposal.getLogProb());
        statistics.store(NUTS.Metrics.TREE_SIZE, Double.valueOf(this.treeSize));
    }

    public LeapfrogState getForward() {
        return this.forward;
    }

    public LeapfrogState getBackward() {
        return this.backward;
    }

    public Proposal getProposal() {
        return this.proposal;
    }

    public Map<VariableReference, DoubleTensor> getSumMomentum() {
        return this.sumMomentum;
    }

    public double getStartEnergy() {
        return this.startEnergy;
    }

    public double getMaxEnergyChange() {
        return this.maxEnergyChange;
    }

    public double getLogSumWeight() {
        return this.logSumWeight;
    }

    public double getSumMetropolisAcceptanceProbability() {
        return this.sumMetropolisAcceptanceProbability;
    }

    public int getTreeSize() {
        return this.treeSize;
    }

    public int getTreeHeight() {
        return this.treeHeight;
    }

    public boolean isDiverged() {
        return this.diverged;
    }

    public boolean isUTurned() {
        return this.uTurned;
    }

    private static class SubTree {
        private LeapfrogState forward;
        private LeapfrogState backward;
        private Map<VariableReference, DoubleTensor> sumMomentum;
        private Proposal proposal;
        private double logSumWeight;
        private boolean diverged;
        private boolean uTurned;
        private double sumMetropolisAcceptanceProbability;
        private int treeSize;

        public boolean shouldContinue() {
            return !this.diverged && !this.uTurned;
        }

        public LeapfrogState getForward() {
            return this.forward;
        }

        public LeapfrogState getBackward() {
            return this.backward;
        }

        public Map<VariableReference, DoubleTensor> getSumMomentum() {
            return this.sumMomentum;
        }

        public Proposal getProposal() {
            return this.proposal;
        }

        public double getLogSumWeight() {
            return this.logSumWeight;
        }

        public boolean isDiverged() {
            return this.diverged;
        }

        public boolean isUTurned() {
            return this.uTurned;
        }

        public double getSumMetropolisAcceptanceProbability() {
            return this.sumMetropolisAcceptanceProbability;
        }

        public int getTreeSize() {
            return this.treeSize;
        }

        public void setForward(LeapfrogState forward) {
            this.forward = forward;
        }

        public void setBackward(LeapfrogState backward) {
            this.backward = backward;
        }

        public void setSumMomentum(Map<VariableReference, DoubleTensor> sumMomentum) {
            this.sumMomentum = sumMomentum;
        }

        public void setProposal(Proposal proposal) {
            this.proposal = proposal;
        }

        public void setLogSumWeight(double logSumWeight) {
            this.logSumWeight = logSumWeight;
        }

        public void setDiverged(boolean diverged) {
            this.diverged = diverged;
        }

        public void setUTurned(boolean uTurned) {
            this.uTurned = uTurned;
        }

        public void setSumMetropolisAcceptanceProbability(double sumMetropolisAcceptanceProbability) {
            this.sumMetropolisAcceptanceProbability = sumMetropolisAcceptanceProbability;
        }

        public void setTreeSize(int treeSize) {
            this.treeSize = treeSize;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof SubTree)) {
                return false;
            }
            SubTree other = (SubTree)o;
            if (!other.canEqual(this)) {
                return false;
            }
            LeapfrogState this$forward = this.getForward();
            LeapfrogState other$forward = other.getForward();
            if (this$forward == null ? other$forward != null : !((Object)this$forward).equals(other$forward)) {
                return false;
            }
            LeapfrogState this$backward = this.getBackward();
            LeapfrogState other$backward = other.getBackward();
            if (this$backward == null ? other$backward != null : !((Object)this$backward).equals(other$backward)) {
                return false;
            }
            Map<VariableReference, DoubleTensor> this$sumMomentum = this.getSumMomentum();
            Map<VariableReference, DoubleTensor> other$sumMomentum = other.getSumMomentum();
            if (this$sumMomentum == null ? other$sumMomentum != null : !((Object)this$sumMomentum).equals(other$sumMomentum)) {
                return false;
            }
            Proposal this$proposal = this.getProposal();
            Proposal other$proposal = other.getProposal();
            if (this$proposal == null ? other$proposal != null : !((Object)this$proposal).equals(other$proposal)) {
                return false;
            }
            if (Double.compare(this.getLogSumWeight(), other.getLogSumWeight()) != 0) {
                return false;
            }
            if (this.isDiverged() != other.isDiverged()) {
                return false;
            }
            if (this.isUTurned() != other.isUTurned()) {
                return false;
            }
            if (Double.compare(this.getSumMetropolisAcceptanceProbability(), other.getSumMetropolisAcceptanceProbability()) != 0) {
                return false;
            }
            return this.getTreeSize() == other.getTreeSize();
        }

        protected boolean canEqual(Object other) {
            return other instanceof SubTree;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            LeapfrogState $forward = this.getForward();
            result = result * 59 + ($forward == null ? 43 : ((Object)$forward).hashCode());
            LeapfrogState $backward = this.getBackward();
            result = result * 59 + ($backward == null ? 43 : ((Object)$backward).hashCode());
            Map<VariableReference, DoubleTensor> $sumMomentum = this.getSumMomentum();
            result = result * 59 + ($sumMomentum == null ? 43 : ((Object)$sumMomentum).hashCode());
            Proposal $proposal = this.getProposal();
            result = result * 59 + ($proposal == null ? 43 : ((Object)$proposal).hashCode());
            long $logSumWeight = Double.doubleToLongBits(this.getLogSumWeight());
            result = result * 59 + (int)($logSumWeight >>> 32 ^ $logSumWeight);
            result = result * 59 + (this.isDiverged() ? 79 : 97);
            result = result * 59 + (this.isUTurned() ? 79 : 97);
            long $sumMetropolisAcceptanceProbability = Double.doubleToLongBits(this.getSumMetropolisAcceptanceProbability());
            result = result * 59 + (int)($sumMetropolisAcceptanceProbability >>> 32 ^ $sumMetropolisAcceptanceProbability);
            result = result * 59 + this.getTreeSize();
            return result;
        }

        public String toString() {
            return "Tree.SubTree(forward=" + this.getForward() + ", backward=" + this.getBackward() + ", sumMomentum=" + this.getSumMomentum() + ", proposal=" + this.getProposal() + ", logSumWeight=" + this.getLogSumWeight() + ", diverged=" + this.isDiverged() + ", uTurned=" + this.isUTurned() + ", sumMetropolisAcceptanceProbability=" + this.getSumMetropolisAcceptanceProbability() + ", treeSize=" + this.getTreeSize() + ")";
        }

        public SubTree(LeapfrogState forward, LeapfrogState backward, Map<VariableReference, DoubleTensor> sumMomentum, Proposal proposal, double logSumWeight, boolean diverged, boolean uTurned, double sumMetropolisAcceptanceProbability, int treeSize) {
            this.forward = forward;
            this.backward = backward;
            this.sumMomentum = sumMomentum;
            this.proposal = proposal;
            this.logSumWeight = logSumWeight;
            this.diverged = diverged;
            this.uTurned = uTurned;
            this.sumMetropolisAcceptanceProbability = sumMetropolisAcceptanceProbability;
            this.treeSize = treeSize;
        }
    }
}

