/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.mcmc.nuts;

import com.google.common.base.Preconditions;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.mcmc.nuts.VariableValues;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.Map;

public class VarianceCalculator {
    private double count;
    private Map<VariableReference, DoubleTensor> mean;
    private Map<VariableReference, DoubleTensor> M2;

    public VarianceCalculator(Map<VariableReference, DoubleTensor> initialMean, Map<VariableReference, DoubleTensor> initialVariance, double initialWeight) {
        Preconditions.checkArgument((initialWeight >= 0.0 ? 1 : 0) != 0, (Object)"Initial weight must be greater than or equal to 0");
        this.count = initialWeight;
        this.mean = initialMean;
        this.M2 = VariableValues.times(initialVariance, this.count);
    }

    public void addSample(Map<VariableReference, DoubleTensor> sampleForLatents) {
        this.count += 1.0;
        for (Map.Entry<VariableReference, DoubleTensor> sampleForVariable : sampleForLatents.entrySet()) {
            VariableReference v = sampleForVariable.getKey();
            DoubleTensor sample = sampleForVariable.getValue();
            DoubleTensor oldMean = this.mean.get(v);
            DoubleTensor delta = sample.minus(oldMean);
            DoubleTensor newMean = oldMean.plus(delta.div(this.count));
            DoubleTensor delta2 = sample.minus(newMean);
            DoubleTensor oldM2 = this.M2.get(v);
            DoubleTensor newM2 = oldM2.plus(delta.times(delta2));
            this.mean.put(v, newMean);
            this.M2.put(v, newM2);
        }
    }

    public Map<VariableReference, DoubleTensor> calculateCurrentVariance() {
        return VariableValues.divide(this.M2, this.count);
    }
}

