/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.mcmc;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.ProbabilisticModel;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.mcmc.MetropolisHastingsStep;
import io.improbable.keanu.algorithms.mcmc.ProposalRejectionStrategy;
import io.improbable.keanu.algorithms.mcmc.RollBackToCachedValuesOnRejection;
import io.improbable.keanu.algorithms.mcmc.proposal.MHStepVariableSelector;
import io.improbable.keanu.algorithms.mcmc.proposal.PriorProposalDistribution;
import io.improbable.keanu.algorithms.mcmc.proposal.ProposalDistribution;
import io.improbable.keanu.network.NetworkState;
import io.improbable.keanu.network.SimpleNetworkState;
import io.improbable.keanu.vertices.ProbabilityCalculator;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;

public class SimulatedAnnealing {
    private static final MHStepVariableSelector DEFAULT_VARIABLE_SELECTOR = MHStepVariableSelector.SINGLE_VARIABLE_SELECTOR;
    private final KeanuRandom random;
    @NonNull
    private final ProposalDistribution proposalDistribution;
    private final MHStepVariableSelector variableSelector;
    @NonNull
    private final ProposalRejectionStrategy rejectionStrategy;

    public static SimulatedAnnealingBuilder builder() {
        return new SimulatedAnnealingBuilder();
    }

    public NetworkState getMaxAPosteriori(ProbabilisticModel model, int sampleCount) {
        AnnealingSchedule schedule = SimulatedAnnealing.exponentialSchedule(sampleCount, 2.0, 0.01);
        return this.getMaxAPosteriori(model, sampleCount, schedule);
    }

    public NetworkState getMaxAPosteriori(ProbabilisticModel model, int sampleCount, AnnealingSchedule annealingSchedule) {
        double logProbabilityBeforeStep;
        if (ProbabilityCalculator.isImpossibleLogProb(model.logProb())) {
            throw new IllegalArgumentException("Cannot start optimizer on zero probability network");
        }
        HashMap maxSamplesByVariable = new HashMap();
        List<Variable> latentVariables = model.getLatentVariables();
        double maxLogP = logProbabilityBeforeStep = model.logProb();
        SimulatedAnnealing.setSamplesAsMax(maxSamplesByVariable, latentVariables);
        MetropolisHastingsStep mhStep = new MetropolisHastingsStep(model, this.proposalDistribution, this.rejectionStrategy, this.random);
        for (int sampleNum = 0; sampleNum < sampleCount; ++sampleNum) {
            Variable chosenVariable = latentVariables.get(sampleNum % latentVariables.size());
            double temperature = annealingSchedule.getTemperature(sampleNum);
            logProbabilityBeforeStep = mhStep.step(Collections.singleton(chosenVariable), logProbabilityBeforeStep, temperature).getLogProbabilityAfterStep();
            if (!(logProbabilityBeforeStep > maxLogP)) continue;
            maxLogP = logProbabilityBeforeStep;
            SimulatedAnnealing.setSamplesAsMax(maxSamplesByVariable, latentVariables);
        }
        return new SimpleNetworkState(maxSamplesByVariable);
    }

    private static void setSamplesAsMax(Map<VariableReference, ?> samples, List<? extends Variable> fromVariables) {
        fromVariables.forEach(variable -> SimulatedAnnealing.setSampleForVariable(variable, samples));
    }

    private static <T> void setSampleForVariable(Variable<T, ?> variable, Map<VariableReference, ?> samples) {
        samples.put(variable.getReference(), variable.getValue());
    }

    public static AnnealingSchedule exponentialSchedule(int iterations, double startT, double endT) {
        double minusK = Math.log(endT / startT) / (double)iterations;
        return n -> startT * Math.exp(minusK * (double)n);
    }

    private SimulatedAnnealing(KeanuRandom random, @NonNull ProposalDistribution proposalDistribution, MHStepVariableSelector variableSelector, @NonNull ProposalRejectionStrategy rejectionStrategy) {
        if (proposalDistribution == null) {
            throw new NullPointerException("proposalDistribution");
        }
        if (rejectionStrategy == null) {
            throw new NullPointerException("rejectionStrategy");
        }
        this.random = random;
        this.proposalDistribution = proposalDistribution;
        this.variableSelector = variableSelector;
        this.rejectionStrategy = rejectionStrategy;
    }

    public KeanuRandom getRandom() {
        return this.random;
    }

    @NonNull
    public ProposalDistribution getProposalDistribution() {
        return this.proposalDistribution;
    }

    public MHStepVariableSelector getVariableSelector() {
        return this.variableSelector;
    }

    @NonNull
    public ProposalRejectionStrategy getRejectionStrategy() {
        return this.rejectionStrategy;
    }

    static /* synthetic */ MHStepVariableSelector access$000() {
        return DEFAULT_VARIABLE_SELECTOR;
    }

    public static class SimulatedAnnealingBuilder {
        private KeanuRandom random = KeanuRandom.getDefaultRandom();
        private ProposalDistribution proposalDistribution = new PriorProposalDistribution();
        private MHStepVariableSelector variableSelector = SimulatedAnnealing.access$000();
        private ProposalRejectionStrategy rejectionStrategy = new RollBackToCachedValuesOnRejection();

        public SimulatedAnnealingBuilder random(KeanuRandom random) {
            this.random = random;
            return this;
        }

        public SimulatedAnnealingBuilder proposalDistribution(ProposalDistribution proposalDistribution) {
            this.proposalDistribution = proposalDistribution;
            return this;
        }

        public SimulatedAnnealingBuilder variableSelector(MHStepVariableSelector variableSelector) {
            this.variableSelector = variableSelector;
            return this;
        }

        public SimulatedAnnealingBuilder rejectionStrategy(ProposalRejectionStrategy rejectionStrategy) {
            this.rejectionStrategy = rejectionStrategy;
            return this;
        }

        public SimulatedAnnealing build() {
            return new SimulatedAnnealing(this.random, this.proposalDistribution, this.variableSelector, this.rejectionStrategy);
        }

        public String toString() {
            return "SimulatedAnnealing.SimulatedAnnealingBuilder(random=" + this.random + ", proposalDistribution=" + this.proposalDistribution + ", variableSelector=" + this.variableSelector + ", rejectionStrategy=" + this.rejectionStrategy + ")";
        }
    }

    public static interface AnnealingSchedule {
        public double getTemperature(int var1);
    }
}

