package io.improbable.keanu.algorithms.particlefiltering;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.vertices.Probabilistic;
import io.improbable.keanu.vertices.Vertex;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:io/improbable/keanu/algorithms/particlefiltering/ParticleFilter.class */
public class ParticleFilter {
    private Collection<? extends Vertex> vertices;
    private int numParticles;
    private int resamplingCycles;
    private double resamplingProportion;
    private KeanuRandom random;
    private List<Particle> particles;

    public static ParticleFilterBuilder ofVertexInGraph(Vertex vertex) {
        return new ParticleFilterBuilder(vertex.getConnectedGraph());
    }

    public static ParticleFilterBuilder ofGraph(Collection<? extends Vertex> collection) {
        return new ParticleFilterBuilder(collection);
    }

    public ParticleFilter(Collection<? extends Vertex> collection, int i, int i2, double d, KeanuRandom keanuRandom) {
        this.vertices = collection;
        this.numParticles = i;
        this.resamplingCycles = i2;
        this.resamplingProportion = d;
        this.random = keanuRandom;
        sort();
    }

    public List<Particle> getSortedMostProbableParticles() {
        this.particles.sort(Particle::sortDescending);
        return this.particles;
    }

    public Particle getMostProbableParticle() {
        return this.particles.get(0);
    }

    public List<Particle> getMostProbableParticles() {
        return this.particles;
    }

    private void sort() {
        Map<Vertex, Set<Vertex>> sort = LatentIncrementSort.sort(this.vertices);
        ArrayList arrayList = new ArrayList(sort.keySet());
        List<Particle> createEmptyParticles = createEmptyParticles(this.numParticles);
        for (int i = 0; i < arrayList.size(); i++) {
            Vertex<?> vertex = (Vertex) arrayList.get(i);
            createEmptyParticles = updateParticles(vertex, sort.get(vertex), createEmptyParticles);
        }
        this.particles = createEmptyParticles;
    }

    private List<Particle> updateParticles(Vertex<?> vertex, Set<Vertex> set, List<Particle> list) {
        List<Particle> sampleAndCopy = sampleAndCopy(list, this.numParticles);
        addObservedVertexToParticles(sampleAndCopy, vertex, set);
        for (int i = 0; i < this.resamplingCycles; i++) {
            sampleAndCopy = removeWorstParticles(sampleAndCopy);
            List<Particle> sampleAndCopy2 = sampleAndCopy(list, this.numParticles - sampleAndCopy.size());
            addObservedVertexToParticles(sampleAndCopy2, vertex, set);
            sampleAndCopy.addAll(sampleAndCopy2);
        }
        return sampleAndCopy;
    }

    private List<Particle> createEmptyParticles(int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(new Particle());
        }
        return arrayList;
    }

    private void addObservedVertexToParticles(List<Particle> list, Vertex<?> vertex, Set<Vertex> set) {
        for (Particle particle : list) {
            particle.addObservedVertex(vertex);
            Iterator<Vertex> it = set.iterator();
            while (it.hasNext()) {
                sampleValueAndAddToParticle(it.next(), particle);
            }
            particle.updateSumLogPOfSubgraph();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private <T> void sampleValueAndAddToParticle(Vertex<T> vertex, Particle particle) {
        particle.addLatentVertex(vertex, ((Probabilistic) vertex).sample(this.random));
    }

    private List<Particle> removeWorstParticles(List<Particle> list) {
        list.sort(Particle::sortDescending);
        return new ArrayList(list.subList(0, (int) (list.size() * (1.0d - this.resamplingProportion))));
    }

    private List<Particle> sampleAndCopy(List<Particle> list, int i) {
        double sum = list.stream().mapToDouble(particle -> {
            return Math.exp(particle.logProb());
        }).sum();
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(weightedRandomParticle(list, sum).shallowCopy());
        }
        return arrayList;
    }

    private Particle weightedRandomParticle(List<Particle> list, double d) {
        double nextDouble = this.random.nextDouble() * d;
        double d2 = 0.0d;
        Particle particle = list.get(0);
        for (int i = 0; i < list.size(); i++) {
            particle = list.get(i);
            d2 += Math.exp(particle.logProb());
            if (d2 > nextDouble) {
                break;
            }
        }
        return particle;
    }
}
