/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.mcmc.nuts;

import io.improbable.keanu.algorithms.ProbabilisticModelWithGradient;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.mcmc.nuts.LeapfrogState;
import io.improbable.keanu.algorithms.mcmc.nuts.Potential;
import io.improbable.keanu.tensor.NumberTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.HashMap;
import java.util.Map;

public class LeapfrogIntegrator {
    private final Potential potential;

    public LeapfrogState step(LeapfrogState fromState, ProbabilisticModelWithGradient logProbGradientCalculator, double timeStep) {
        double halfTimeStep = timeStep / 2.0;
        Map<VariableReference, DoubleTensor> nextMomentum = LeapfrogIntegrator.stepMomentum(halfTimeStep, fromState.getMomentum(), fromState.getGradient());
        Map<VariableReference, DoubleTensor> nextVelocity = this.potential.getVelocity(nextMomentum);
        Map<VariableReference, DoubleTensor> nextPosition = LeapfrogIntegrator.stepPosition(timeStep, nextVelocity, fromState.getPosition());
        Map<VariableReference, DoubleTensor> nextPositionGradient = logProbGradientCalculator.logProbGradients(nextPosition);
        double nextPositionLogProb = logProbGradientCalculator.logProb();
        nextMomentum = LeapfrogIntegrator.stepMomentum(halfTimeStep, nextMomentum, nextPositionGradient);
        return new LeapfrogState(nextPosition, nextMomentum, nextPositionGradient, nextPositionLogProb, this.potential);
    }

    private static Map<VariableReference, DoubleTensor> stepPosition(double dt, Map<VariableReference, DoubleTensor> velocity, Map<VariableReference, DoubleTensor> position) {
        HashMap<VariableReference, DoubleTensor> nextPosition = new HashMap<VariableReference, DoubleTensor>();
        for (VariableReference variableReference : position.keySet()) {
            DoubleTensor variablePosition = position.get(variableReference);
            DoubleTensor variableVelocity = velocity.get(variableReference);
            DoubleTensor nextPositionForLatent = variableVelocity.times(dt).plusInPlace(variablePosition);
            nextPosition.put(variableReference, nextPositionForLatent);
        }
        return nextPosition;
    }

    private static Map<VariableReference, DoubleTensor> stepMomentum(double dt, Map<? extends VariableReference, DoubleTensor> momentum, Map<? extends VariableReference, DoubleTensor> gradient) {
        HashMap<VariableReference, DoubleTensor> nextMomentum = new HashMap<VariableReference, DoubleTensor>();
        for (Map.Entry<? extends VariableReference, DoubleTensor> rEntry : momentum.entrySet()) {
            DoubleTensor updatedMomentum = (DoubleTensor)gradient.get(rEntry.getKey()).times(dt).plusInPlace((NumberTensor)rEntry.getValue());
            nextMomentum.put(rEntry.getKey(), updatedMomentum);
        }
        return nextMomentum;
    }

    public LeapfrogIntegrator(Potential potential) {
        this.potential = potential;
    }
}

