/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.backend.keanu.compiled;

import io.improbable.keanu.algorithms.ProbabilisticModel;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.backend.ComputableGraph;
import io.improbable.keanu.backend.ProbabilisticGraphConverter;
import io.improbable.keanu.backend.VariableImpl;
import io.improbable.keanu.backend.keanu.compiled.KeanuCompiledGraphBuilder;
import io.improbable.keanu.backend.keanu.compiled.WrappedCompiledGraph;
import io.improbable.keanu.network.BayesianNetwork;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

public class KeanuCompiledProbabilisticGraph
implements ProbabilisticModel {
    private final ComputableGraph computableGraph;
    private final List<Variable> latentVariables;
    private final VariableReference logProbOp;
    private final VariableReference logLikelihoodOp;

    public static KeanuCompiledProbabilisticGraph convert(BayesianNetwork network) {
        KeanuCompiledGraphBuilder builder = new KeanuCompiledGraphBuilder();
        builder.convert(network.getVertices());
        Optional<VariableReference> logLikelihoodReference = ProbabilisticGraphConverter.convertLogProbObservation(network, builder);
        VariableReference priorLogProbReference = ProbabilisticGraphConverter.convertLogProbPrior(network, builder);
        VariableReference logProbReference = logLikelihoodReference.map(ll -> builder.add((VariableReference)ll, priorLogProbReference)).orElse(priorLogProbReference);
        builder.registerOutput(logProbReference);
        logLikelihoodReference.ifPresent(builder::registerOutput);
        WrappedCompiledGraph computableGraph = builder.build();
        List<Variable> latentVariables = builder.getLatentVariables().stream().map(v -> new VariableImpl(computableGraph, (VariableReference)v)).collect(Collectors.toList());
        return new KeanuCompiledProbabilisticGraph(computableGraph, latentVariables, logProbReference, logLikelihoodReference.orElse(null));
    }

    @Override
    public double logProb(Map<VariableReference, ?> inputs) {
        DoubleTensor logProb = (DoubleTensor)this.computableGraph.compute(inputs).get(this.logProbOp);
        return (Double)logProb.scalar();
    }

    @Override
    public double logLikelihood(Map<VariableReference, ?> inputs) {
        if (this.logLikelihoodOp == null) {
            throw new IllegalStateException("Likelihood is undefined");
        }
        DoubleTensor logLikelihood = (DoubleTensor)this.computableGraph.compute(inputs).get(this.logLikelihoodOp);
        return (Double)logLikelihood.scalar();
    }

    public KeanuCompiledProbabilisticGraph(ComputableGraph computableGraph, List<Variable> latentVariables, VariableReference logProbOp, VariableReference logLikelihoodOp) {
        this.computableGraph = computableGraph;
        this.latentVariables = latentVariables;
        this.logProbOp = logProbOp;
        this.logLikelihoodOp = logLikelihoodOp;
    }

    public ComputableGraph getComputableGraph() {
        return this.computableGraph;
    }

    @Override
    public List<Variable> getLatentVariables() {
        return this.latentVariables;
    }

    public VariableReference getLogProbOp() {
        return this.logProbOp;
    }

    public VariableReference getLogLikelihoodOp() {
        return this.logLikelihoodOp;
    }
}

