/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.functions;

import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Map;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

public class CosineSimilarity<NAMETYPE extends Name>
extends TensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> arg1;
    private final TensorFunction<NAMETYPE> arg2;
    private final String dimension;

    public CosineSimilarity(TensorFunction<NAMETYPE> argument1, TensorFunction<NAMETYPE> argument2, String dimension) {
        this.arg1 = argument1;
        this.arg2 = argument2;
        this.dimension = dimension;
    }

    @Override
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.arg1, this.arg2);
    }

    @Override
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
        if (arguments.size() != 2) {
            throw new IllegalArgumentException("CosineSimilarity must have 2 arguments, got " + arguments.size());
        }
        return new CosineSimilarity<NAMETYPE>(arguments.get(0), arguments.get(1), this.dimension);
    }

    @Override
    public TensorType type(TypeContext<NAMETYPE> context) {
        TensorType t1 = this.arg1.toPrimitive().type(context);
        TensorType t2 = this.arg2.toPrimitive().type(context);
        Optional<TensorType.Dimension> d1 = t1.dimension(this.dimension);
        Optional<TensorType.Dimension> d2 = t2.dimension(this.dimension);
        if (d1.isEmpty() || d2.isEmpty() || d1.get().type() != TensorType.Dimension.Type.indexedBound || d2.get().type() != TensorType.Dimension.Type.indexedBound || d1.get().size().get() != d2.get().size().get()) {
            throw new IllegalArgumentException("cosine_similarity expects both arguments to have the '" + this.dimension + "' dimension with same size, but input types were " + t1 + " and " + t2);
        }
        return this.toPrimitive().type(context);
    }

    @Override
    public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
        return this.toPrimitive().evaluate(context);
    }

    @Override
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        PrimitiveTensorFunction<NAMETYPE> a = this.arg1.toPrimitive();
        PrimitiveTensorFunction<NAMETYPE> b = this.arg2.toPrimitive();
        Join<NAMETYPE> aa = new Join<NAMETYPE>(a, a, ScalarFunctions.multiply());
        Join<NAMETYPE> ab = new Join<NAMETYPE>(a, b, ScalarFunctions.multiply());
        Join<NAMETYPE> bb = new Join<NAMETYPE>(b, b, ScalarFunctions.multiply());
        Reduce<NAMETYPE> dot_aa = new Reduce<NAMETYPE>(aa, Reduce.Aggregator.sum, this.dimension);
        Reduce<NAMETYPE> dot_ab = new Reduce<NAMETYPE>(ab, Reduce.Aggregator.sum, this.dimension);
        Reduce<NAMETYPE> dot_bb = new Reduce<NAMETYPE>(bb, Reduce.Aggregator.sum, this.dimension);
        Join<NAMETYPE> aabb = new Join<NAMETYPE>(dot_aa, dot_bb, ScalarFunctions.multiply());
        Map<NAMETYPE> sqrt_aabb = new Map<NAMETYPE>(aabb, ScalarFunctions.sqrt());
        return new Join<NAMETYPE>(dot_ab, sqrt_aabb, ScalarFunctions.divide());
    }

    @Override
    public String toString(ToStringContext<NAMETYPE> context) {
        return "cosine_similarity(" + this.arg1.toString(context) + ", " + this.arg2.toString(context) + ", " + this.dimension + ")";
    }

    @Override
    public int hashCode() {
        return Objects.hash("cosine_similarity", this.arg1, this.arg2, this.dimension);
    }
}

