package io.improbable.keanu.algorithms.mcmc.nuts;

import io.improbable.keanu.algorithms.Statistics;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.mcmc.nuts.NUTS;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.List;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/nuts/AdaptiveStepSize.class */
public class AdaptiveStepSize {
    private static final double t0 = 10.0d;
    private static final double gamma = 0.05d;
    private static final double kappa = 0.75d;
    private final double mu;
    private final long adaptCount;
    private final double sigma;
    private double stepSize;
    private double acceptRate;
    private double logStepSizeBar;
    private double logStepSize;
    private double hBar = 0.0d;
    private long stepNum = 1;

    /* JADX INFO: Access modifiers changed from: package-private */
    public AdaptiveStepSize(double d, double d2, long j) {
        this.sigma = d2;
        this.stepSize = d;
        this.logStepSize = Math.log(d);
        this.logStepSizeBar = this.logStepSize;
        this.adaptCount = j;
        this.mu = Math.log(t0 * d);
    }

    public static double findStartingStepSize(double d, List<? extends Variable<DoubleTensor, ?>> list) {
        return d / Math.pow(list.stream().mapToLong(variable -> {
            return ((DoubleTensor) variable.getValue()).getLength();
        }).sum(), 0.25d);
    }

    public double adaptStepSize(Tree tree) {
        if (this.stepNum <= this.adaptCount) {
            this.stepSize = Math.exp(updateLogStepSize(tree));
        } else if (this.stepNum == this.adaptCount + 1) {
            this.stepSize = Math.exp(this.logStepSizeBar);
        }
        this.stepNum++;
        return this.stepSize;
    }

    private double updateLogStepSize(Tree tree) {
        double sumMetropolisAcceptanceProbability = tree.getSumMetropolisAcceptanceProbability();
        double treeSize = tree.getTreeSize();
        double d = 1.0d / (this.stepNum + t0);
        this.acceptRate = sumMetropolisAcceptanceProbability / treeSize;
        this.hBar = ((1.0d - d) * this.hBar) + (d * (this.sigma - this.acceptRate));
        this.logStepSize = this.mu - ((Math.sqrt(this.stepNum) / gamma) * this.hBar);
        double pow = Math.pow(this.stepNum, -0.75d);
        this.logStepSizeBar = (pow * this.logStepSize) + ((1.0d - pow) * this.logStepSizeBar);
        return this.logStepSize;
    }

    public double getStepSize() {
        return this.stepSize;
    }

    public void save(Statistics statistics) {
        statistics.store(NUTS.Metrics.STEPSIZE, Double.valueOf(this.stepSize));
        statistics.store(NUTS.Metrics.MEAN_TREE_ACCEPT, Double.valueOf(this.acceptRate));
    }
}
