/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.particlefiltering;

import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.ProbabilityCalculator;
import io.improbable.keanu.vertices.Vertex;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Particle {
    private Map<Vertex, Object> latentVertices = new HashMap<Vertex, Object>();
    private List<Vertex> observedVertices = new ArrayList<Vertex>();
    private double sumLogPOfSubgraph = 1.0;

    public Map<Vertex, Object> getLatentVertices() {
        return this.latentVertices;
    }

    public double logProb() {
        return this.sumLogPOfSubgraph;
    }

    public double getScalarValueOfVertex(Vertex<DoubleTensor> vertex) {
        return (Double)((DoubleTensor)this.latentVertices.get(vertex)).scalar();
    }

    public <T> T getValueOfVertex(Vertex<T> vertex) {
        return (T)this.latentVertices.get(vertex);
    }

    <T> void addLatentVertex(Vertex<T> vertex, T value) {
        this.latentVertices.put(vertex, value);
    }

    <T> void addObservedVertex(Vertex<T> vertex) {
        this.observedVertices.add(vertex);
    }

    double updateSumLogPOfSubgraph() {
        this.applyLatentVertexValues();
        double sumLogPOfLatents = ProbabilityCalculator.calculateLogProbFor(this.latentVertices.keySet());
        double sumLogPOfObservables = ProbabilityCalculator.calculateLogProbFor(this.observedVertices);
        this.sumLogPOfSubgraph = sumLogPOfLatents + sumLogPOfObservables;
        return this.sumLogPOfSubgraph;
    }

    Particle shallowCopy() {
        Particle clone = new Particle();
        clone.latentVertices = new HashMap<Vertex, Object>(this.latentVertices);
        clone.observedVertices = new ArrayList<Vertex>(this.observedVertices);
        return clone;
    }

    static int sortDescending(Particle a, Particle b) {
        return Double.compare(b.logProb(), a.logProb());
    }

    private void applyLatentVertexValues() {
        this.latentVertices.keySet().forEach(this::applyLatentVertexValue);
    }

    private <T> void applyLatentVertexValue(Vertex<T> vertex) {
        if (this.latentVertices.containsKey(vertex)) {
            Object value = this.latentVertices.get(vertex);
            vertex.setAndCascade(value);
        }
    }
}

