package io.improbable.keanu.algorithms.variational.optimizer.gradient;

import io.improbable.keanu.algorithms.ProbabilisticModelWithGradient;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.variational.optimizer.FitnessFunction;
import io.improbable.keanu.algorithms.variational.optimizer.FitnessFunctionGradient;
import io.improbable.keanu.algorithms.variational.optimizer.OptimizedResult;
import io.improbable.keanu.algorithms.variational.optimizer.Optimizer;
import io.improbable.keanu.algorithms.variational.optimizer.ProbabilityFitness;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.util.status.StatusBar;
import io.improbable.keanu.vertices.ProbabilityCalculator;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;

/* loaded from: input_file:io/improbable/keanu/algorithms/variational/optimizer/gradient/GradientOptimizer.class */
public class GradientOptimizer implements Optimizer {
    private static final double FLAT_GRADIENT = 1.0E-16d;
    private final ProbabilisticModelWithGradient probabilisticModelWithGradient;
    private final GradientOptimizationAlgorithm gradientOptimizationAlgorithm;
    private final boolean checkInitialFitnessConditions;
    private final List<BiConsumer<Map<VariableReference, DoubleTensor>, Map<? extends VariableReference, DoubleTensor>>> onGradientCalculations;
    private final List<BiConsumer<Map<VariableReference, DoubleTensor>, Double>> onFitnessCalculations;

    /* loaded from: input_file:io/improbable/keanu/algorithms/variational/optimizer/gradient/GradientOptimizer$GradientOptimizerBuilder.class */
    public static class GradientOptimizerBuilder {
        private ProbabilisticModelWithGradient probabilisticModelWithGradient;
        private GradientOptimizationAlgorithm gradientOptimizationAlgorithm = ConjugateGradient.builder().build();
        private boolean checkInitialFitnessConditions = true;

        public GradientOptimizerBuilder probabilisticModel(ProbabilisticModelWithGradient probabilisticModelWithGradient) {
            this.probabilisticModelWithGradient = probabilisticModelWithGradient;
            return this;
        }

        public GradientOptimizerBuilder algorithm(GradientOptimizationAlgorithm gradientOptimizationAlgorithm) {
            this.gradientOptimizationAlgorithm = gradientOptimizationAlgorithm;
            return this;
        }

        public GradientOptimizerBuilder checkInitialFitnessConditions(boolean z) {
            this.checkInitialFitnessConditions = z;
            return this;
        }

        public GradientOptimizer build() {
            if (this.probabilisticModelWithGradient == null) {
                throw new IllegalStateException("Cannot build optimizer without specifying network to optimize.");
            }
            if (this.gradientOptimizationAlgorithm == null) {
                throw new IllegalStateException("Cannot build optimizer without specifying algorithm for optimizing.");
            }
            return new GradientOptimizer(this.probabilisticModelWithGradient, this.gradientOptimizationAlgorithm, this.checkInitialFitnessConditions);
        }
    }

    public static GradientOptimizerBuilder builder() {
        return new GradientOptimizerBuilder();
    }

    public void addGradientCalculationHandler(BiConsumer<Map<VariableReference, DoubleTensor>, Map<? extends VariableReference, DoubleTensor>> biConsumer) {
        this.onGradientCalculations.add(biConsumer);
    }

    public void removeGradientCalculationHandler(BiConsumer<Map<VariableReference, DoubleTensor>, Map<? extends VariableReference, DoubleTensor>> biConsumer) {
        this.onGradientCalculations.remove(biConsumer);
    }

    private void handleGradientCalculation(Map<VariableReference, DoubleTensor> map, Map<? extends VariableReference, DoubleTensor> map2) {
        Iterator<BiConsumer<Map<VariableReference, DoubleTensor>, Map<? extends VariableReference, DoubleTensor>>> it = this.onGradientCalculations.iterator();
        while (it.hasNext()) {
            it.next().accept(map, map2);
        }
    }

    @Override // io.improbable.keanu.algorithms.variational.optimizer.Optimizer
    public void addFitnessCalculationHandler(BiConsumer<Map<VariableReference, DoubleTensor>, Double> biConsumer) {
        this.onFitnessCalculations.add(biConsumer);
    }

    @Override // io.improbable.keanu.algorithms.variational.optimizer.Optimizer
    public void removeFitnessCalculationHandler(BiConsumer<Map<VariableReference, DoubleTensor>, Double> biConsumer) {
        this.onFitnessCalculations.remove(biConsumer);
    }

    private void handleFitnessCalculation(Map<VariableReference, DoubleTensor> map, Double d) {
        Iterator<BiConsumer<Map<VariableReference, DoubleTensor>, Double>> it = this.onFitnessCalculations.iterator();
        while (it.hasNext()) {
            it.next().accept(map, d);
        }
    }

    private void assertHasLatents() {
        if (this.probabilisticModelWithGradient.getLatentVariables().isEmpty()) {
            throw new IllegalArgumentException("Cannot find MAP of network without any latent variables");
        }
    }

    @Override // io.improbable.keanu.algorithms.variational.optimizer.Optimizer
    public OptimizedResult maxAPosteriori() {
        return optimize(ProbabilityFitness.MAP);
    }

    @Override // io.improbable.keanu.algorithms.variational.optimizer.Optimizer
    public OptimizedResult maxLikelihood() {
        return optimize(ProbabilityFitness.MLE);
    }

    private OptimizedResult optimize(ProbabilityFitness probabilityFitness) {
        assertHasLatents();
        return optimize(probabilityFitness.getFitnessFunction(this.probabilisticModelWithGradient, this::handleFitnessCalculation), probabilityFitness.getFitnessFunctionGradient(this.probabilisticModelWithGradient, this::handleGradientCalculation));
    }

    private OptimizedResult optimize(FitnessFunction fitnessFunction, FitnessFunctionGradient fitnessFunctionGradient) {
        StatusBar createFitnessStatusBar = Optimizer.createFitnessStatusBar(this);
        if (this.checkInitialFitnessConditions) {
            Map<VariableReference, DoubleTensor> convertToMapPoint = Optimizer.convertToMapPoint(this.probabilisticModelWithGradient.getLatentVariables());
            if (ProbabilityCalculator.isImpossibleLogProb(fitnessFunction.getFitnessAt(convertToMapPoint))) {
                throw new IllegalArgumentException("Cannot start optimizer on zero probability network");
            }
            throwIfGradientIsFlat(fitnessFunctionGradient.getGradientsAt(convertToMapPoint));
        }
        OptimizedResult optimize = this.gradientOptimizationAlgorithm.optimize(this.probabilisticModelWithGradient.getLatentVariables(), fitnessFunction, fitnessFunctionGradient);
        createFitnessStatusBar.finish();
        return optimize;
    }

    private static void throwIfGradientIsFlat(Map<? extends VariableReference, DoubleTensor> map) {
        double orElseThrow = map.values().stream().flatMap(doubleTensor -> {
            return Arrays.stream(doubleTensor.asFlatDoubleArray()).boxed();
        }).mapToDouble(d -> {
            return d.doubleValue();
        }).max().orElseThrow(IllegalArgumentException::new);
        if (Math.abs(orElseThrow) <= FLAT_GRADIENT) {
            throw new IllegalStateException("The initial gradient is very flat. The largest gradient is " + orElseThrow);
        }
    }

    private GradientOptimizer(ProbabilisticModelWithGradient probabilisticModelWithGradient, GradientOptimizationAlgorithm gradientOptimizationAlgorithm, boolean z) {
        this.onGradientCalculations = new ArrayList();
        this.onFitnessCalculations = new ArrayList();
        this.probabilisticModelWithGradient = probabilisticModelWithGradient;
        this.gradientOptimizationAlgorithm = gradientOptimizationAlgorithm;
        this.checkInitialFitnessConditions = z;
    }
}
