/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.multiple;

import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.NonSaveableVertex;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

public class ReduceVertex
extends DoubleVertex
implements Differentiable,
NonProbabilistic<DoubleTensor>,
NonSaveableVertex {
    private final List<? extends Vertex<DoubleTensor>> inputs;
    private final BiFunction<DoubleTensor, DoubleTensor, DoubleTensor> reduceFunction;
    private final Supplier<PartialDerivative> forwardModeAutoDiffLambda;
    private final Function<PartialDerivative, Map<Vertex, PartialDerivative>> reverseModeAutoDiffLambda;

    public ReduceVertex(long[] shape, Collection<? extends Vertex<DoubleTensor>> inputs, BiFunction<DoubleTensor, DoubleTensor, DoubleTensor> reduceFunction, Supplier<PartialDerivative> forwardModeAutoDiffLambda, Function<PartialDerivative, Map<Vertex, PartialDerivative>> reverseModeAutoDiffLambda) {
        super(shape);
        if (inputs.size() < 2) {
            throw new IllegalArgumentException("ReduceVertex should have at least two input vertices, called with " + inputs.size());
        }
        this.inputs = new ArrayList<Vertex<DoubleTensor>>(inputs);
        this.reduceFunction = reduceFunction;
        this.forwardModeAutoDiffLambda = forwardModeAutoDiffLambda;
        this.reverseModeAutoDiffLambda = reverseModeAutoDiffLambda;
        this.setParents(inputs);
    }

    public ReduceVertex(long[] shape, BiFunction<DoubleTensor, DoubleTensor, DoubleTensor> reduceFunction, Supplier<PartialDerivative> forwardModeAutoDiffLambda, Function<PartialDerivative, Map<Vertex, PartialDerivative>> reverseModeAutoDiffLambda, Vertex<DoubleTensor> ... input) {
        this(shape, Arrays.asList(input), reduceFunction, forwardModeAutoDiffLambda, reverseModeAutoDiffLambda);
    }

    public ReduceVertex(long[] shape, Collection<? extends Vertex<DoubleTensor>> inputs, BiFunction<DoubleTensor, DoubleTensor, DoubleTensor> reduceFunction) {
        this(shape, inputs, reduceFunction, null, null);
    }

    public ReduceVertex(BiFunction<DoubleTensor, DoubleTensor, DoubleTensor> reduceFunction, Supplier<PartialDerivative> forwardModeAutoDiffLambda, Function<PartialDerivative, Map<Vertex, PartialDerivative>> reverseModeAutoDiffLambda, Vertex<DoubleTensor> ... input) {
        this(TensorShapeValidation.checkAllShapesMatch(Arrays.stream(input).map(Vertex::getShape).collect(Collectors.toList())), Arrays.asList(input), reduceFunction, forwardModeAutoDiffLambda, reverseModeAutoDiffLambda);
    }

    public ReduceVertex(List<? extends Vertex<DoubleTensor>> inputs, BiFunction<DoubleTensor, DoubleTensor, DoubleTensor> reduceFunction) {
        this(TensorShapeValidation.checkAllShapesMatch(inputs.stream().map(Vertex::getShape).collect(Collectors.toList())), inputs, reduceFunction, null, null);
    }

    @Override
    public DoubleTensor calculate() {
        return this.applyReduce(Vertex::getValue);
    }

    private DoubleTensor applyReduce(Function<Vertex<DoubleTensor>, DoubleTensor> mapper) {
        Iterator<? extends Vertex<DoubleTensor>> inputIterator = this.inputs.iterator();
        DoubleTensor result = inputIterator.next().getValue();
        while (inputIterator.hasNext()) {
            result = this.reduceFunction.apply(result, mapper.apply(inputIterator.next()));
        }
        return result;
    }

    @Override
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) {
        if (this.forwardModeAutoDiffLambda != null) {
            return this.forwardModeAutoDiffLambda.get();
        }
        throw new UnsupportedOperationException();
    }

    @Override
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative derivativeOfOutputWithRespectToSelf) {
        return this.reverseModeAutoDiffLambda.apply(derivativeOfOutputWithRespectToSelf);
    }
}

