package io.improbable.keanu.algorithms.mcmc.nuts;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.ProbabilisticModelWithGradient;
import io.improbable.keanu.algorithms.Statistics;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.mcmc.SamplingAlgorithm;
import io.improbable.keanu.algorithms.mcmc.nuts.NUTS;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.List;
import java.util.Map;

/*  JADX ERROR: NullPointerException in pass: ClassModifier
    java.lang.NullPointerException: Cannot invoke "java.util.List.forEach(java.util.function.Consumer)" because "blocks" is null
    	at jadx.core.utils.BlockUtils.collectAllInsns(BlockUtils.java:1017)
    	at jadx.core.dex.visitors.ClassModifier.removeBridgeMethod(ClassModifier.java:239)
    	at jadx.core.dex.visitors.ClassModifier.removeSyntheticMethods(ClassModifier.java:154)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.ClassModifier.visit(ClassModifier.java:64)
    	at jadx.core.dex.visitors.ClassModifier.visit(ClassModifier.java:57)
    */
/* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/nuts/Tree.class */
class Tree {
    private final ProbabilisticModelWithGradient logProbGradientCalculator;
    private final LeapfrogIntegrator leapfrogIntegrator;
    private final List<? extends Variable> sampleFromVariables;
    private final KeanuRandom random;
    private LeapfrogState forward;
    private LeapfrogState backward;
    private Proposal proposal;
    private Map<VariableReference, DoubleTensor> sumMomentum;
    private final double startEnergy;
    private final double maxEnergyChange;
    private double logSumWeight = 0.0d;
    private double sumMetropolisAcceptanceProbability = 0.0d;
    private int treeSize = 0;
    private int treeHeight = 0;
    private boolean diverged = false;
    private boolean uTurned = false;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/nuts/Tree$SubTree.class */
    public static class SubTree {
        private LeapfrogState forward;
        private LeapfrogState backward;
        private Map<VariableReference, DoubleTensor> sumMomentum;
        private Proposal proposal;
        private double logSumWeight;
        private boolean diverged;
        private boolean uTurned;
        private double sumMetropolisAcceptanceProbability;
        private int treeSize;

        public boolean shouldContinue() {
            return (this.diverged || this.uTurned) ? false : true;
        }

        public LeapfrogState getForward() {
            return this.forward;
        }

        public LeapfrogState getBackward() {
            return this.backward;
        }

        public Map<VariableReference, DoubleTensor> getSumMomentum() {
            return this.sumMomentum;
        }

        public Proposal getProposal() {
            return this.proposal;
        }

        public double getLogSumWeight() {
            return this.logSumWeight;
        }

        public boolean isDiverged() {
            return this.diverged;
        }

        public boolean isUTurned() {
            return this.uTurned;
        }

        public double getSumMetropolisAcceptanceProbability() {
            return this.sumMetropolisAcceptanceProbability;
        }

        public int getTreeSize() {
            return this.treeSize;
        }

        public void setForward(LeapfrogState leapfrogState) {
            this.forward = leapfrogState;
        }

        public void setBackward(LeapfrogState leapfrogState) {
            this.backward = leapfrogState;
        }

        public void setSumMomentum(Map<VariableReference, DoubleTensor> map) {
            this.sumMomentum = map;
        }

        public void setProposal(Proposal proposal) {
            this.proposal = proposal;
        }

        public void setLogSumWeight(double d) {
            this.logSumWeight = d;
        }

        public void setDiverged(boolean z) {
            this.diverged = z;
        }

        public void setUTurned(boolean z) {
            this.uTurned = z;
        }

        public void setSumMetropolisAcceptanceProbability(double d) {
            this.sumMetropolisAcceptanceProbability = d;
        }

        public void setTreeSize(int i) {
            this.treeSize = i;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof SubTree)) {
                return false;
            }
            SubTree subTree = (SubTree) obj;
            if (!subTree.canEqual(this)) {
                return false;
            }
            LeapfrogState forward = getForward();
            LeapfrogState forward2 = subTree.getForward();
            if (forward == null) {
                if (forward2 != null) {
                    return false;
                }
            } else if (!forward.equals(forward2)) {
                return false;
            }
            LeapfrogState backward = getBackward();
            LeapfrogState backward2 = subTree.getBackward();
            if (backward == null) {
                if (backward2 != null) {
                    return false;
                }
            } else if (!backward.equals(backward2)) {
                return false;
            }
            Map<VariableReference, DoubleTensor> sumMomentum = getSumMomentum();
            Map<VariableReference, DoubleTensor> sumMomentum2 = subTree.getSumMomentum();
            if (sumMomentum == null) {
                if (sumMomentum2 != null) {
                    return false;
                }
            } else if (!sumMomentum.equals(sumMomentum2)) {
                return false;
            }
            Proposal proposal = getProposal();
            Proposal proposal2 = subTree.getProposal();
            if (proposal == null) {
                if (proposal2 != null) {
                    return false;
                }
            } else if (!proposal.equals(proposal2)) {
                return false;
            }
            return Double.compare(getLogSumWeight(), subTree.getLogSumWeight()) == 0 && isDiverged() == subTree.isDiverged() && isUTurned() == subTree.isUTurned() && Double.compare(getSumMetropolisAcceptanceProbability(), subTree.getSumMetropolisAcceptanceProbability()) == 0 && getTreeSize() == subTree.getTreeSize();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof SubTree;
        }

        public int hashCode() {
            LeapfrogState forward = getForward();
            int hashCode = (1 * 59) + (forward == null ? 43 : forward.hashCode());
            LeapfrogState backward = getBackward();
            int hashCode2 = (hashCode * 59) + (backward == null ? 43 : backward.hashCode());
            Map<VariableReference, DoubleTensor> sumMomentum = getSumMomentum();
            int hashCode3 = (hashCode2 * 59) + (sumMomentum == null ? 43 : sumMomentum.hashCode());
            Proposal proposal = getProposal();
            int hashCode4 = (hashCode3 * 59) + (proposal == null ? 43 : proposal.hashCode());
            long doubleToLongBits = Double.doubleToLongBits(getLogSumWeight());
            int i = (((((hashCode4 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits))) * 59) + (isDiverged() ? 79 : 97)) * 59) + (isUTurned() ? 79 : 97);
            long doubleToLongBits2 = Double.doubleToLongBits(getSumMetropolisAcceptanceProbability());
            return (((i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2))) * 59) + getTreeSize();
        }

        public String toString() {
            return "Tree.SubTree(forward=" + getForward() + ", backward=" + getBackward() + ", sumMomentum=" + getSumMomentum() + ", proposal=" + getProposal() + ", logSumWeight=" + getLogSumWeight() + ", diverged=" + isDiverged() + ", uTurned=" + isUTurned() + ", sumMetropolisAcceptanceProbability=" + getSumMetropolisAcceptanceProbability() + ", treeSize=" + getTreeSize() + ")";
        }

        public SubTree(LeapfrogState leapfrogState, LeapfrogState leapfrogState2, Map<VariableReference, DoubleTensor> map, Proposal proposal, double d, boolean z, boolean z2, double d2, int i) {
            this.forward = leapfrogState;
            this.backward = leapfrogState2;
            this.sumMomentum = map;
            this.proposal = proposal;
            this.logSumWeight = d;
            this.diverged = z;
            this.uTurned = z2;
            this.sumMetropolisAcceptanceProbability = d2;
            this.treeSize = i;
        }

        static /* synthetic */ double access$200(SubTree subTree) {
            return subTree.sumMetropolisAcceptanceProbability;
        }

        static /* synthetic */ int access$300(SubTree subTree) {
            return subTree.treeSize;
        }

        static /* synthetic */ Proposal access$400(SubTree subTree) {
            return subTree.proposal;
        }

        static /* synthetic */ LeapfrogState access$102(SubTree subTree, LeapfrogState leapfrogState) {
            subTree.forward = leapfrogState;
            return leapfrogState;
        }

        /*  JADX ERROR: Failed to decode insn: 0x0002: MOVE_MULTI, method: io.improbable.keanu.algorithms.mcmc.nuts.Tree.SubTree.access$502(io.improbable.keanu.algorithms.mcmc.nuts.Tree$SubTree, double):double
            java.lang.ArrayIndexOutOfBoundsException: arraycopy: source index -1 out of bounds for object array[6]
            	at java.base/java.lang.System.arraycopy(Native Method)
            	at jadx.plugins.input.java.data.code.StackState.insert(StackState.java:49)
            	at jadx.plugins.input.java.data.code.CodeDecodeState.insert(CodeDecodeState.java:118)
            	at jadx.plugins.input.java.data.code.JavaInsnsRegister.dup2x1(JavaInsnsRegister.java:313)
            	at jadx.plugins.input.java.data.code.JavaInsnData.decode(JavaInsnData.java:46)
            	at jadx.core.dex.instructions.InsnDecoder.lambda$process$0(InsnDecoder.java:54)
            	at jadx.plugins.input.java.data.code.JavaCodeReader.visitInstructions(JavaCodeReader.java:81)
            	at jadx.core.dex.instructions.InsnDecoder.process(InsnDecoder.java:50)
            	at jadx.core.dex.nodes.MethodNode.load(MethodNode.java:156)
            	at jadx.core.dex.nodes.ClassNode.load(ClassNode.java:443)
            	at jadx.core.dex.nodes.ClassNode.load(ClassNode.java:449)
            	at jadx.core.ProcessClass.process(ProcessClass.java:70)
            	at jadx.core.ProcessClass.generateCode(ProcessClass.java:118)
            	at jadx.core.dex.nodes.ClassNode.generateClassCode(ClassNode.java:400)
            	at jadx.core.dex.nodes.ClassNode.decompile(ClassNode.java:388)
            	at jadx.core.dex.nodes.ClassNode.getCode(ClassNode.java:338)
            */
        static /* synthetic */ double access$502(io.improbable.keanu.algorithms.mcmc.nuts.Tree.SubTree r6, double r7) {
            /*
                r0 = r6
                r1 = r7
                // decode failed: arraycopy: source index -1 out of bounds for object array[6]
                r0.logSumWeight = r1
                return r-1
            */
            throw new UnsupportedOperationException("Method not decompiled: io.improbable.keanu.algorithms.mcmc.nuts.Tree.SubTree.access$502(io.improbable.keanu.algorithms.mcmc.nuts.Tree$SubTree, double):double");
        }

        static /* synthetic */ Proposal access$402(SubTree subTree, Proposal proposal) {
            subTree.proposal = proposal;
            return proposal;
        }

        /*  JADX ERROR: Failed to decode insn: 0x0002: MOVE_MULTI, method: io.improbable.keanu.algorithms.mcmc.nuts.Tree.SubTree.access$202(io.improbable.keanu.algorithms.mcmc.nuts.Tree$SubTree, double):double
            java.lang.ArrayIndexOutOfBoundsException: arraycopy: source index -1 out of bounds for object array[6]
            	at java.base/java.lang.System.arraycopy(Native Method)
            	at jadx.plugins.input.java.data.code.StackState.insert(StackState.java:49)
            	at jadx.plugins.input.java.data.code.CodeDecodeState.insert(CodeDecodeState.java:118)
            	at jadx.plugins.input.java.data.code.JavaInsnsRegister.dup2x1(JavaInsnsRegister.java:313)
            	at jadx.plugins.input.java.data.code.JavaInsnData.decode(JavaInsnData.java:46)
            	at jadx.core.dex.instructions.InsnDecoder.lambda$process$0(InsnDecoder.java:54)
            	at jadx.plugins.input.java.data.code.JavaCodeReader.visitInstructions(JavaCodeReader.java:81)
            	at jadx.core.dex.instructions.InsnDecoder.process(InsnDecoder.java:50)
            	at jadx.core.dex.nodes.MethodNode.load(MethodNode.java:156)
            	at jadx.core.dex.nodes.ClassNode.load(ClassNode.java:443)
            	at jadx.core.dex.nodes.ClassNode.load(ClassNode.java:449)
            	at jadx.core.ProcessClass.process(ProcessClass.java:70)
            	at jadx.core.ProcessClass.generateCode(ProcessClass.java:118)
            	at jadx.core.dex.nodes.ClassNode.generateClassCode(ClassNode.java:400)
            	at jadx.core.dex.nodes.ClassNode.decompile(ClassNode.java:388)
            	at jadx.core.dex.nodes.ClassNode.getCode(ClassNode.java:338)
            */
        static /* synthetic */ double access$202(io.improbable.keanu.algorithms.mcmc.nuts.Tree.SubTree r6, double r7) {
            /*
                r0 = r6
                r1 = r7
                // decode failed: arraycopy: source index -1 out of bounds for object array[6]
                r0.sumMetropolisAcceptanceProbability = r1
                return r-1
            */
            throw new UnsupportedOperationException("Method not decompiled: io.improbable.keanu.algorithms.mcmc.nuts.Tree.SubTree.access$202(io.improbable.keanu.algorithms.mcmc.nuts.Tree$SubTree, double):double");
        }

        static /* synthetic */ int access$302(SubTree subTree, int i) {
            subTree.treeSize = i;
            return i;
        }
    }

    public Tree(LeapfrogState leapfrogState, Proposal proposal, double d, ProbabilisticModelWithGradient probabilisticModelWithGradient, LeapfrogIntegrator leapfrogIntegrator, List<? extends Variable> list, KeanuRandom keanuRandom) {
        this.forward = leapfrogState;
        this.backward = leapfrogState;
        this.proposal = proposal;
        this.maxEnergyChange = d;
        this.logProbGradientCalculator = probabilisticModelWithGradient;
        this.leapfrogIntegrator = leapfrogIntegrator;
        this.sampleFromVariables = list;
        this.random = keanuRandom;
        this.sumMomentum = leapfrogState.getMomentum();
        this.startEnergy = leapfrogState.getEnergy();
    }

    public void grow(int i, double d) {
        SubTree buildTree = buildTree(i == -1 ? this.backward : this.forward, i, this.treeHeight, d);
        if (i == -1) {
            this.backward = buildTree.backward;
        } else {
            this.forward = buildTree.forward;
        }
        this.sumMetropolisAcceptanceProbability += buildTree.sumMetropolisAcceptanceProbability;
        this.treeSize += buildTree.treeSize;
        if (buildTree.shouldContinue()) {
            if (acceptOtherProposalWithProbability(buildTree.getLogSumWeight() - this.logSumWeight, this.random)) {
                this.proposal = buildTree.proposal;
            }
            this.logSumWeight = logSumExp(this.logSumWeight, buildTree.logSumWeight);
            this.sumMomentum = VariableValues.add(this.sumMomentum, buildTree.sumMomentum);
        }
        this.diverged = buildTree.diverged;
        if (!this.diverged) {
            this.uTurned = buildTree.uTurned || isUTurning(this.forward.getVelocity(), this.backward.getVelocity(), this.sumMomentum);
        }
        this.treeHeight++;
    }

    /*  JADX ERROR: JadxRuntimeException in pass: InlineMethods
        jadx.core.utils.exceptions.JadxRuntimeException: Failed to process method for inline: io.improbable.keanu.algorithms.mcmc.nuts.Tree.SubTree.access$502(io.improbable.keanu.algorithms.mcmc.nuts.Tree$SubTree, double):double
        	at jadx.core.dex.visitors.InlineMethods.processInvokeInsn(InlineMethods.java:74)
        	at jadx.core.dex.visitors.InlineMethods.visit(InlineMethods.java:49)
        Caused by: jadx.core.utils.exceptions.JadxRuntimeException: Class not yet loaded at codegen stage: io.improbable.keanu.algorithms.mcmc.nuts.Tree
        	at jadx.core.dex.nodes.ClassNode.reloadAtCodegenStage(ClassNode.java:883)
        	at jadx.core.dex.visitors.InlineMethods.processInvokeInsn(InlineMethods.java:66)
        	... 1 more
        */
    private io.improbable.keanu.algorithms.mcmc.nuts.Tree.SubTree buildTree(io.improbable.keanu.algorithms.mcmc.nuts.LeapfrogState r8, int r9, int r10, double r11) {
        /*
            Method dump skipped, instructions count: 270
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: io.improbable.keanu.algorithms.mcmc.nuts.Tree.buildTree(io.improbable.keanu.algorithms.mcmc.nuts.LeapfrogState, int, int, double):io.improbable.keanu.algorithms.mcmc.nuts.Tree$SubTree");
    }

    public static double logSumExp(double d, double d2) {
        double max = Math.max(d, d2);
        return max + Math.log(Math.exp(d - max) + Math.exp(d2 - max));
    }

    private SubTree treeBuilderBaseCase(LeapfrogState leapfrogState, int i, double d) {
        LeapfrogState step = this.leapfrogIntegrator.step(leapfrogState, this.logProbGradientCalculator, d * i);
        double energy = step.getEnergy() - this.startEnergy;
        if (Math.abs(energy) >= this.maxEnergyChange || Double.isNaN(energy)) {
            return new SubTree(step, step, step.getMomentum(), null, Double.NEGATIVE_INFINITY, true, false, 0.0d, 1);
        }
        double d2 = -energy;
        return new SubTree(step, step, step.getMomentum(), new Proposal(step.getPosition(), step.getGradient(), SamplingAlgorithm.takeSample(this.sampleFromVariables), step.getLogProb()), d2, false, false, Math.min(1.0d, Math.exp(d2)), 1);
    }

    private static boolean acceptOtherProposalWithProbability(double d, KeanuRandom keanuRandom) {
        return Math.log(keanuRandom.nextDouble()) < d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static boolean isUTurning(Map<VariableReference, DoubleTensor> map, Map<VariableReference, DoubleTensor> map2, Map<VariableReference, DoubleTensor> map3) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (VariableReference variableReference : map.keySet()) {
            DoubleTensor doubleTensor = map.get(variableReference);
            DoubleTensor doubleTensor2 = map2.get(variableReference);
            DoubleTensor doubleTensor3 = map3.get(variableReference);
            d += ((Double) ((DoubleTensor) doubleTensor.times(doubleTensor3)).sum()).doubleValue();
            d2 += ((Double) ((DoubleTensor) doubleTensor2.times(doubleTensor3)).sum()).doubleValue();
        }
        return d < 0.0d || d2 < 0.0d;
    }

    public boolean shouldContinue() {
        return (this.diverged || this.uTurned) ? false : true;
    }

    public void save(Statistics statistics) {
        statistics.store(NUTS.Metrics.LOG_PROB, Double.valueOf(this.proposal.getLogProb()));
        statistics.store(NUTS.Metrics.TREE_SIZE, Double.valueOf(this.treeSize));
    }

    public LeapfrogState getForward() {
        return this.forward;
    }

    public LeapfrogState getBackward() {
        return this.backward;
    }

    public Proposal getProposal() {
        return this.proposal;
    }

    public Map<VariableReference, DoubleTensor> getSumMomentum() {
        return this.sumMomentum;
    }

    public double getStartEnergy() {
        return this.startEnergy;
    }

    public double getMaxEnergyChange() {
        return this.maxEnergyChange;
    }

    public double getLogSumWeight() {
        return this.logSumWeight;
    }

    public double getSumMetropolisAcceptanceProbability() {
        return this.sumMetropolisAcceptanceProbability;
    }

    public int getTreeSize() {
        return this.treeSize;
    }

    public int getTreeHeight() {
        return this.treeHeight;
    }

    public boolean isDiverged() {
        return this.diverged;
    }

    public boolean isUTurned() {
        return this.uTurned;
    }
}
