package io.improbable.keanu.algorithms.mcmc.nuts;

import com.google.common.base.Preconditions;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/nuts/AdaptiveQuadraticPotential.class */
public class AdaptiveQuadraticPotential implements Potential {
    private final double initialWeight;
    private final double initialMean;
    private final double initialVariance;
    private final int adaptionWindowSize;
    private VarianceCalculator forwardVariance;
    private VarianceCalculator backgroundVariance;
    private long nSamples;
    private Map<VariableReference, DoubleTensor> variance;
    private Map<VariableReference, DoubleTensor> standardDeviation;

    public AdaptiveQuadraticPotential(double d, double d2, double d3, int i) {
        Preconditions.checkArgument(i > 1, "Adapt window size must be greater than 1");
        this.initialWeight = d3;
        this.initialMean = d;
        this.initialVariance = d2;
        this.adaptionWindowSize = i;
        this.nSamples = 0L;
    }

    @Override // io.improbable.keanu.algorithms.mcmc.nuts.Potential
    public void initialize(Map<VariableReference, DoubleTensor> map) {
        Map<VariableReference, DoubleTensor> withShape = VariableValues.withShape(this.initialVariance, map);
        Map<VariableReference, DoubleTensor> withShape2 = VariableValues.withShape(this.initialMean, map);
        setVariance(withShape);
        this.forwardVariance = new VarianceCalculator(withShape2, withShape, this.initialWeight);
        this.backgroundVariance = new VarianceCalculator(VariableValues.zeros(withShape2), VariableValues.zeros(withShape2), 0.0d);
    }

    private void setVariance(Map<VariableReference, DoubleTensor> map) {
        this.variance = map;
        this.standardDeviation = VariableValues.pow(this.variance, 0.5d);
    }

    @Override // io.improbable.keanu.algorithms.mcmc.nuts.Potential
    public void update(Map<VariableReference, DoubleTensor> map) {
        if (this.nSamples > 0 && this.nSamples % this.adaptionWindowSize == 0) {
            this.forwardVariance = this.backgroundVariance;
            this.backgroundVariance = new VarianceCalculator(VariableValues.zeros(this.variance), VariableValues.zeros(this.variance), 0.0d);
        }
        this.forwardVariance.addSample(map);
        this.backgroundVariance.addSample(map);
        setVariance(this.forwardVariance.calculateCurrentVariance());
        this.nSamples++;
    }

    @Override // io.improbable.keanu.algorithms.mcmc.nuts.Potential
    public Map<VariableReference, DoubleTensor> randomMomentum(KeanuRandom keanuRandom) {
        HashMap hashMap = new HashMap();
        for (VariableReference variableReference : this.standardDeviation.keySet()) {
            DoubleTensor doubleTensor = this.standardDeviation.get(variableReference);
            hashMap.put(variableReference, (DoubleTensor) keanuRandom.nextGaussian(doubleTensor.getShape()).divInPlace(doubleTensor));
        }
        return hashMap;
    }

    @Override // io.improbable.keanu.algorithms.mcmc.nuts.Potential
    public Map<VariableReference, DoubleTensor> getVelocity(Map<VariableReference, DoubleTensor> map) {
        return VariableValues.times(this.variance, map);
    }

    @Override // io.improbable.keanu.algorithms.mcmc.nuts.Potential
    public double getKineticEnergy(Map<VariableReference, DoubleTensor> map, Map<VariableReference, DoubleTensor> map2) {
        return 0.5d * VariableValues.dotProduct(map, map2);
    }

    public Map<VariableReference, DoubleTensor> getVariance() {
        return this.variance;
    }

    public Map<VariableReference, DoubleTensor> getStandardDeviation() {
        return this.standardDeviation;
    }
}
