/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.probabilistic;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.distributions.continuous.Uniform;
import io.improbable.keanu.tensor.NumberTensor;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.Samplable;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.probabilistic.ProbabilisticDouble;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class KDEVertex
extends DoubleVertex
implements Differentiable,
ProbabilisticDouble,
Samplable<DoubleTensor> {
    private final double bandwidth;
    private DoubleTensor samples;
    private static final String BANDWIDTH_NAME = "bandwidth";
    private static final String SAMPLES_NAME = "samples";

    @ExportVertexToPythonBindings
    public KDEVertex(@LoadVertexParam(value="samples") DoubleTensor samples, @LoadVertexParam(value="bandwidth") double bandwidth) {
        super(Tensor.SCALAR_SHAPE);
        if (samples.getLength() == 0L) {
            throw new IllegalStateException("The provided tensor of samples is empty!");
        }
        this.samples = samples;
        this.bandwidth = bandwidth;
    }

    public KDEVertex(DoubleTensor samples) {
        this(samples, KDEVertex.scottsBandwidth(samples));
    }

    public KDEVertex(List<Double> samples) {
        this(DoubleTensor.create(samples.stream().mapToDouble(Double::doubleValue).toArray(), new long[]{samples.size()}));
    }

    public KDEVertex(List<Double> samples, double bandwidth) {
        this(DoubleTensor.create(samples.stream().mapToDouble(Double::doubleValue).toArray()), bandwidth);
    }

    @SaveVertexParam(value="bandwidth")
    public double getBandwidth() {
        return this.bandwidth;
    }

    @SaveVertexParam(value="samples")
    public DoubleTensor getInputSamples() {
        return this.samples;
    }

    private DoubleTensor getDiffs(DoubleTensor x) {
        return (DoubleTensor)((Object)((DoubleTensor)((DoubleTensor)x.broadcast(this.samples.getShape()[0], x.getShape()[0])).minusInPlace((NumberTensor)this.samples.reshape(this.samples.getShape()[0], 1L))).divInPlace(this.bandwidth));
    }

    public DoubleTensor pdf(DoubleTensor x) {
        DoubleTensor diffs = this.getDiffs(x);
        return (DoubleTensor)((Object)((DoubleTensor)this.gaussianKernel(diffs).sum(0)).divInPlace((double)this.samples.getLength() * this.bandwidth));
    }

    @Override
    public double logProb(DoubleTensor x) {
        return (Double)((DoubleTensor)this.pdf(x).log()).sum();
    }

    @Override
    public Map<Vertex, DoubleTensor> dLogProb(DoubleTensor value, Set<? extends Vertex> withRespectTo) {
        HashMap<Vertex, DoubleTensor> partialDerivatives = new HashMap<Vertex, DoubleTensor>();
        if (withRespectTo.contains(this)) {
            DoubleTensor dlnPdfs = this.dPdx(value).divInPlace(this.pdf(value));
            partialDerivatives.put(this, dlnPdfs);
        }
        return partialDerivatives;
    }

    private DoubleTensor dPdx(DoubleTensor x) {
        DoubleTensor diff = this.getDiffs(x);
        return (DoubleTensor)((Object)((DoubleTensor)((DoubleTensor)this.gaussianKernel(diff).timesInPlace(diff).unaryMinusInPlace()).sum(0)).divInPlace(this.bandwidth * this.bandwidth * (double)this.samples.getLength()));
    }

    private DoubleTensor gaussianKernel(DoubleTensor x) {
        DoubleTensor power = (DoubleTensor)((DoubleTensor)((Object)x.pow(2.0).timesInPlace(-0.5))).expInPlace();
        return (DoubleTensor)((Object)power.timesInPlace(1.0 / Math.sqrt(Math.PI * 2)));
    }

    private static double scottsBandwidth(DoubleTensor samples) {
        return 1.06 * (Double)samples.standardDeviation() * Math.pow(samples.getLength(), -0.2);
    }

    public DoubleTensor sample(int nSamples, KeanuRandom random) {
        DoubleTensor value = (DoubleTensor)Uniform.withParameters(DoubleTensor.scalar(0.0), DoubleTensor.scalar(this.samples.getLength())).sample(new long[]{nSamples}, random);
        DoubleTensor index = (DoubleTensor)value.floorInPlace();
        double[] shuffledSamples = new double[nSamples];
        int j = 0;
        for (Double i : index.asFlatList()) {
            shuffledSamples[j] = (Double)this.samples.getValue(i.intValue());
            ++j;
        }
        DoubleTensor sampleMus = DoubleTensor.create(shuffledSamples);
        return ((DoubleTensor)((Object)random.nextGaussian(new long[]{nSamples}).timesInPlace(this.bandwidth))).plusInPlace(sampleMus);
    }

    @Override
    public DoubleTensor sample(KeanuRandom random) {
        return this.sample(1, random);
    }

    public void resample(int nSamples, KeanuRandom random) {
        this.samples = this.sample(nSamples, random);
    }

    public long[] getSampleShape() {
        return this.samples.getShape();
    }
}

