/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.mcmc.nuts;

import com.google.common.base.Preconditions;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.NetworkSamples;
import io.improbable.keanu.algorithms.PosteriorSamplingAlgorithm;
import io.improbable.keanu.algorithms.ProbabilisticModel;
import io.improbable.keanu.algorithms.ProbabilisticModelWithGradient;
import io.improbable.keanu.algorithms.Statistics;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.mcmc.NetworkSamplesGenerator;
import io.improbable.keanu.algorithms.mcmc.SamplingAlgorithm;
import io.improbable.keanu.algorithms.mcmc.nuts.AdaptiveQuadraticPotential;
import io.improbable.keanu.algorithms.mcmc.nuts.AdaptiveStepSize;
import io.improbable.keanu.algorithms.mcmc.nuts.NUTSSampler;
import io.improbable.keanu.algorithms.mcmc.nuts.Potential;
import io.improbable.keanu.algorithms.mcmc.nuts.Proposal;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.util.status.StatusBar;
import io.improbable.keanu.vertices.ProbabilityCalculator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class NUTS
implements PosteriorSamplingAlgorithm {
    private final KeanuRandom random;
    private final double targetAcceptanceProb;
    private final long adaptCount;
    private final boolean adaptStepSizeEnabled;
    private final Double initialStepSize;
    private final Potential potential;
    private final boolean adaptPotentialEnabled;
    private final double maxEnergyChange;
    private final int maxTreeHeight;
    private final boolean saveStatistics;
    private final Statistics statistics = new Statistics(Metrics.values());

    public static NUTSBuilder builder() {
        return new NUTSBuilder();
    }

    @Override
    public NetworkSamples getPosteriorSamples(ProbabilisticModel model, List<? extends Variable> variablesToSampleFrom, int sampleCount) {
        return this.generatePosteriorSamples(model, variablesToSampleFrom).generate(sampleCount);
    }

    @Override
    public NetworkSamplesGenerator generatePosteriorSamples(ProbabilisticModel model, List<? extends Variable> fromVariables) {
        Preconditions.checkArgument((boolean)(model instanceof ProbabilisticModelWithGradient), (Object)"NUTS requires a model on which gradients can be calculated.");
        return new NetworkSamplesGenerator(this.setupSampler((ProbabilisticModelWithGradient)model, fromVariables), StatusBar::new);
    }

    private NUTSSampler setupSampler(ProbabilisticModelWithGradient model, List<? extends Variable> sampleFromVariables) {
        Preconditions.checkArgument((!sampleFromVariables.isEmpty() ? 1 : 0) != 0, (Object)"List of variables to sample from is empty");
        List<Variable<DoubleTensor, ?>> latentVariables = model.getContinuousLatentVariables();
        Map<VariableReference, DoubleTensor> position = latentVariables.stream().collect(Collectors.toMap(Variable::getReference, Variable::getValue));
        double initialLogOfMasterP = model.logProb(position);
        Preconditions.checkArgument((!ProbabilityCalculator.isImpossibleLogProb(initialLogOfMasterP) ? 1 : 0) != 0, (Object)"Sampler starting position is invalid. Please start from a non-zero probability position.");
        Map<VariableReference, DoubleTensor> gradient = model.logProbGradients();
        Map<VariableReference, ?> startingSample = SamplingAlgorithm.takeSample(sampleFromVariables);
        double startingStepSize = this.initialStepSize == null ? AdaptiveStepSize.findStartingStepSize(0.25, latentVariables) : this.initialStepSize;
        AdaptiveStepSize stepSize = new AdaptiveStepSize(startingStepSize, this.targetAcceptanceProb, this.adaptCount);
        this.potential.initialize(position);
        Proposal initialProposal = new Proposal(position, gradient, startingSample, initialLogOfMasterP);
        return new NUTSSampler(sampleFromVariables, model, this.adaptPotentialEnabled, this.potential, this.adaptStepSizeEnabled, stepSize, this.adaptCount, this.maxEnergyChange, this.maxTreeHeight, initialProposal, this.random, this.statistics, this.saveStatistics);
    }

    public Statistics getStatistics() {
        return this.statistics;
    }

    private NUTS(KeanuRandom random, double targetAcceptanceProb, long adaptCount, boolean adaptStepSizeEnabled, Double initialStepSize, Potential potential, boolean adaptPotentialEnabled, double maxEnergyChange, int maxTreeHeight, boolean saveStatistics) {
        this.random = random;
        this.targetAcceptanceProb = targetAcceptanceProb;
        this.adaptCount = adaptCount;
        this.adaptStepSizeEnabled = adaptStepSizeEnabled;
        this.initialStepSize = initialStepSize;
        this.potential = potential;
        this.adaptPotentialEnabled = adaptPotentialEnabled;
        this.maxEnergyChange = maxEnergyChange;
        this.maxTreeHeight = maxTreeHeight;
        this.saveStatistics = saveStatistics;
    }

    public KeanuRandom getRandom() {
        return this.random;
    }

    public double getTargetAcceptanceProb() {
        return this.targetAcceptanceProb;
    }

    public long getAdaptCount() {
        return this.adaptCount;
    }

    public static class NUTSBuilder {
        private KeanuRandom random = KeanuRandom.getDefaultRandom();
        private long adaptCount = 1000L;
        private boolean adaptStepSizeEnabled = true;
        private Double initialStepSize = null;
        private Potential potential = new AdaptiveQuadraticPotential(0.0, 1.0, 10.0, 100);
        private boolean adaptPotentialEnabled = true;
        private double targetAcceptanceProb = 0.8;
        private double maxEnergyChange = 1000.0;
        private int maxTreeHeight = 10;
        private boolean saveStatistics = false;

        public NUTSBuilder random(KeanuRandom random) {
            this.random = random;
            return this;
        }

        public NUTSBuilder targetAcceptanceProb(double targetAcceptanceProb) {
            if (targetAcceptanceProb > 1.0 || targetAcceptanceProb < 0.0) {
                throw new IllegalArgumentException("Target acceptance probability must be between 0.0 and 1.");
            }
            this.targetAcceptanceProb = targetAcceptanceProb;
            return this;
        }

        public NUTSBuilder adaptStepSizeEnabled(boolean adaptEnabled) {
            this.adaptStepSizeEnabled = adaptEnabled;
            return this;
        }

        public NUTSBuilder adaptCount(long adaptCount) {
            if (adaptCount < 0L) {
                throw new IllegalArgumentException("Adapt count must be greater than or equal to 0");
            }
            this.adaptCount = adaptCount;
            return this;
        }

        public NUTSBuilder initialStepSize(Double initialStepSize) {
            if (initialStepSize <= 0.0) {
                throw new IllegalArgumentException("Initial step size must be greater than 0");
            }
            this.initialStepSize = initialStepSize;
            return this;
        }

        public NUTSBuilder adaptPotentialEnabled(boolean adaptPotentialEnabled) {
            this.adaptPotentialEnabled = adaptPotentialEnabled;
            return this;
        }

        public NUTSBuilder potential(Potential potential) {
            this.potential = potential;
            return this;
        }

        public NUTSBuilder maxEnergyChange(double maxEnergyChange) {
            if (maxEnergyChange <= 0.0) {
                throw new IllegalArgumentException("Max energy change must be greater than 0");
            }
            this.maxEnergyChange = maxEnergyChange;
            return this;
        }

        public NUTSBuilder maxTreeHeight(int maxTreeHeight) {
            if (maxTreeHeight <= 0) {
                throw new IllegalArgumentException("Max tree height must be greater than 0");
            }
            this.maxTreeHeight = maxTreeHeight;
            return this;
        }

        public NUTSBuilder saveStatistics(boolean saveStatistics) {
            this.saveStatistics = saveStatistics;
            return this;
        }

        public NUTS build() {
            return new NUTS(this.random, this.targetAcceptanceProb, this.adaptCount, this.adaptStepSizeEnabled, this.initialStepSize, this.potential, this.adaptPotentialEnabled, this.maxEnergyChange, this.maxTreeHeight, this.saveStatistics);
        }

        public String toString() {
            return "NUTS.NUTSBuilder(random=" + this.random + ", adaptCount=" + this.adaptCount + ", targetAcceptanceProb=" + this.targetAcceptanceProb + ", adaptStepSizeEnabled=" + this.adaptStepSizeEnabled + ", initialStepSize=" + this.initialStepSize + ", maxTreeHeight=" + this.maxTreeHeight + ", saveStatistics=" + this.saveStatistics + ")";
        }
    }

    public static enum Metrics {
        STEPSIZE,
        LOG_PROB,
        MEAN_TREE_ACCEPT,
        TREE_SIZE,
        DIVERGENT_SAMPLE;

    }
}

