package io.improbable.keanu.algorithms.mcmc.nuts;

import com.google.common.base.Preconditions;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/nuts/VarianceCalculator.class */
public class VarianceCalculator {
    private double count;
    private Map<VariableReference, DoubleTensor> mean;
    private Map<VariableReference, DoubleTensor> M2;

    public VarianceCalculator(Map<VariableReference, DoubleTensor> map, Map<VariableReference, DoubleTensor> map2, double d) {
        Preconditions.checkArgument(d >= 0.0d, "Initial weight must be greater than or equal to 0");
        this.count = d;
        this.mean = map;
        this.M2 = VariableValues.times(map2, this.count);
    }

    public void addSample(Map<VariableReference, DoubleTensor> map) {
        this.count += 1.0d;
        for (Map.Entry<VariableReference, DoubleTensor> entry : map.entrySet()) {
            VariableReference key = entry.getKey();
            DoubleTensor value = entry.getValue();
            DoubleTensor doubleTensor = this.mean.get(key);
            DoubleTensor doubleTensor2 = (DoubleTensor) value.minus(doubleTensor);
            DoubleTensor doubleTensor3 = (DoubleTensor) doubleTensor.plus(doubleTensor2.div2(this.count));
            DoubleTensor doubleTensor4 = (DoubleTensor) this.M2.get(key).plus((DoubleTensor) doubleTensor2.times((DoubleTensor) value.minus(doubleTensor3)));
            this.mean.put(key, doubleTensor3);
            this.M2.put(key, doubleTensor4);
        }
    }

    public Map<VariableReference, DoubleTensor> calculateCurrentVariance() {
        return VariableValues.divide(this.M2, this.count);
    }
}
