/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.variational.optimizer.gradient;

import io.improbable.keanu.algorithms.ProbabilisticModelWithGradient;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.variational.optimizer.FitnessFunction;
import io.improbable.keanu.algorithms.variational.optimizer.FitnessFunctionGradient;
import io.improbable.keanu.algorithms.variational.optimizer.OptimizedResult;
import io.improbable.keanu.algorithms.variational.optimizer.Optimizer;
import io.improbable.keanu.algorithms.variational.optimizer.ProbabilityFitness;
import io.improbable.keanu.algorithms.variational.optimizer.gradient.ConjugateGradient;
import io.improbable.keanu.algorithms.variational.optimizer.gradient.GradientOptimizationAlgorithm;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.util.status.StatusBar;
import io.improbable.keanu.vertices.ProbabilityCalculator;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;

public class GradientOptimizer
implements Optimizer {
    private static final double FLAT_GRADIENT = 1.0E-16;
    private final ProbabilisticModelWithGradient probabilisticModelWithGradient;
    private final GradientOptimizationAlgorithm gradientOptimizationAlgorithm;
    private final boolean checkInitialFitnessConditions;
    private final List<BiConsumer<Map<VariableReference, DoubleTensor>, Map<? extends VariableReference, DoubleTensor>>> onGradientCalculations = new ArrayList<BiConsumer<Map<VariableReference, DoubleTensor>, Map<? extends VariableReference, DoubleTensor>>>();
    private final List<BiConsumer<Map<VariableReference, DoubleTensor>, Double>> onFitnessCalculations = new ArrayList<BiConsumer<Map<VariableReference, DoubleTensor>, Double>>();

    public static GradientOptimizerBuilder builder() {
        return new GradientOptimizerBuilder();
    }

    public void addGradientCalculationHandler(BiConsumer<Map<VariableReference, DoubleTensor>, Map<? extends VariableReference, DoubleTensor>> gradientCalculationHandler) {
        this.onGradientCalculations.add(gradientCalculationHandler);
    }

    public void removeGradientCalculationHandler(BiConsumer<Map<VariableReference, DoubleTensor>, Map<? extends VariableReference, DoubleTensor>> gradientCalculationHandler) {
        this.onGradientCalculations.remove(gradientCalculationHandler);
    }

    private void handleGradientCalculation(Map<VariableReference, DoubleTensor> point, Map<? extends VariableReference, DoubleTensor> gradients) {
        for (BiConsumer<Map<VariableReference, DoubleTensor>, Map<? extends VariableReference, DoubleTensor>> gradientCalculationHandler : this.onGradientCalculations) {
            gradientCalculationHandler.accept(point, gradients);
        }
    }

    @Override
    public void addFitnessCalculationHandler(BiConsumer<Map<VariableReference, DoubleTensor>, Double> fitnessCalculationHandler) {
        this.onFitnessCalculations.add(fitnessCalculationHandler);
    }

    @Override
    public void removeFitnessCalculationHandler(BiConsumer<Map<VariableReference, DoubleTensor>, Double> fitnessCalculationHandler) {
        this.onFitnessCalculations.remove(fitnessCalculationHandler);
    }

    private void handleFitnessCalculation(Map<VariableReference, DoubleTensor> point, Double fitness) {
        for (BiConsumer<Map<VariableReference, DoubleTensor>, Double> fitnessCalculationHandler : this.onFitnessCalculations) {
            fitnessCalculationHandler.accept(point, fitness);
        }
    }

    private void assertHasLatents() {
        if (this.probabilisticModelWithGradient.getLatentVariables().isEmpty()) {
            throw new IllegalArgumentException("Cannot find MAP of network without any latent variables");
        }
    }

    @Override
    public OptimizedResult maxAPosteriori() {
        return this.optimize(ProbabilityFitness.MAP);
    }

    @Override
    public OptimizedResult maxLikelihood() {
        return this.optimize(ProbabilityFitness.MLE);
    }

    private OptimizedResult optimize(ProbabilityFitness probabilityFitness) {
        this.assertHasLatents();
        FitnessFunction fitnessFunction = probabilityFitness.getFitnessFunction(this.probabilisticModelWithGradient, this::handleFitnessCalculation);
        FitnessFunctionGradient fitnessFunctionGradient = probabilityFitness.getFitnessFunctionGradient(this.probabilisticModelWithGradient, this::handleGradientCalculation);
        return this.optimize(fitnessFunction, fitnessFunctionGradient);
    }

    private OptimizedResult optimize(FitnessFunction fitnessFunction, FitnessFunctionGradient fitnessFunctionGradient) {
        StatusBar statusBar = Optimizer.createFitnessStatusBar(this);
        if (this.checkInitialFitnessConditions) {
            Map<VariableReference, DoubleTensor> startingPoint = Optimizer.convertToMapPoint(this.probabilisticModelWithGradient.getLatentVariables());
            double initialFitness = fitnessFunction.getFitnessAt(startingPoint);
            if (ProbabilityCalculator.isImpossibleLogProb(initialFitness)) {
                throw new IllegalArgumentException("Cannot start optimizer on zero probability network");
            }
            Map<? extends VariableReference, DoubleTensor> initialGradient = fitnessFunctionGradient.getGradientsAt(startingPoint);
            GradientOptimizer.throwIfGradientIsFlat(initialGradient);
        }
        OptimizedResult result = this.gradientOptimizationAlgorithm.optimize(this.probabilisticModelWithGradient.getLatentVariables(), fitnessFunction, fitnessFunctionGradient);
        statusBar.finish();
        return result;
    }

    private static void throwIfGradientIsFlat(Map<? extends VariableReference, DoubleTensor> gradient) {
        double maxGradient = gradient.values().stream().flatMap(v -> Arrays.stream(v.asFlatDoubleArray()).boxed()).mapToDouble(v -> v).max().orElseThrow(IllegalArgumentException::new);
        if (Math.abs(maxGradient) <= 1.0E-16) {
            throw new IllegalStateException("The initial gradient is very flat. The largest gradient is " + maxGradient);
        }
    }

    private GradientOptimizer(ProbabilisticModelWithGradient probabilisticModelWithGradient, GradientOptimizationAlgorithm gradientOptimizationAlgorithm, boolean checkInitialFitnessConditions) {
        this.probabilisticModelWithGradient = probabilisticModelWithGradient;
        this.gradientOptimizationAlgorithm = gradientOptimizationAlgorithm;
        this.checkInitialFitnessConditions = checkInitialFitnessConditions;
    }

    public static class GradientOptimizerBuilder {
        private ProbabilisticModelWithGradient probabilisticModelWithGradient;
        private GradientOptimizationAlgorithm gradientOptimizationAlgorithm = ConjugateGradient.builder().build();
        private boolean checkInitialFitnessConditions = true;

        public GradientOptimizerBuilder probabilisticModel(ProbabilisticModelWithGradient probabilisticModelWithGradient) {
            this.probabilisticModelWithGradient = probabilisticModelWithGradient;
            return this;
        }

        public GradientOptimizerBuilder algorithm(GradientOptimizationAlgorithm gradientOptimizationAlgorithm) {
            this.gradientOptimizationAlgorithm = gradientOptimizationAlgorithm;
            return this;
        }

        public GradientOptimizerBuilder checkInitialFitnessConditions(boolean check) {
            this.checkInitialFitnessConditions = check;
            return this;
        }

        public GradientOptimizer build() {
            if (this.probabilisticModelWithGradient == null) {
                throw new IllegalStateException("Cannot build optimizer without specifying network to optimize.");
            }
            if (this.gradientOptimizationAlgorithm == null) {
                throw new IllegalStateException("Cannot build optimizer without specifying algorithm for optimizing.");
            }
            return new GradientOptimizer(this.probabilisticModelWithGradient, this.gradientOptimizationAlgorithm, this.checkInitialFitnessConditions);
        }
    }
}

