package io.improbable.keanu.vertices.dbl.nonprobabilistic;

import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.NonSaveableVertex;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import io.improbable.keanu.vertices.generic.nonprobabilistic.CPTCondition;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/nonprobabilistic/DoubleCPTVertex.class */
public class DoubleCPTVertex extends DoubleVertex implements Differentiable, NonProbabilistic<DoubleTensor>, NonSaveableVertex {
    private final List<Vertex<? extends Tensor<?, ?>>> inputs;
    private final Map<CPTCondition, DoubleVertex> conditions;
    private final DoubleVertex defaultResult;

    public DoubleCPTVertex(List<Vertex<? extends Tensor<?, ?>>> list, Map<CPTCondition, DoubleVertex> map, DoubleVertex doubleVertex) {
        super(doubleVertex.getShape());
        this.inputs = list;
        this.conditions = map;
        this.defaultResult = doubleVertex;
        addParents(list);
        addParents(map.values());
        addParent(doubleVertex);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.vertices.NonProbabilistic
    public DoubleTensor calculate() {
        DoubleVertex doubleVertex = this.conditions.get(CPTCondition.from(this.inputs, vertex -> {
            return ((Tensor) vertex.getValue()).scalar();
        }));
        return doubleVertex == null ? this.defaultResult.getValue() : doubleVertex.getValue();
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> map) {
        DoubleVertex doubleVertex = this.conditions.get(CPTCondition.from(this.inputs, vertex -> {
            return ((Tensor) vertex.getValue()).scalar();
        }));
        return doubleVertex == null ? map.get(this.defaultResult) : map.get(doubleVertex);
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative partialDerivative) {
        DoubleVertex doubleVertex = this.conditions.get(CPTCondition.from(this.inputs, vertex -> {
            return ((Tensor) vertex.getValue()).scalar();
        }));
        HashMap hashMap = new HashMap();
        for (DoubleVertex doubleVertex2 : this.conditions.values()) {
            if (doubleVertex2 == doubleVertex) {
                hashMap.put(doubleVertex2, partialDerivative);
            }
        }
        if (doubleVertex == null) {
            hashMap.put(this.defaultResult, partialDerivative);
        }
        return hashMap;
    }
}
