package io.improbable.keanu.algorithms.mcmc;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.ProbabilisticModel;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.mcmc.proposal.Proposal;
import io.improbable.keanu.algorithms.mcmc.proposal.ProposalDistribution;
import io.improbable.keanu.vertices.ProbabilityCalculator;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/MetropolisHastingsStep.class */
public class MetropolisHastingsStep {
    private static final Logger log = LoggerFactory.getLogger(MetropolisHastingsStep.class);
    private static final double DEFAULT_TEMPERATURE = 1.0d;
    private final ProbabilisticModel model;
    private final ProposalDistribution proposalDistribution;
    private final ProposalRejectionStrategy rejectionStrategy;
    private final KeanuRandom random;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/MetropolisHastingsStep$StepResult.class */
    public static final class StepResult {
        private final boolean accepted;
        private final double logProbabilityAfterStep;

        public StepResult(boolean z, double d) {
            this.accepted = z;
            this.logProbabilityAfterStep = d;
        }

        public boolean isAccepted() {
            return this.accepted;
        }

        public double getLogProbabilityAfterStep() {
            return this.logProbabilityAfterStep;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof StepResult)) {
                return false;
            }
            StepResult stepResult = (StepResult) obj;
            return isAccepted() == stepResult.isAccepted() && Double.compare(getLogProbabilityAfterStep(), stepResult.getLogProbabilityAfterStep()) == 0;
        }

        public int hashCode() {
            int i = (1 * 59) + (isAccepted() ? 79 : 97);
            long doubleToLongBits = Double.doubleToLongBits(getLogProbabilityAfterStep());
            return (i * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        }

        public String toString() {
            return "MetropolisHastingsStep.StepResult(accepted=" + isAccepted() + ", logProbabilityAfterStep=" + getLogProbabilityAfterStep() + ")";
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public MetropolisHastingsStep(ProbabilisticModel probabilisticModel, ProposalDistribution proposalDistribution, ProposalRejectionStrategy proposalRejectionStrategy, KeanuRandom keanuRandom) {
        this.model = probabilisticModel;
        this.proposalDistribution = proposalDistribution;
        this.rejectionStrategy = proposalRejectionStrategy;
        this.random = keanuRandom;
    }

    public StepResult step(Set<Variable> set, double d) {
        return step(set, d, DEFAULT_TEMPERATURE);
    }

    public StepResult step(Set<Variable> set, double d, double d2) {
        Proposal proposal = this.proposalDistribution.getProposal(set, this.random);
        this.rejectionStrategy.onProposalCreated(proposal);
        double logProbAfter = this.model.logProbAfter(proposal.getProposalTo(), d);
        if (!ProbabilityCalculator.isImpossibleLogProb(logProbAfter)) {
            if (Math.exp(((DEFAULT_TEMPERATURE / d2) * (logProbAfter - d)) + (this.proposalDistribution.logProbAtFromGivenTo(proposal) - this.proposalDistribution.logProbAtToGivenFrom(proposal))) >= this.random.nextDouble()) {
                return new StepResult(true, logProbAfter);
            }
        }
        this.proposalDistribution.onProposalRejected();
        this.rejectionStrategy.onProposalRejected(proposal);
        return new StepResult(false, d);
    }
}
