/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.mcmc.nuts;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.NetworkSample;
import io.improbable.keanu.algorithms.ProbabilisticModelWithGradient;
import io.improbable.keanu.algorithms.Statistics;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.mcmc.SamplingAlgorithm;
import io.improbable.keanu.algorithms.mcmc.nuts.AdaptiveStepSize;
import io.improbable.keanu.algorithms.mcmc.nuts.LeapfrogIntegrator;
import io.improbable.keanu.algorithms.mcmc.nuts.LeapfrogState;
import io.improbable.keanu.algorithms.mcmc.nuts.NUTS;
import io.improbable.keanu.algorithms.mcmc.nuts.Potential;
import io.improbable.keanu.algorithms.mcmc.nuts.Proposal;
import io.improbable.keanu.algorithms.mcmc.nuts.Tree;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class NUTSSampler
implements SamplingAlgorithm {
    private static final Logger log = LoggerFactory.getLogger(NUTSSampler.class);
    private final List<? extends Variable> sampleFromVariables;
    private final ProbabilisticModelWithGradient logProbGradientCalculator;
    private final LeapfrogIntegrator leapfrogIntegrator;
    private final boolean adaptPotentialEnabled;
    private final Potential potential;
    private final boolean adaptStepSizeEnabled;
    private final AdaptiveStepSize stepSize;
    private final long adaptCount;
    private long stepCount;
    private final int maxTreeHeight;
    private Proposal proposal;
    private final KeanuRandom random;
    private final double maxEnergyChange;
    private final Statistics statistics;
    private final boolean saveStatistics;

    public NUTSSampler(List<? extends Variable> sampleFromVariables, ProbabilisticModelWithGradient logProbGradientCalculator, boolean adaptPotentialEnabled, Potential potential, boolean adaptStepSizeEnabled, AdaptiveStepSize stepSize, long adaptCount, double maxEnergyChange, int maxTreeHeight, Proposal initialProposal, KeanuRandom random, Statistics statistics, boolean saveStatistics) {
        this.sampleFromVariables = sampleFromVariables;
        this.logProbGradientCalculator = logProbGradientCalculator;
        this.leapfrogIntegrator = new LeapfrogIntegrator(potential);
        this.adaptPotentialEnabled = adaptPotentialEnabled;
        this.potential = potential;
        this.adaptStepSizeEnabled = adaptStepSizeEnabled;
        this.stepSize = stepSize;
        this.adaptCount = adaptCount;
        this.stepCount = 0L;
        this.maxEnergyChange = maxEnergyChange;
        this.maxTreeHeight = maxTreeHeight;
        this.proposal = initialProposal;
        this.random = random;
        this.statistics = statistics;
        this.saveStatistics = saveStatistics;
    }

    @Override
    public void sample(Map<VariableReference, List<?>> samples, List<Double> logOfMasterPForEachSample) {
        this.step();
        NUTSSampler.addSampleFromCache(samples, this.proposal.getSample());
        logOfMasterPForEachSample.add(this.proposal.getLogProb());
    }

    @Override
    public NetworkSample sample() {
        this.step();
        return new NetworkSample(this.proposal.getSample(), this.proposal.getLogProb());
    }

    @Override
    public void step() {
        Map<VariableReference, DoubleTensor> initialMomentum = this.potential.randomMomentum(this.random);
        LeapfrogState startState = new LeapfrogState(this.proposal.getPosition(), initialMomentum, this.proposal.getGradient(), this.proposal.getLogProb(), this.potential);
        Tree tree = new Tree(startState, this.proposal, this.maxEnergyChange, this.logProbGradientCalculator, this.leapfrogIntegrator, this.sampleFromVariables, this.random);
        while (tree.shouldContinue() && tree.getTreeHeight() < this.maxTreeHeight) {
            int buildDirection = this.random.nextBoolean() ? 1 : -1;
            tree.grow(buildDirection, this.stepSize.getStepSize());
        }
        this.proposal = tree.getProposal();
        if (this.saveStatistics) {
            this.stepSize.save(this.statistics);
            tree.save(this.statistics);
        }
        if (this.adaptStepSizeEnabled) {
            this.stepSize.adaptStepSize(tree);
        }
        if (this.stepCount < this.adaptCount && this.adaptPotentialEnabled) {
            this.potential.update(this.proposal.getPosition());
        }
        if (this.stepCount > this.adaptCount && tree.isDiverged()) {
            this.statistics.store(NUTS.Metrics.DIVERGENT_SAMPLE, Double.valueOf(this.stepCount));
            log.warn("Divergent NUTS sample after adaption ended. Increase the number or samples to adapt for or the max energy change.");
        }
        ++this.stepCount;
    }

    private static void addSampleFromCache(Map<VariableReference, List<?>> samples, Map<VariableReference, ?> cachedSample) {
        for (Map.Entry<VariableReference, ?> sampleEntry : cachedSample.entrySet()) {
            NUTSSampler.addSampleForVariable(sampleEntry.getKey(), sampleEntry.getValue(), samples);
        }
    }

    private static <T> void addSampleForVariable(VariableReference id, T value, Map<VariableReference, List<?>> samples) {
        List samplesForVariable = samples.computeIfAbsent(id, v -> new ArrayList());
        samplesForVariable.add(value);
    }
}

