package io.improbable.keanu.algorithms.variational.optimizer.nongradient;

import io.improbable.keanu.algorithms.ProbabilisticModel;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.variational.optimizer.FitnessFunction;
import io.improbable.keanu.algorithms.variational.optimizer.OptimizedResult;
import io.improbable.keanu.algorithms.variational.optimizer.Optimizer;
import io.improbable.keanu.algorithms.variational.optimizer.ProbabilityFitness;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.util.status.StatusBar;
import io.improbable.keanu.vertices.ProbabilityCalculator;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;

/* loaded from: input_file:io/improbable/keanu/algorithms/variational/optimizer/nongradient/NonGradientOptimizer.class */
public class NonGradientOptimizer implements Optimizer {
    private final ProbabilisticModel probabilisticModel;
    private final NonGradientOptimizationAlgorithm nonGradientOptimizationAlgorithm;
    private final boolean checkInitialFitnessConditions;
    private final List<BiConsumer<Map<VariableReference, DoubleTensor>, Double>> onFitnessCalculations;

    /* loaded from: input_file:io/improbable/keanu/algorithms/variational/optimizer/nongradient/NonGradientOptimizer$NonGradientOptimizerBuilder.class */
    public static class NonGradientOptimizerBuilder {
        private ProbabilisticModel probabilisticModel;
        private NonGradientOptimizationAlgorithm nonGradientOptimizationAlgorithm = BOBYQA.builder().build();
        private boolean checkInitialFitnessConditions;

        public NonGradientOptimizerBuilder probabilisticModel(ProbabilisticModel probabilisticModel) {
            this.probabilisticModel = probabilisticModel;
            return this;
        }

        public NonGradientOptimizerBuilder algorithm(NonGradientOptimizationAlgorithm nonGradientOptimizationAlgorithm) {
            this.nonGradientOptimizationAlgorithm = nonGradientOptimizationAlgorithm;
            return this;
        }

        public NonGradientOptimizerBuilder checkInitialFitnessConditions(boolean z) {
            this.checkInitialFitnessConditions = z;
            return this;
        }

        public NonGradientOptimizer build() {
            if (this.probabilisticModel == null) {
                throw new IllegalStateException("Cannot build optimizer without specifying network to optimize.");
            }
            return new NonGradientOptimizer(this.probabilisticModel, this.nonGradientOptimizationAlgorithm, this.checkInitialFitnessConditions);
        }

        public String toString() {
            return "NonGradientOptimizer.NonGradientOptimizerBuilder(probabilisticModel=" + this.probabilisticModel + ", nonGradientOptimizationAlgorithm=" + this.nonGradientOptimizationAlgorithm + ", checkInitialFitnessConditions=" + this.checkInitialFitnessConditions + ")";
        }
    }

    public static NonGradientOptimizerBuilder builder() {
        return new NonGradientOptimizerBuilder();
    }

    @Override // io.improbable.keanu.algorithms.variational.optimizer.Optimizer
    public void addFitnessCalculationHandler(BiConsumer<Map<VariableReference, DoubleTensor>, Double> biConsumer) {
        this.onFitnessCalculations.add(biConsumer);
    }

    @Override // io.improbable.keanu.algorithms.variational.optimizer.Optimizer
    public void removeFitnessCalculationHandler(BiConsumer<Map<VariableReference, DoubleTensor>, Double> biConsumer) {
        this.onFitnessCalculations.remove(biConsumer);
    }

    private void handleFitnessCalculation(Map<VariableReference, DoubleTensor> map, Double d) {
        Iterator<BiConsumer<Map<VariableReference, DoubleTensor>, Double>> it = this.onFitnessCalculations.iterator();
        while (it.hasNext()) {
            it.next().accept(map, d);
        }
    }

    private OptimizedResult optimize(ProbabilityFitness probabilityFitness) {
        return optimize(probabilityFitness.getFitnessFunction(this.probabilisticModel, this::handleFitnessCalculation));
    }

    private OptimizedResult optimize(FitnessFunction fitnessFunction) {
        StatusBar createFitnessStatusBar = Optimizer.createFitnessStatusBar(this);
        List<Variable> latentVariables = this.probabilisticModel.getLatentVariables();
        if (this.checkInitialFitnessConditions && ProbabilityCalculator.isImpossibleLogProb(fitnessFunction.getFitnessAt(Optimizer.convertToMapPoint(latentVariables)))) {
            throw new IllegalArgumentException("Cannot start optimizer on zero probability network");
        }
        OptimizedResult optimize = this.nonGradientOptimizationAlgorithm.optimize(latentVariables, fitnessFunction);
        createFitnessStatusBar.finish();
        return optimize;
    }

    @Override // io.improbable.keanu.algorithms.variational.optimizer.Optimizer
    public OptimizedResult maxAPosteriori() {
        return optimize(ProbabilityFitness.MAP);
    }

    @Override // io.improbable.keanu.algorithms.variational.optimizer.Optimizer
    public OptimizedResult maxLikelihood() {
        return optimize(ProbabilityFitness.MLE);
    }

    private NonGradientOptimizer(ProbabilisticModel probabilisticModel, NonGradientOptimizationAlgorithm nonGradientOptimizationAlgorithm, boolean z) {
        this.onFitnessCalculations = new ArrayList();
        this.probabilisticModel = probabilisticModel;
        this.nonGradientOptimizationAlgorithm = nonGradientOptimizationAlgorithm;
        this.checkInitialFitnessConditions = z;
    }
}
