/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.particlefiltering;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.particlefiltering.LatentIncrementSort;
import io.improbable.keanu.algorithms.particlefiltering.Particle;
import io.improbable.keanu.algorithms.particlefiltering.ParticleFilterBuilder;
import io.improbable.keanu.vertices.Probabilistic;
import io.improbable.keanu.vertices.Vertex;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class ParticleFilter {
    private Collection<? extends Vertex> vertices;
    private int numParticles;
    private int resamplingCycles;
    private double resamplingProportion;
    private KeanuRandom random;
    private List<Particle> particles;

    public static ParticleFilterBuilder ofVertexInGraph(Vertex vertex) {
        return new ParticleFilterBuilder(vertex.getConnectedGraph());
    }

    public static ParticleFilterBuilder ofGraph(Collection<? extends Vertex> vertices) {
        return new ParticleFilterBuilder(vertices);
    }

    public ParticleFilter(Collection<? extends Vertex> vertices, int numParticles, int resamplingCycles, double resamplingProportion, KeanuRandom random) {
        this.vertices = vertices;
        this.numParticles = numParticles;
        this.resamplingCycles = resamplingCycles;
        this.resamplingProportion = resamplingProportion;
        this.random = random;
        this.sort();
    }

    public List<Particle> getSortedMostProbableParticles() {
        this.particles.sort(Particle::sortDescending);
        return this.particles;
    }

    public Particle getMostProbableParticle() {
        return this.particles.get(0);
    }

    public List<Particle> getMostProbableParticles() {
        return this.particles;
    }

    private void sort() {
        Map<Vertex, Set<Vertex>> obsVertIncrDependencies = LatentIncrementSort.sort(this.vertices);
        ArrayList<Vertex> observedVertexOrder = new ArrayList<Vertex>(obsVertIncrDependencies.keySet());
        List<Particle> particles = this.createEmptyParticles(this.numParticles);
        for (int i = 0; i < observedVertexOrder.size(); ++i) {
            Vertex nextObsVertex = (Vertex)observedVertexOrder.get(i);
            Set<Vertex> vertexDeps = obsVertIncrDependencies.get(nextObsVertex);
            particles = this.updateParticles(nextObsVertex, vertexDeps, particles);
        }
        this.particles = particles;
    }

    private List<Particle> updateParticles(Vertex<?> nextObservedVertex, Set<Vertex> vertexDeps, List<Particle> particles) {
        List<Particle> updatedParticles = this.sampleAndCopy(particles, this.numParticles);
        this.addObservedVertexToParticles(updatedParticles, nextObservedVertex, vertexDeps);
        for (int i = 0; i < this.resamplingCycles; ++i) {
            updatedParticles = this.removeWorstParticles(updatedParticles);
            int numToSample = this.numParticles - updatedParticles.size();
            List<Particle> sampledParticles = this.sampleAndCopy(particles, numToSample);
            this.addObservedVertexToParticles(sampledParticles, nextObservedVertex, vertexDeps);
            updatedParticles.addAll(sampledParticles);
        }
        return updatedParticles;
    }

    private List<Particle> createEmptyParticles(int number) {
        ArrayList<Particle> emptyParticles = new ArrayList<Particle>();
        for (int i = 0; i < number; ++i) {
            emptyParticles.add(new Particle());
        }
        return emptyParticles;
    }

    private void addObservedVertexToParticles(List<Particle> particles, Vertex<?> observedVertex, Set<Vertex> vertexDependencies) {
        for (Particle particle : particles) {
            particle.addObservedVertex(observedVertex);
            for (Vertex latentVertex : vertexDependencies) {
                this.sampleValueAndAddToParticle(latentVertex, particle);
            }
            particle.updateSumLogPOfSubgraph();
        }
    }

    private <T> void sampleValueAndAddToParticle(Vertex<T> vertex, Particle particle) {
        Object sample = ((Probabilistic)((Object)vertex)).sample(this.random);
        particle.addLatentVertex(vertex, sample);
    }

    private List<Particle> removeWorstParticles(List<Particle> particles) {
        particles.sort(Particle::sortDescending);
        int numberToKeep = (int)((double)particles.size() * (1.0 - this.resamplingProportion));
        List<Particle> particlesToKeep = particles.subList(0, numberToKeep);
        return new ArrayList<Particle>(particlesToKeep);
    }

    private List<Particle> sampleAndCopy(List<Particle> particles, int numToSample) {
        double sumWeights = particles.stream().mapToDouble(p -> Math.exp(p.logProb())).sum();
        ArrayList<Particle> sampledParticles = new ArrayList<Particle>();
        for (int i = 0; i < numToSample; ++i) {
            Particle sampledParticle = this.weightedRandomParticle(particles, sumWeights);
            sampledParticles.add(sampledParticle.shallowCopy());
        }
        return sampledParticles;
    }

    private Particle weightedRandomParticle(List<Particle> particles, double sumWeights) {
        double r = this.random.nextDouble() * sumWeights;
        double cumulativeWeight = 0.0;
        Particle p = particles.get(0);
        for (int i = 0; i < particles.size() && !((cumulativeWeight += Math.exp((p = particles.get(i)).logProb())) > r); ++i) {
        }
        return p;
    }
}

