package io.improbable.keanu.algorithms.mcmc.nuts;

import com.google.common.base.Preconditions;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.NetworkSamples;
import io.improbable.keanu.algorithms.PosteriorSamplingAlgorithm;
import io.improbable.keanu.algorithms.ProbabilisticModel;
import io.improbable.keanu.algorithms.ProbabilisticModelWithGradient;
import io.improbable.keanu.algorithms.Statistics;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.mcmc.NetworkSamplesGenerator;
import io.improbable.keanu.algorithms.mcmc.SamplingAlgorithm;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.util.status.StatusBar;
import io.improbable.keanu.vertices.ProbabilityCalculator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/nuts/NUTS.class */
public class NUTS implements PosteriorSamplingAlgorithm {
    private final KeanuRandom random;
    private final double targetAcceptanceProb;
    private final long adaptCount;
    private final boolean adaptStepSizeEnabled;
    private final Double initialStepSize;
    private final Potential potential;
    private final boolean adaptPotentialEnabled;
    private final double maxEnergyChange;
    private final int maxTreeHeight;
    private final boolean saveStatistics;
    private final Statistics statistics;

    /* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/nuts/NUTS$Metrics.class */
    public enum Metrics {
        STEPSIZE,
        LOG_PROB,
        MEAN_TREE_ACCEPT,
        TREE_SIZE,
        DIVERGENT_SAMPLE
    }

    /* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/nuts/NUTS$NUTSBuilder.class */
    public static class NUTSBuilder {
        private KeanuRandom random = KeanuRandom.getDefaultRandom();
        private long adaptCount = 1000;
        private boolean adaptStepSizeEnabled = true;
        private Double initialStepSize = null;
        private Potential potential = new AdaptiveQuadraticPotential(0.0d, 1.0d, 10.0d, 100);
        private boolean adaptPotentialEnabled = true;
        private double targetAcceptanceProb = 0.8d;
        private double maxEnergyChange = 1000.0d;
        private int maxTreeHeight = 10;
        private boolean saveStatistics = false;

        public NUTSBuilder random(KeanuRandom keanuRandom) {
            this.random = keanuRandom;
            return this;
        }

        public NUTSBuilder targetAcceptanceProb(double d) {
            if (d > 1.0d || d < 0.0d) {
                throw new IllegalArgumentException("Target acceptance probability must be between 0.0 and 1.");
            }
            this.targetAcceptanceProb = d;
            return this;
        }

        public NUTSBuilder adaptStepSizeEnabled(boolean z) {
            this.adaptStepSizeEnabled = z;
            return this;
        }

        public NUTSBuilder adaptCount(long j) {
            if (j < 0) {
                throw new IllegalArgumentException("Adapt count must be greater than or equal to 0");
            }
            this.adaptCount = j;
            return this;
        }

        public NUTSBuilder initialStepSize(Double d) {
            if (d.doubleValue() <= 0.0d) {
                throw new IllegalArgumentException("Initial step size must be greater than 0");
            }
            this.initialStepSize = d;
            return this;
        }

        public NUTSBuilder adaptPotentialEnabled(boolean z) {
            this.adaptPotentialEnabled = z;
            return this;
        }

        public NUTSBuilder potential(Potential potential) {
            this.potential = potential;
            return this;
        }

        public NUTSBuilder maxEnergyChange(double d) {
            if (d <= 0.0d) {
                throw new IllegalArgumentException("Max energy change must be greater than 0");
            }
            this.maxEnergyChange = d;
            return this;
        }

        public NUTSBuilder maxTreeHeight(int i) {
            if (i <= 0) {
                throw new IllegalArgumentException("Max tree height must be greater than 0");
            }
            this.maxTreeHeight = i;
            return this;
        }

        public NUTSBuilder saveStatistics(boolean z) {
            this.saveStatistics = z;
            return this;
        }

        public NUTS build() {
            return new NUTS(this.random, this.targetAcceptanceProb, this.adaptCount, this.adaptStepSizeEnabled, this.initialStepSize, this.potential, this.adaptPotentialEnabled, this.maxEnergyChange, this.maxTreeHeight, this.saveStatistics);
        }

        public String toString() {
            return "NUTS.NUTSBuilder(random=" + this.random + ", adaptCount=" + this.adaptCount + ", targetAcceptanceProb=" + this.targetAcceptanceProb + ", adaptStepSizeEnabled=" + this.adaptStepSizeEnabled + ", initialStepSize=" + this.initialStepSize + ", maxTreeHeight=" + this.maxTreeHeight + ", saveStatistics=" + this.saveStatistics + ")";
        }
    }

    public static NUTSBuilder builder() {
        return new NUTSBuilder();
    }

    @Override // io.improbable.keanu.algorithms.PosteriorSamplingAlgorithm
    public NetworkSamples getPosteriorSamples(ProbabilisticModel probabilisticModel, List<? extends Variable> list, int i) {
        return generatePosteriorSamples(probabilisticModel, list).generate(i);
    }

    @Override // io.improbable.keanu.algorithms.PosteriorSamplingAlgorithm
    public NetworkSamplesGenerator generatePosteriorSamples(ProbabilisticModel probabilisticModel, List<? extends Variable> list) {
        Preconditions.checkArgument(probabilisticModel instanceof ProbabilisticModelWithGradient, "NUTS requires a model on which gradients can be calculated.");
        return new NetworkSamplesGenerator(setupSampler((ProbabilisticModelWithGradient) probabilisticModel, list), StatusBar::new);
    }

    private NUTSSampler setupSampler(ProbabilisticModelWithGradient probabilisticModelWithGradient, List<? extends Variable> list) {
        Preconditions.checkArgument(!list.isEmpty(), "List of variables to sample from is empty");
        List<? extends Variable<DoubleTensor, ?>> continuousLatentVariables = probabilisticModelWithGradient.getContinuousLatentVariables();
        Map<VariableReference, DoubleTensor> map = (Map) continuousLatentVariables.stream().collect(Collectors.toMap((v0) -> {
            return v0.getReference();
        }, (v0) -> {
            return v0.getValue();
        }));
        double logProb = probabilisticModelWithGradient.logProb(map);
        Preconditions.checkArgument(!ProbabilityCalculator.isImpossibleLogProb(logProb), "Sampler starting position is invalid. Please start from a non-zero probability position.");
        Map<VariableReference, DoubleTensor> logProbGradients = probabilisticModelWithGradient.logProbGradients();
        Map<VariableReference, ?> takeSample = SamplingAlgorithm.takeSample(list);
        AdaptiveStepSize adaptiveStepSize = new AdaptiveStepSize(this.initialStepSize == null ? AdaptiveStepSize.findStartingStepSize(0.25d, continuousLatentVariables) : this.initialStepSize.doubleValue(), this.targetAcceptanceProb, this.adaptCount);
        this.potential.initialize(map);
        return new NUTSSampler(list, probabilisticModelWithGradient, this.adaptPotentialEnabled, this.potential, this.adaptStepSizeEnabled, adaptiveStepSize, this.adaptCount, this.maxEnergyChange, this.maxTreeHeight, new Proposal(map, logProbGradients, takeSample, logProb), this.random, this.statistics, this.saveStatistics);
    }

    public Statistics getStatistics() {
        return this.statistics;
    }

    private NUTS(KeanuRandom keanuRandom, double d, long j, boolean z, Double d2, Potential potential, boolean z2, double d3, int i, boolean z3) {
        this.statistics = new Statistics(Metrics.values());
        this.random = keanuRandom;
        this.targetAcceptanceProb = d;
        this.adaptCount = j;
        this.adaptStepSizeEnabled = z;
        this.initialStepSize = d2;
        this.potential = potential;
        this.adaptPotentialEnabled = z2;
        this.maxEnergyChange = d3;
        this.maxTreeHeight = i;
        this.saveStatistics = z3;
    }

    public KeanuRandom getRandom() {
        return this.random;
    }

    public double getTargetAcceptanceProb() {
        return this.targetAcceptanceProb;
    }

    public long getAdaptCount() {
        return this.adaptCount;
    }
}
