/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.mcmc.nuts;

import com.google.common.base.Preconditions;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.mcmc.nuts.Potential;
import io.improbable.keanu.algorithms.mcmc.nuts.VariableValues;
import io.improbable.keanu.algorithms.mcmc.nuts.VarianceCalculator;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.HashMap;
import java.util.Map;

public class AdaptiveQuadraticPotential
implements Potential {
    private final double initialWeight;
    private final double initialMean;
    private final double initialVariance;
    private final int adaptionWindowSize;
    private VarianceCalculator forwardVariance;
    private VarianceCalculator backgroundVariance;
    private long nSamples;
    private Map<VariableReference, DoubleTensor> variance;
    private Map<VariableReference, DoubleTensor> standardDeviation;

    public AdaptiveQuadraticPotential(double initialMean, double initialVariance, double initialWeight, int adaptionWindowSize) {
        Preconditions.checkArgument((adaptionWindowSize > 1 ? 1 : 0) != 0, (Object)"Adapt window size must be greater than 1");
        this.initialWeight = initialWeight;
        this.initialMean = initialMean;
        this.initialVariance = initialVariance;
        this.adaptionWindowSize = adaptionWindowSize;
        this.nSamples = 0L;
    }

    @Override
    public void initialize(Map<VariableReference, DoubleTensor> shapeLike) {
        Map<VariableReference, DoubleTensor> varianceShapedLike = VariableValues.withShape(this.initialVariance, shapeLike);
        Map<VariableReference, DoubleTensor> meanShapedLike = VariableValues.withShape(this.initialMean, shapeLike);
        this.setVariance(varianceShapedLike);
        this.forwardVariance = new VarianceCalculator(meanShapedLike, varianceShapedLike, this.initialWeight);
        this.backgroundVariance = new VarianceCalculator(VariableValues.zeros(meanShapedLike), VariableValues.zeros(meanShapedLike), 0.0);
    }

    private void setVariance(Map<VariableReference, DoubleTensor> variance) {
        this.variance = variance;
        this.standardDeviation = VariableValues.pow(this.variance, 0.5);
    }

    @Override
    public void update(Map<VariableReference, DoubleTensor> position) {
        if (this.nSamples > 0L && this.nSamples % (long)this.adaptionWindowSize == 0L) {
            this.forwardVariance = this.backgroundVariance;
            this.backgroundVariance = new VarianceCalculator(VariableValues.zeros(this.variance), VariableValues.zeros(this.variance), 0.0);
        }
        this.forwardVariance.addSample(position);
        this.backgroundVariance.addSample(position);
        this.setVariance(this.forwardVariance.calculateCurrentVariance());
        ++this.nSamples;
    }

    @Override
    public Map<VariableReference, DoubleTensor> randomMomentum(KeanuRandom random) {
        HashMap<VariableReference, DoubleTensor> result = new HashMap<VariableReference, DoubleTensor>();
        for (VariableReference variable : this.standardDeviation.keySet()) {
            DoubleTensor standardDeviationForVariable = this.standardDeviation.get(variable);
            DoubleTensor randomForVariable = random.nextGaussian(standardDeviationForVariable.getShape()).divInPlace(standardDeviationForVariable);
            result.put(variable, randomForVariable);
        }
        return result;
    }

    @Override
    public Map<VariableReference, DoubleTensor> getVelocity(Map<VariableReference, DoubleTensor> momentum) {
        return VariableValues.times(this.variance, momentum);
    }

    @Override
    public double getKineticEnergy(Map<VariableReference, DoubleTensor> momentum, Map<VariableReference, DoubleTensor> velocity) {
        return 0.5 * VariableValues.dotProduct(momentum, velocity);
    }

    public Map<VariableReference, DoubleTensor> getVariance() {
        return this.variance;
    }

    public Map<VariableReference, DoubleTensor> getStandardDeviation() {
        return this.standardDeviation;
    }
}

