/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import java.util.HashMap;
import java.util.Map;

public class DoubleIfVertex
extends DoubleVertex
implements Differentiable,
NonProbabilistic<DoubleTensor> {
    private final BooleanVertex predicate;
    private final DoubleVertex thn;
    private final DoubleVertex els;
    protected static final String PREDICATE_NAME = "predicate";
    protected static final String THEN_NAME = "then";
    protected static final String ELSE_NAME = "else";

    @ExportVertexToPythonBindings
    public DoubleIfVertex(@LoadVertexParam(value="predicate") BooleanVertex predicate, @LoadVertexParam(value="then") DoubleVertex thn, @LoadVertexParam(value="else") DoubleVertex els) {
        super(TensorShapeValidation.checkTernaryConditionShapeIsValid(predicate.getShape(), thn.getShape(), els.getShape()));
        this.predicate = predicate;
        this.thn = thn;
        this.els = els;
        this.setParents(predicate, thn, els);
    }

    @SaveVertexParam(value="predicate")
    public BooleanVertex getPredicate() {
        return this.predicate;
    }

    @SaveVertexParam(value="then")
    public DoubleVertex getThn() {
        return this.thn;
    }

    @SaveVertexParam(value="else")
    public DoubleVertex getEls() {
        return this.els;
    }

    @Override
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) {
        PartialDerivative thnPartial = derivativeOfParentsWithRespectToInput.getOrDefault(this.thn, PartialDerivative.EMPTY);
        PartialDerivative elsPartial = derivativeOfParentsWithRespectToInput.getOrDefault(this.els, PartialDerivative.EMPTY);
        BooleanTensor predicateValue = (BooleanTensor)this.predicate.getValue();
        if (predicateValue.allTrue()) {
            return thnPartial;
        }
        if (predicateValue.allFalse()) {
            return elsPartial;
        }
        return thnPartial.multiplyAlongOfDimensions(predicateValue.toDoubleMask()).add(elsPartial.multiplyAlongOfDimensions(predicateValue.not().toDoubleMask()));
    }

    @Override
    public DoubleTensor calculate() {
        return this.op((BooleanTensor)this.predicate.getValue(), (DoubleTensor)this.thn.getValue(), (DoubleTensor)this.els.getValue());
    }

    private DoubleTensor op(BooleanTensor predicate, DoubleTensor thn, DoubleTensor els) {
        return predicate.doubleWhere(thn, els);
    }

    @Override
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative derivativeOfOutputWithRespectToSelf) {
        HashMap<Vertex, PartialDerivative> partials = new HashMap<Vertex, PartialDerivative>();
        BooleanTensor predicateValue = (BooleanTensor)this.predicate.getValue();
        partials.put(this.thn, derivativeOfOutputWithRespectToSelf.multiplyAlongWrtDimensions(predicateValue.toDoubleMask()));
        partials.put(this.els, derivativeOfOutputWithRespectToSelf.multiplyAlongWrtDimensions(predicateValue.not().toDoubleMask()));
        return partials;
    }
}

