package io.improbable.keanu.algorithms.mcmc.nuts;

import io.improbable.keanu.algorithms.ProbabilisticModelWithGradient;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/nuts/LeapfrogIntegrator.class */
public class LeapfrogIntegrator {
    private final Potential potential;

    public LeapfrogState step(LeapfrogState leapfrogState, ProbabilisticModelWithGradient probabilisticModelWithGradient, double d) {
        double d2 = d / 2.0d;
        Map<VariableReference, DoubleTensor> stepMomentum = stepMomentum(d2, leapfrogState.getMomentum(), leapfrogState.getGradient());
        Map<VariableReference, DoubleTensor> stepPosition = stepPosition(d, this.potential.getVelocity(stepMomentum), leapfrogState.getPosition());
        Map<VariableReference, DoubleTensor> logProbGradients = probabilisticModelWithGradient.logProbGradients(stepPosition);
        return new LeapfrogState(stepPosition, stepMomentum(d2, stepMomentum, logProbGradients), logProbGradients, probabilisticModelWithGradient.logProb(), this.potential);
    }

    private static Map<VariableReference, DoubleTensor> stepPosition(double d, Map<VariableReference, DoubleTensor> map, Map<VariableReference, DoubleTensor> map2) {
        HashMap hashMap = new HashMap();
        for (VariableReference variableReference : map2.keySet()) {
            hashMap.put(variableReference, (DoubleTensor) map.get(variableReference).times2(d).plusInPlace(map2.get(variableReference)));
        }
        return hashMap;
    }

    private static Map<VariableReference, DoubleTensor> stepMomentum(double d, Map<? extends VariableReference, DoubleTensor> map, Map<? extends VariableReference, DoubleTensor> map2) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<? extends VariableReference, DoubleTensor> entry : map.entrySet()) {
            hashMap.put(entry.getKey(), (DoubleTensor) map2.get(entry.getKey()).times2(d).plusInPlace(entry.getValue()));
        }
        return hashMap;
    }

    public LeapfrogIntegrator(Potential potential) {
        this.potential = potential;
    }
}
