/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.cql3.functions;

import java.nio.ByteBuffer;
import java.util.List;
import org.apache.cassandra.cql3.CQL3Type;
import org.apache.cassandra.cql3.functions.Arguments;
import org.apache.cassandra.cql3.functions.FunctionArguments;
import org.apache.cassandra.cql3.functions.FunctionFactory;
import org.apache.cassandra.cql3.functions.FunctionParameter;
import org.apache.cassandra.cql3.functions.NativeFunction;
import org.apache.cassandra.cql3.functions.NativeFunctions;
import org.apache.cassandra.cql3.functions.NativeScalarFunction;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.FloatType;
import org.apache.cassandra.db.marshal.VectorType;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.transport.ProtocolVersion;
import org.apache.lucene.index.VectorSimilarityFunction;

public class VectorFcts {
    public static void addFunctionsTo(NativeFunctions functions) {
        functions.add(VectorFcts.createSimilarityFunctionFactory("similarity_cosine", VectorSimilarityFunction.COSINE, false));
        functions.add(VectorFcts.createSimilarityFunctionFactory("similarity_euclidean", VectorSimilarityFunction.EUCLIDEAN, true));
        functions.add(VectorFcts.createSimilarityFunctionFactory("similarity_dot_product", VectorSimilarityFunction.DOT_PRODUCT, true));
    }

    private static FunctionFactory createSimilarityFunctionFactory(String name, final VectorSimilarityFunction luceneFunction, final boolean supportsZeroVectors) {
        return new FunctionFactory(name, new FunctionParameter[]{FunctionParameter.sameAs(1, false, FunctionParameter.vector(CQL3Type.Native.FLOAT)), FunctionParameter.sameAs(0, false, FunctionParameter.vector(CQL3Type.Native.FLOAT))}){

            @Override
            protected NativeFunction doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> receiverType) {
                VectorType firstArgType = (VectorType)argTypes.get(0);
                int dimensions = firstArgType.dimension;
                if (!argTypes.stream().allMatch(t -> ((VectorType)t).dimension == dimensions)) {
                    throw new InvalidRequestException("All arguments must have the same vector dimensions");
                }
                return VectorFcts.createSimilarityFunction(this.name.name, firstArgType, luceneFunction, supportsZeroVectors);
            }
        };
    }

    private static NativeFunction createSimilarityFunction(String name, final VectorType<Float> type, final VectorSimilarityFunction f, final boolean supportsZeroVectors) {
        return new NativeScalarFunction(name, (AbstractType)FloatType.instance, new AbstractType[]{type, type}){

            @Override
            public Arguments newArguments(ProtocolVersion version) {
                return new FunctionArguments(version, (v, b) -> type.composeAsFloat(b), (v, b) -> type.composeAsFloat(b));
            }

            @Override
            public ByteBuffer execute(Arguments arguments) throws InvalidRequestException {
                if (arguments.containsNulls()) {
                    return null;
                }
                float[] v1 = (float[])arguments.get(0);
                float[] v2 = (float[])arguments.get(1);
                if (!supportsZeroVectors && (this.isAllZero(v1) || this.isAllZero(v2))) {
                    throw new InvalidRequestException("Function " + this.name + " doesn't support all-zero vectors.");
                }
                return FloatType.instance.decompose(Float.valueOf(f.compare(v1, v2)));
            }

            private boolean isAllZero(float[] v) {
                for (float f2 : v) {
                    if (f2 == 0.0f) continue;
                    return false;
                }
                return true;
            }
        };
    }
}

