package io.improbable.keanu.algorithms.particlefiltering;

import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.ProbabilityCalculator;
import io.improbable.keanu.vertices.Vertex;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/algorithms/particlefiltering/Particle.class */
public class Particle {
    private Map<Vertex, Object> latentVertices = new HashMap();
    private List<Vertex> observedVertices = new ArrayList();
    private double sumLogPOfSubgraph = 1.0d;

    public Map<Vertex, Object> getLatentVertices() {
        return this.latentVertices;
    }

    public double logProb() {
        return this.sumLogPOfSubgraph;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public double getScalarValueOfVertex(Vertex<DoubleTensor> vertex) {
        return ((Double) ((DoubleTensor) this.latentVertices.get(vertex)).scalar()).doubleValue();
    }

    public <T> T getValueOfVertex(Vertex<T> vertex) {
        return (T) this.latentVertices.get(vertex);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public <T> void addLatentVertex(Vertex<T> vertex, T t) {
        this.latentVertices.put(vertex, t);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public <T> void addObservedVertex(Vertex<T> vertex) {
        this.observedVertices.add(vertex);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double updateSumLogPOfSubgraph() {
        applyLatentVertexValues();
        this.sumLogPOfSubgraph = ProbabilityCalculator.calculateLogProbFor(this.latentVertices.keySet()) + ProbabilityCalculator.calculateLogProbFor(this.observedVertices);
        return this.sumLogPOfSubgraph;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Particle shallowCopy() {
        Particle particle = new Particle();
        particle.latentVertices = new HashMap(this.latentVertices);
        particle.observedVertices = new ArrayList(this.observedVertices);
        return particle;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int sortDescending(Particle particle, Particle particle2) {
        return Double.compare(particle2.logProb(), particle.logProb());
    }

    private void applyLatentVertexValues() {
        this.latentVertices.keySet().forEach(this::applyLatentVertexValue);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private <T> void applyLatentVertexValue(Vertex<T> vertex) {
        if (this.latentVertices.containsKey(vertex)) {
            vertex.setAndCascade(this.latentVertices.get(vertex));
        }
    }
}
