package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.multiple;

import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.NonSaveableVertex;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/nonprobabilistic/operators/multiple/ReduceVertex.class */
public class ReduceVertex extends DoubleVertex implements Differentiable, NonProbabilistic<DoubleTensor>, NonSaveableVertex {
    private final List<? extends Vertex<DoubleTensor>> inputs;
    private final BiFunction<DoubleTensor, DoubleTensor, DoubleTensor> reduceFunction;
    private final Supplier<PartialDerivative> forwardModeAutoDiffLambda;
    private final Function<PartialDerivative, Map<Vertex, PartialDerivative>> reverseModeAutoDiffLambda;

    public ReduceVertex(long[] jArr, Collection<? extends Vertex<DoubleTensor>> collection, BiFunction<DoubleTensor, DoubleTensor, DoubleTensor> biFunction, Supplier<PartialDerivative> supplier, Function<PartialDerivative, Map<Vertex, PartialDerivative>> function) {
        super(jArr);
        if (collection.size() < 2) {
            throw new IllegalArgumentException("ReduceVertex should have at least two input vertices, called with " + collection.size());
        }
        this.inputs = new ArrayList(collection);
        this.reduceFunction = biFunction;
        this.forwardModeAutoDiffLambda = supplier;
        this.reverseModeAutoDiffLambda = function;
        setParents(collection);
    }

    public ReduceVertex(long[] jArr, BiFunction<DoubleTensor, DoubleTensor, DoubleTensor> biFunction, Supplier<PartialDerivative> supplier, Function<PartialDerivative, Map<Vertex, PartialDerivative>> function, Vertex<DoubleTensor>... vertexArr) {
        this(jArr, Arrays.asList(vertexArr), biFunction, supplier, function);
    }

    public ReduceVertex(long[] jArr, Collection<? extends Vertex<DoubleTensor>> collection, BiFunction<DoubleTensor, DoubleTensor, DoubleTensor> biFunction) {
        this(jArr, collection, biFunction, (Supplier<PartialDerivative>) null, (Function<PartialDerivative, Map<Vertex, PartialDerivative>>) null);
    }

    public ReduceVertex(BiFunction<DoubleTensor, DoubleTensor, DoubleTensor> biFunction, Supplier<PartialDerivative> supplier, Function<PartialDerivative, Map<Vertex, PartialDerivative>> function, Vertex<DoubleTensor>... vertexArr) {
        this(TensorShapeValidation.checkAllShapesMatch((Collection<long[]>) Arrays.stream(vertexArr).map((v0) -> {
            return v0.getShape();
        }).collect(Collectors.toList())), Arrays.asList(vertexArr), biFunction, supplier, function);
    }

    public ReduceVertex(List<? extends Vertex<DoubleTensor>> list, BiFunction<DoubleTensor, DoubleTensor, DoubleTensor> biFunction) {
        this(TensorShapeValidation.checkAllShapesMatch((Collection<long[]>) list.stream().map((v0) -> {
            return v0.getShape();
        }).collect(Collectors.toList())), list, biFunction, (Supplier<PartialDerivative>) null, (Function<PartialDerivative, Map<Vertex, PartialDerivative>>) null);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.vertices.NonProbabilistic
    public DoubleTensor calculate() {
        return applyReduce((v0) -> {
            return v0.getValue();
        });
    }

    private DoubleTensor applyReduce(Function<Vertex<DoubleTensor>, DoubleTensor> function) {
        Iterator<? extends Vertex<DoubleTensor>> it = this.inputs.iterator();
        DoubleTensor value = it.next().getValue();
        while (true) {
            DoubleTensor doubleTensor = value;
            if (!it.hasNext()) {
                return doubleTensor;
            }
            value = this.reduceFunction.apply(doubleTensor, function.apply(it.next()));
        }
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> map) {
        if (this.forwardModeAutoDiffLambda != null) {
            return this.forwardModeAutoDiffLambda.get();
        }
        throw new UnsupportedOperationException();
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative partialDerivative) {
        return this.reverseModeAutoDiffLambda.apply(partialDerivative);
    }
}
