package io.improbable.keanu.algorithms.variational.optimizer;

import com.google.common.primitives.Ints;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.tensor.NumberTensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.util.status.AverageTimeComponent;
import io.improbable.keanu.util.status.StatusBar;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;

/* loaded from: input_file:io/improbable/keanu/algorithms/variational/optimizer/Optimizer.class */
public interface Optimizer {
    void addFitnessCalculationHandler(BiConsumer<Map<VariableReference, DoubleTensor>, Double> biConsumer);

    void removeFitnessCalculationHandler(BiConsumer<Map<VariableReference, DoubleTensor>, Double> biConsumer);

    OptimizedResult maxAPosteriori();

    OptimizedResult maxLikelihood();

    static double[] convertToArrayPoint(List<? extends Variable<? extends NumberTensor, ?>> list) {
        long j = totalNumberOfLatentDimensions((List) list.stream().map((v0) -> {
            return v0.getShape();
        }).collect(Collectors.toList()));
        if (j > 2147483647L) {
            throw new IllegalArgumentException("Greater than 2147483647 latent dimensions not supported");
        }
        int i = 0;
        double[] dArr = new double[(int) j];
        Iterator<? extends Variable<? extends NumberTensor, ?>> it = list.iterator();
        while (it.hasNext()) {
            double[] asFlatDoubleArray = it.next().getValue().asFlatDoubleArray();
            System.arraycopy(asFlatDoubleArray, 0, dArr, i, asFlatDoubleArray.length);
            i += asFlatDoubleArray.length;
        }
        return dArr;
    }

    static Map<VariableReference, DoubleTensor> convertToMapPoint(List<? extends Variable> list) {
        return (Map) list.stream().collect(Collectors.toMap((v0) -> {
            return v0.getReference();
        }, variable -> {
            return toDoubleTensorVariable(variable).getValue();
        }));
    }

    static Map<VariableReference, DoubleTensor> convertFromPoint(double[] dArr, List<? extends Variable> list) {
        HashMap hashMap = new HashMap();
        int i = 0;
        for (Variable variable : list) {
            int checkedCast = Ints.checkedCast(TensorShape.getLength(variable.getShape()));
            double[] dArr2 = new double[checkedCast];
            System.arraycopy(dArr, i, dArr2, 0, checkedCast);
            hashMap.put(variable.getReference(), DoubleTensor.create(dArr2, variable.getShape()));
            i += checkedCast;
        }
        return hashMap;
    }

    static long totalNumberOfLatentDimensions(List<long[]> list) {
        return list.stream().mapToLong(Optimizer::numDimensions).sum();
    }

    static long numDimensions(long[] jArr) {
        return TensorShape.getLength(jArr);
    }

    static List<Variable<? extends DoubleTensor, ?>> getAsDoubleTensors(List<? extends Variable> list) {
        return (List) list.stream().map(variable -> {
            return toDoubleTensorVariable(variable);
        }).collect(Collectors.toList());
    }

    /* JADX WARN: Multi-variable type inference failed */
    static Variable<DoubleTensor, ?> toDoubleTensorVariable(Variable<?, ?> variable) {
        if (variable.getValue() instanceof DoubleTensor) {
            return variable;
        }
        throw new UnsupportedOperationException("Optimization unsupported on networks containing discrete latents. Discrete latent : " + variable.getReference() + " found.");
    }

    static StatusBar createFitnessStatusBar(Optimizer optimizer) {
        AtomicInteger atomicInteger = new AtomicInteger(0);
        StatusBar statusBar = new StatusBar();
        AverageTimeComponent averageTimeComponent = new AverageTimeComponent();
        statusBar.addComponent(averageTimeComponent);
        BiConsumer<Map<VariableReference, DoubleTensor>, Double> biConsumer = (map, d) -> {
            statusBar.setMessage(String.format("Fitness Evaluation #%d LogProb: %.2f", Integer.valueOf(atomicInteger.incrementAndGet()), d));
            averageTimeComponent.step();
        };
        optimizer.addFitnessCalculationHandler(biConsumer);
        statusBar.addFinishHandler(() -> {
            optimizer.removeFitnessCalculationHandler(biConsumer);
        });
        return statusBar;
    }
}
