package io.improbable.keanu.algorithms.mcmc;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.ProbabilisticModel;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.mcmc.proposal.MHStepVariableSelector;
import io.improbable.keanu.algorithms.mcmc.proposal.PriorProposalDistribution;
import io.improbable.keanu.algorithms.mcmc.proposal.ProposalDistribution;
import io.improbable.keanu.network.NetworkState;
import io.improbable.keanu.network.SimpleNetworkState;
import io.improbable.keanu.vertices.ProbabilityCalculator;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;

/* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/SimulatedAnnealing.class */
public class SimulatedAnnealing {
    private static final MHStepVariableSelector DEFAULT_VARIABLE_SELECTOR = MHStepVariableSelector.SINGLE_VARIABLE_SELECTOR;
    private final KeanuRandom random;

    @NonNull
    private final ProposalDistribution proposalDistribution;
    private final MHStepVariableSelector variableSelector;

    @NonNull
    private final ProposalRejectionStrategy rejectionStrategy;

    /* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/SimulatedAnnealing$AnnealingSchedule.class */
    public interface AnnealingSchedule {
        double getTemperature(int i);
    }

    /* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/SimulatedAnnealing$SimulatedAnnealingBuilder.class */
    public static class SimulatedAnnealingBuilder {
        private KeanuRandom random = KeanuRandom.getDefaultRandom();
        private ProposalDistribution proposalDistribution = new PriorProposalDistribution();
        private MHStepVariableSelector variableSelector = SimulatedAnnealing.DEFAULT_VARIABLE_SELECTOR;
        private ProposalRejectionStrategy rejectionStrategy = new RollBackToCachedValuesOnRejection();

        public SimulatedAnnealingBuilder random(KeanuRandom keanuRandom) {
            this.random = keanuRandom;
            return this;
        }

        public SimulatedAnnealingBuilder proposalDistribution(ProposalDistribution proposalDistribution) {
            this.proposalDistribution = proposalDistribution;
            return this;
        }

        public SimulatedAnnealingBuilder variableSelector(MHStepVariableSelector mHStepVariableSelector) {
            this.variableSelector = mHStepVariableSelector;
            return this;
        }

        public SimulatedAnnealingBuilder rejectionStrategy(ProposalRejectionStrategy proposalRejectionStrategy) {
            this.rejectionStrategy = proposalRejectionStrategy;
            return this;
        }

        public SimulatedAnnealing build() {
            return new SimulatedAnnealing(this.random, this.proposalDistribution, this.variableSelector, this.rejectionStrategy);
        }

        public String toString() {
            return "SimulatedAnnealing.SimulatedAnnealingBuilder(random=" + this.random + ", proposalDistribution=" + this.proposalDistribution + ", variableSelector=" + this.variableSelector + ", rejectionStrategy=" + this.rejectionStrategy + ")";
        }
    }

    public static SimulatedAnnealingBuilder builder() {
        return new SimulatedAnnealingBuilder();
    }

    public NetworkState getMaxAPosteriori(ProbabilisticModel probabilisticModel, int i) {
        return getMaxAPosteriori(probabilisticModel, i, exponentialSchedule(i, 2.0d, 0.01d));
    }

    public NetworkState getMaxAPosteriori(ProbabilisticModel probabilisticModel, int i, AnnealingSchedule annealingSchedule) {
        if (ProbabilityCalculator.isImpossibleLogProb(probabilisticModel.logProb())) {
            throw new IllegalArgumentException("Cannot start optimizer on zero probability network");
        }
        HashMap hashMap = new HashMap();
        List<Variable> latentVariables = probabilisticModel.getLatentVariables();
        double logProb = probabilisticModel.logProb();
        double d = logProb;
        setSamplesAsMax(hashMap, latentVariables);
        MetropolisHastingsStep metropolisHastingsStep = new MetropolisHastingsStep(probabilisticModel, this.proposalDistribution, this.rejectionStrategy, this.random);
        for (int i2 = 0; i2 < i; i2++) {
            logProb = metropolisHastingsStep.step(Collections.singleton(latentVariables.get(i2 % latentVariables.size())), logProb, annealingSchedule.getTemperature(i2)).getLogProbabilityAfterStep();
            if (logProb > d) {
                d = logProb;
                setSamplesAsMax(hashMap, latentVariables);
            }
        }
        return new SimpleNetworkState(hashMap);
    }

    private static void setSamplesAsMax(Map<VariableReference, ?> map, List<? extends Variable> list) {
        list.forEach(variable -> {
            setSampleForVariable(variable, map);
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static <T> void setSampleForVariable(Variable<T, ?> variable, Map<VariableReference, ?> map) {
        map.put(variable.getReference(), variable.getValue());
    }

    public static AnnealingSchedule exponentialSchedule(int i, double d, double d2) {
        double log = Math.log(d2 / d) / i;
        return i2 -> {
            return d * Math.exp(log * i2);
        };
    }

    private SimulatedAnnealing(KeanuRandom keanuRandom, @NonNull ProposalDistribution proposalDistribution, MHStepVariableSelector mHStepVariableSelector, @NonNull ProposalRejectionStrategy proposalRejectionStrategy) {
        if (proposalDistribution == null) {
            throw new NullPointerException("proposalDistribution");
        }
        if (proposalRejectionStrategy == null) {
            throw new NullPointerException("rejectionStrategy");
        }
        this.random = keanuRandom;
        this.proposalDistribution = proposalDistribution;
        this.variableSelector = mHStepVariableSelector;
        this.rejectionStrategy = proposalRejectionStrategy;
    }

    public KeanuRandom getRandom() {
        return this.random;
    }

    @NonNull
    public ProposalDistribution getProposalDistribution() {
        return this.proposalDistribution;
    }

    public MHStepVariableSelector getVariableSelector() {
        return this.variableSelector;
    }

    @NonNull
    public ProposalRejectionStrategy getRejectionStrategy() {
        return this.rejectionStrategy;
    }
}
