/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.mcmc.nuts;

import io.improbable.keanu.algorithms.Statistics;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.mcmc.nuts.NUTS;
import io.improbable.keanu.algorithms.mcmc.nuts.Tree;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.List;

class AdaptiveStepSize {
    private static final double t0 = 10.0;
    private static final double gamma = 0.05;
    private static final double kappa = 0.75;
    private final double mu;
    private final long adaptCount;
    private final double sigma;
    private double stepSize;
    private double hBar;
    private double acceptRate;
    private double logStepSizeBar;
    private double logStepSize;
    private long stepNum;

    AdaptiveStepSize(double stepSize, double sigma, long adaptCount) {
        this.sigma = sigma;
        this.stepSize = stepSize;
        this.hBar = 0.0;
        this.logStepSizeBar = this.logStepSize = Math.log(stepSize);
        this.adaptCount = adaptCount;
        this.mu = Math.log(10.0 * stepSize);
        this.stepNum = 1L;
    }

    public static double findStartingStepSize(double stepScale, List<? extends Variable<DoubleTensor, ?>> variables) {
        long N = variables.stream().mapToLong(v -> ((DoubleTensor)v.getValue()).getLength()).sum();
        return stepScale / Math.pow(N, 0.25);
    }

    public double adaptStepSize(Tree tree) {
        if (this.stepNum <= this.adaptCount) {
            this.stepSize = Math.exp(this.updateLogStepSize(tree));
        } else if (this.stepNum == this.adaptCount + 1L) {
            this.stepSize = Math.exp(this.logStepSizeBar);
        }
        ++this.stepNum;
        return this.stepSize;
    }

    private double updateLogStepSize(Tree tree) {
        double alpha = tree.getSumMetropolisAcceptanceProbability();
        double nuAlpha = tree.getTreeSize();
        double w = 1.0 / ((double)this.stepNum + 10.0);
        this.acceptRate = alpha / nuAlpha;
        this.hBar = (1.0 - w) * this.hBar + w * (this.sigma - this.acceptRate);
        this.logStepSize = this.mu - Math.sqrt(this.stepNum) / 0.05 * this.hBar;
        double tendToZero = Math.pow(this.stepNum, -0.75);
        this.logStepSizeBar = tendToZero * this.logStepSize + (1.0 - tendToZero) * this.logStepSizeBar;
        return this.logStepSize;
    }

    public double getStepSize() {
        return this.stepSize;
    }

    public void save(Statistics statistics) {
        statistics.store(NUTS.Metrics.STEPSIZE, this.stepSize);
        statistics.store(NUTS.Metrics.MEAN_TREE_ACCEPT, this.acceptRate);
    }
}

