package io.improbable.keanu.algorithms.variational;

import com.google.common.collect.ImmutableList;
import io.improbable.keanu.Keanu;
import io.improbable.keanu.algorithms.Samples;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.network.KeanuProbabilisticModel;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.probabilistic.KDEVertex;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

/* loaded from: input_file:io/improbable/keanu/algorithms/variational/GaussianKDE.class */
public class GaussianKDE {
    public static KDEVertex approximate(Samples<DoubleTensor> samples) {
        return new KDEVertex((List<Double>) samples.asList().stream().map(GaussianKDE::checkIfScalar).map(doubleTensor -> {
            return (Double) doubleTensor.scalar();
        }).collect(Collectors.toList()));
    }

    public static KDEVertex approximate(DoubleVertex doubleVertex, Integer num) {
        return approximate(Keanu.Sampling.MetropolisHastings.withDefaultConfig().getPosteriorSamples(new KeanuProbabilisticModel(doubleVertex.getConnectedGraph()), (List<? extends Variable>) ImmutableList.of(doubleVertex), num.intValue()).getDoubleTensorSamples(doubleVertex));
    }

    private static DoubleTensor checkIfScalar(DoubleTensor doubleTensor) throws IllegalArgumentException {
        if (doubleTensor.isScalar()) {
            return doubleTensor;
        }
        throw new IllegalArgumentException("The provided samples are not scalars, but have shape " + Arrays.toString(doubleTensor.getShape()));
    }
}
