package io.improbable.keanu.vertices.dbl.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.continuous.Uniform;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.Samplable;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/probabilistic/KDEVertex.class */
public class KDEVertex extends DoubleVertex implements Differentiable, ProbabilisticDouble, Samplable<DoubleTensor> {
    private final double bandwidth;
    private DoubleTensor samples;
    private static final String BANDWIDTH_NAME = "bandwidth";
    private static final String SAMPLES_NAME = "samples";

    @ExportVertexToPythonBindings
    public KDEVertex(@LoadVertexParam("samples") DoubleTensor doubleTensor, @LoadVertexParam("bandwidth") double d) {
        super(Tensor.SCALAR_SHAPE);
        if (doubleTensor.getLength() == 0) {
            throw new IllegalStateException("The provided tensor of samples is empty!");
        }
        this.samples = doubleTensor;
        this.bandwidth = d;
    }

    public KDEVertex(DoubleTensor doubleTensor) {
        this(doubleTensor, scottsBandwidth(doubleTensor));
    }

    public KDEVertex(List<Double> list) {
        this(DoubleTensor.create(list.stream().mapToDouble((v0) -> {
            return v0.doubleValue();
        }).toArray(), list.size()));
    }

    public KDEVertex(List<Double> list, double d) {
        this(DoubleTensor.create(list.stream().mapToDouble((v0) -> {
            return v0.doubleValue();
        }).toArray()), d);
    }

    @SaveVertexParam(BANDWIDTH_NAME)
    public double getBandwidth() {
        return this.bandwidth;
    }

    @SaveVertexParam(SAMPLES_NAME)
    public DoubleTensor getInputSamples() {
        return this.samples;
    }

    private DoubleTensor getDiffs(DoubleTensor doubleTensor) {
        return (DoubleTensor) ((DoubleTensor) ((DoubleTensor) doubleTensor.broadcast(this.samples.getShape()[0], doubleTensor.getShape()[0])).minusInPlace((DoubleTensor) this.samples.reshape(this.samples.getShape()[0], 1))).divInPlace((DoubleTensor) Double.valueOf(this.bandwidth));
    }

    public DoubleTensor pdf(DoubleTensor doubleTensor) {
        return (DoubleTensor) ((DoubleTensor) gaussianKernel(getDiffs(doubleTensor)).sum(0)).divInPlace((DoubleTensor) Double.valueOf(this.samples.getLength() * this.bandwidth));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.vertices.Probabilistic
    public double logProb(DoubleTensor doubleTensor) {
        return ((Double) pdf(doubleTensor).log2().sum()).doubleValue();
    }

    /* renamed from: dLogProb, reason: avoid collision after fix types in other method */
    public Map<Vertex, DoubleTensor> dLogProb2(DoubleTensor doubleTensor, Set<? extends Vertex> set) {
        HashMap hashMap = new HashMap();
        if (set.contains(this)) {
            hashMap.put(this, (DoubleTensor) dPdx(doubleTensor).divInPlace(pdf(doubleTensor)));
        }
        return hashMap;
    }

    private DoubleTensor dPdx(DoubleTensor doubleTensor) {
        DoubleTensor diffs = getDiffs(doubleTensor);
        return (DoubleTensor) ((DoubleTensor) ((DoubleTensor) ((DoubleTensor) gaussianKernel(diffs).timesInPlace(diffs)).unaryMinusInPlace()).sum(0)).divInPlace((DoubleTensor) Double.valueOf(this.bandwidth * this.bandwidth * this.samples.getLength()));
    }

    private DoubleTensor gaussianKernel(DoubleTensor doubleTensor) {
        return (DoubleTensor) ((DoubleTensor) doubleTensor.pow2(2.0d).timesInPlace((DoubleTensor) Double.valueOf(-0.5d))).expInPlace().timesInPlace((DoubleTensor) Double.valueOf(1.0d / Math.sqrt(6.283185307179586d)));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static double scottsBandwidth(DoubleTensor doubleTensor) {
        return 1.06d * ((Double) doubleTensor.standardDeviation()).doubleValue() * Math.pow(doubleTensor.getLength(), -0.2d);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public DoubleTensor sample(int i, KeanuRandom keanuRandom) {
        double[] dArr = new double[i];
        int i2 = 0;
        Iterator it = Uniform.withParameters(DoubleTensor.scalar(0.0d), DoubleTensor.scalar(this.samples.getLength())).sample(new long[]{i}, keanuRandom).floorInPlace().asFlatList().iterator();
        while (it.hasNext()) {
            dArr[i2] = ((Double) this.samples.getValue(((Double) it.next()).intValue())).doubleValue();
            i2++;
        }
        return (DoubleTensor) ((DoubleTensor) keanuRandom.nextGaussian(new long[]{i}).timesInPlace((DoubleTensor) Double.valueOf(this.bandwidth))).plusInPlace(DoubleTensor.create(dArr));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.vertices.Samplable
    public DoubleTensor sample(KeanuRandom keanuRandom) {
        return sample(1, keanuRandom);
    }

    public void resample(int i, KeanuRandom keanuRandom) {
        this.samples = sample(i, keanuRandom);
    }

    public long[] getSampleShape() {
        return this.samples.getShape();
    }

    @Override // io.improbable.keanu.vertices.Probabilistic
    public /* bridge */ /* synthetic */ Map dLogProb(DoubleTensor doubleTensor, Set set) {
        return dLogProb2(doubleTensor, (Set<? extends Vertex>) set);
    }
}
