/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.functions;

import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Map;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

public class EuclideanDistance<NAMETYPE extends Name>
extends TensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> arg1;
    private final TensorFunction<NAMETYPE> arg2;
    private final String dimension;

    public EuclideanDistance(TensorFunction<NAMETYPE> argument1, TensorFunction<NAMETYPE> argument2, String dimension) {
        this.arg1 = argument1;
        this.arg2 = argument2;
        this.dimension = dimension;
    }

    @Override
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.arg1, this.arg2);
    }

    @Override
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
        if (arguments.size() != 2) {
            throw new IllegalArgumentException("EuclideanDistance must have 2 arguments, got " + arguments.size());
        }
        return new EuclideanDistance<NAMETYPE>(arguments.get(0), arguments.get(1), this.dimension);
    }

    @Override
    public TensorType type(TypeContext<NAMETYPE> context) {
        TensorType t1 = this.arg1.toPrimitive().type(context);
        TensorType t2 = this.arg2.toPrimitive().type(context);
        String resolvedDimension = context.resolveBinding(this.dimension);
        Optional<TensorType.Dimension> d1 = t1.dimension(resolvedDimension);
        Optional<TensorType.Dimension> d2 = t2.dimension(resolvedDimension);
        if (d1.isEmpty() || d2.isEmpty() || d1.get().type() != TensorType.Dimension.Type.indexedBound || d2.get().type() != TensorType.Dimension.Type.indexedBound || !d1.get().size().equals(d2.get().size())) {
            throw new IllegalArgumentException("euclidean_distance expects both arguments to have the '" + resolvedDimension + "' dimension with same size, but input types were " + t1 + " and " + t2);
        }
        return this.toPrimitive().type(context);
    }

    @Override
    public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
        return this.toPrimitive().evaluate(context);
    }

    @Override
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        PrimitiveTensorFunction<NAMETYPE> primitive1 = this.arg1.toPrimitive();
        PrimitiveTensorFunction<NAMETYPE> primitive2 = this.arg2.toPrimitive();
        Join<NAMETYPE> diffs = new Join<NAMETYPE>(primitive1, primitive2, ScalarFunctions.subtract());
        Map<NAMETYPE> squaredDiffs = new Map<NAMETYPE>(diffs, ScalarFunctions.square());
        Reduce<NAMETYPE> sumOfSquares = new Reduce<NAMETYPE>(squaredDiffs, Reduce.Aggregator.sum, this.dimension);
        return new Map<NAMETYPE>(sumOfSquares, ScalarFunctions.sqrt());
    }

    @Override
    public String toString(ToStringContext<NAMETYPE> context) {
        return "euclidean_distance(" + this.arg1.toString(context) + ", " + this.arg2.toString(context) + ", " + context.resolveBinding(this.dimension) + ")";
    }

    @Override
    public int hashCode() {
        return Objects.hash("euclidean_distance", this.arg1, this.arg2, this.dimension);
    }
}

