/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic;

import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.NonSaveableVertex;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import io.improbable.keanu.vertices.generic.nonprobabilistic.CPTCondition;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class DoubleCPTVertex
extends DoubleVertex
implements Differentiable,
NonProbabilistic<DoubleTensor>,
NonSaveableVertex {
    private final List<Vertex<? extends Tensor<?, ?>>> inputs;
    private final Map<CPTCondition, DoubleVertex> conditions;
    private final DoubleVertex defaultResult;

    public DoubleCPTVertex(List<Vertex<? extends Tensor<?, ?>>> inputs, Map<CPTCondition, DoubleVertex> conditions, DoubleVertex defaultResult) {
        super(defaultResult.getShape());
        this.inputs = inputs;
        this.conditions = conditions;
        this.defaultResult = defaultResult;
        this.addParents(inputs);
        this.addParents(conditions.values());
        this.addParent(defaultResult);
    }

    @Override
    public DoubleTensor calculate() {
        CPTCondition condition = CPTCondition.from(this.inputs, v -> ((Tensor)v.getValue()).scalar());
        DoubleVertex vertex = this.conditions.get(condition);
        return vertex == null ? (DoubleTensor)this.defaultResult.getValue() : (DoubleTensor)vertex.getValue();
    }

    @Override
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) {
        CPTCondition condition = CPTCondition.from(this.inputs, vertex -> ((Tensor)vertex.getValue()).scalar());
        DoubleVertex vertex2 = this.conditions.get(condition);
        return vertex2 == null ? derivativeOfParentsWithRespectToInput.get(this.defaultResult) : derivativeOfParentsWithRespectToInput.get(vertex2);
    }

    @Override
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative derivativeOfOutputWithRespectToSelf) {
        CPTCondition condition = CPTCondition.from(this.inputs, vertex -> ((Tensor)vertex.getValue()).scalar());
        DoubleVertex conditionVertex = this.conditions.get(condition);
        HashMap<Vertex, PartialDerivative> partials = new HashMap<Vertex, PartialDerivative>();
        for (Vertex vertex2 : this.conditions.values()) {
            if (vertex2 != conditionVertex) continue;
            partials.put(vertex2, derivativeOfOutputWithRespectToSelf);
        }
        if (conditionVertex == null) {
            partials.put(this.defaultResult, derivativeOfOutputWithRespectToSelf);
        }
        return partials;
    }
}

