/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.variational.optimizer;

import com.google.common.primitives.Ints;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.variational.optimizer.OptimizedResult;
import io.improbable.keanu.tensor.NumberTensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.util.status.AverageTimeComponent;
import io.improbable.keanu.util.status.StatusBar;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;

public interface Optimizer {
    public void addFitnessCalculationHandler(BiConsumer<Map<VariableReference, DoubleTensor>, Double> var1);

    public void removeFitnessCalculationHandler(BiConsumer<Map<VariableReference, DoubleTensor>, Double> var1);

    public OptimizedResult maxAPosteriori();

    public OptimizedResult maxLikelihood();

    public static double[] convertToArrayPoint(List<? extends Variable<? extends NumberTensor, ?>> latentVariables) {
        List<long[]> shapes = latentVariables.stream().map(Variable::getShape).collect(Collectors.toList());
        long totalLatentDimensions = Optimizer.totalNumberOfLatentDimensions(shapes);
        if (totalLatentDimensions > Integer.MAX_VALUE) {
            throw new IllegalArgumentException("Greater than 2147483647 latent dimensions not supported");
        }
        int position = 0;
        double[] point = new double[(int)totalLatentDimensions];
        for (Variable<NumberTensor, ?> variable : latentVariables) {
            double[] values = variable.getValue().asFlatDoubleArray();
            System.arraycopy(values, 0, point, position, values.length);
            position += values.length;
        }
        return point;
    }

    public static Map<VariableReference, DoubleTensor> convertToMapPoint(List<? extends Variable> variables) {
        return variables.stream().collect(Collectors.toMap(Variable::getReference, v -> Optimizer.toDoubleTensorVariable(v).getValue()));
    }

    public static Map<VariableReference, DoubleTensor> convertFromPoint(double[] point, List<? extends Variable> latentVariables) {
        HashMap<VariableReference, DoubleTensor> tensors = new HashMap<VariableReference, DoubleTensor>();
        int position = 0;
        for (Variable variable : latentVariables) {
            int dimensions = Ints.checkedCast((long)TensorShape.getLength(variable.getShape()));
            double[] values = new double[dimensions];
            System.arraycopy(point, position, values, 0, dimensions);
            DoubleTensor newTensor = DoubleTensor.create(values, variable.getShape());
            tensors.put(variable.getReference(), newTensor);
            position += dimensions;
        }
        return tensors;
    }

    public static long totalNumberOfLatentDimensions(List<long[]> continuousLatentVariableShapes) {
        return continuousLatentVariableShapes.stream().mapToLong(Optimizer::numDimensions).sum();
    }

    public static long numDimensions(long[] shape) {
        return TensorShape.getLength(shape);
    }

    public static List<Variable<? extends DoubleTensor, ?>> getAsDoubleTensors(List<? extends Variable> variables) {
        return variables.stream().map(v -> Optimizer.toDoubleTensorVariable(v)).collect(Collectors.toList());
    }

    public static Variable<DoubleTensor, ?> toDoubleTensorVariable(Variable<?, ?> v) {
        if (v.getValue() instanceof DoubleTensor) {
            return v;
        }
        throw new UnsupportedOperationException("Optimization unsupported on networks containing discrete latents. Discrete latent : " + v.getReference() + " found.");
    }

    public static StatusBar createFitnessStatusBar(Optimizer optimizerThatNeedsStatusBar) {
        AtomicInteger evalCount = new AtomicInteger(0);
        StatusBar statusBar = new StatusBar();
        AverageTimeComponent averageTimeComponent = new AverageTimeComponent();
        statusBar.addComponent(averageTimeComponent);
        BiConsumer<Map<VariableReference, DoubleTensor>, Double> statusBarFitnessCalculationHandler = (position, logProb) -> {
            statusBar.setMessage(String.format("Fitness Evaluation #%d LogProb: %.2f", evalCount.incrementAndGet(), logProb));
            averageTimeComponent.step();
        };
        optimizerThatNeedsStatusBar.addFitnessCalculationHandler(statusBarFitnessCalculationHandler);
        statusBar.addFinishHandler(() -> optimizerThatNeedsStatusBar.removeFitnessCalculationHandler(statusBarFitnessCalculationHandler));
        return statusBar;
    }
}

