/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.network;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.improbable.keanu.algorithms.ProbabilisticModel;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.graphtraversal.VertexValuePropagation;
import io.improbable.keanu.network.BayesianNetwork;
import io.improbable.keanu.network.LambdaSectionSnapshot;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.ProbabilityCalculator;
import io.improbable.keanu.vertices.Vertex;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

public class KeanuProbabilisticModel
implements ProbabilisticModel {
    private final Map<VariableReference, Vertex> vertexLookup;
    private final List<Vertex> latentVertices;
    private final List<Vertex> observedVertices;
    private final List<Vertex> latentOrObservedVertices;
    private final LambdaSectionSnapshot lambdaSectionSnapshot;

    public KeanuProbabilisticModel(Collection<? extends Vertex> variables) {
        this(new BayesianNetwork(variables));
    }

    public KeanuProbabilisticModel(BayesianNetwork bayesianNetwork) {
        this.vertexLookup = bayesianNetwork.getLatentOrObservedVertices().stream().collect(Collectors.toMap(Vertex::getId, v -> v));
        this.latentVertices = ImmutableList.copyOf(bayesianNetwork.getLatentVertices());
        this.observedVertices = ImmutableList.copyOf(bayesianNetwork.getObservedVertices());
        this.latentOrObservedVertices = ImmutableList.copyOf(bayesianNetwork.getLatentOrObservedVertices());
        this.lambdaSectionSnapshot = new LambdaSectionSnapshot();
        this.resetModelToObservedState();
        this.checkBayesNetInHealthyState();
    }

    @Override
    public double logProb(Map<VariableReference, ?> inputs) {
        if (!inputs.isEmpty()) {
            this.cascadeValues(inputs);
        }
        return ProbabilityCalculator.calculateLogProbFor(this.latentOrObservedVertices);
    }

    @Override
    public double logProbAfter(Map<VariableReference, Object> newValues, double logProbBefore) {
        ImmutableSet.Builder affectedVerticesBuilder = ImmutableSet.builder();
        for (VariableReference reference : newValues.keySet()) {
            Vertex vertex = this.vertexLookup.get(reference);
            affectedVerticesBuilder.add((Object)vertex);
        }
        ImmutableSet affectedVertices = affectedVerticesBuilder.build();
        double lambdaSectionLogProbBefore = this.lambdaSectionSnapshot.logProb((Set<? extends Variable>)affectedVertices);
        this.cascadeValues(newValues);
        double lambdaSectionLogProbAfter = this.lambdaSectionSnapshot.logProb((Set<? extends Variable>)affectedVertices);
        double deltaLogProb = lambdaSectionLogProbAfter - lambdaSectionLogProbBefore;
        return logProbBefore + deltaLogProb;
    }

    @Override
    public double logLikelihood(Map<VariableReference, ?> inputs) {
        if (!inputs.isEmpty()) {
            this.cascadeValues(inputs);
        }
        return ProbabilityCalculator.calculateLogProbFor(this.observedVertices);
    }

    @Override
    public List<Variable> getLatentVariables() {
        return this.latentVertices;
    }

    public List<Vertex> getLatentVertices() {
        return this.latentVertices;
    }

    public List<Vertex> getLatentOrObservedVertices() {
        return this.latentOrObservedVertices;
    }

    public List<Variable<DoubleTensor, ?>> getContinuousLatentVariables() {
        return this.getLatentVariables().stream().filter(v -> v.getValue() instanceof DoubleTensor).map(v -> v).collect(Collectors.toList());
    }

    private void checkBayesNetInHealthyState() {
        if (this.latentOrObservedVertices.isEmpty()) {
            throw new IllegalArgumentException("Cannot create model without latent or observed variables");
        }
    }

    private void resetModelToObservedState() {
        VertexValuePropagation.cascadeUpdate(this.observedVertices);
    }

    protected void cascadeValues(Map<VariableReference, ?> inputs) {
        ArrayList<Vertex> updatedVertices = new ArrayList<Vertex>();
        for (Map.Entry<VariableReference, ?> input : inputs.entrySet()) {
            Vertex updatingVertex = this.vertexLookup.get(input.getKey());
            if (updatingVertex == null) {
                throw new IllegalArgumentException("Cannot cascade update for input: " + input.getKey());
            }
            updatingVertex.setValue(input.getValue());
            updatedVertices.add(updatingVertex);
        }
        VertexValuePropagation.cascadeUpdate(updatedVertices);
    }
}

