/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.postgresql.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.ConnectorExpressionPatterns;
import io.trino.plugin.base.projection.ProjectFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.spi.block.Block;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Constant;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.expression.StandardFunctions;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.RealType;
import io.trino.spi.type.Type;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.StringJoiner;

public final class RewriteVectorDistanceFunction
implements ProjectFunctionRule<JdbcExpression, ParameterizedExpression> {
    private static final Capture<ConnectorExpression> LEFT_ARGUMENT = Capture.newCapture();
    private static final Capture<ConnectorExpression> RIGHT_ARGUMENT = Capture.newCapture();
    private final Pattern<Call> pattern;
    private final String operator;

    public RewriteVectorDistanceFunction(String functionName, String operator) {
        this.pattern = ConnectorExpressionPatterns.call().with(ConnectorExpressionPatterns.functionName().equalTo((Object)new FunctionName(Objects.requireNonNull(functionName, "functionName is null")))).with(ConnectorExpressionPatterns.type().matching(type -> type == DoubleType.DOUBLE)).with(ConnectorExpressionPatterns.argumentCount().equalTo((Object)2)).with(ConnectorExpressionPatterns.argument((int)0).matching(ConnectorExpressionPatterns.expression().capturedAs(LEFT_ARGUMENT).with(ConnectorExpressionPatterns.type().matching(RewriteVectorDistanceFunction::isArrayTypeWithRealOrDouble)))).with(ConnectorExpressionPatterns.argument((int)1).matching(ConnectorExpressionPatterns.expression().capturedAs(RIGHT_ARGUMENT).with(ConnectorExpressionPatterns.type().matching(RewriteVectorDistanceFunction::isArrayTypeWithRealOrDouble))));
        this.operator = Objects.requireNonNull(operator, "operator is null");
    }

    public Pattern<? extends ConnectorExpression> getPattern() {
        return this.pattern;
    }

    public Optional<JdbcExpression> rewrite(ConnectorTableHandle handle, ConnectorExpression projectionExpression, Captures captures, ProjectFunctionRule.RewriteContext<ParameterizedExpression> context) {
        Optional<ParameterizedExpression> leftExpression = RewriteVectorDistanceFunction.rewrite((ConnectorExpression)captures.get(LEFT_ARGUMENT), context);
        if (leftExpression.isEmpty()) {
            return Optional.empty();
        }
        Optional<ParameterizedExpression> rightExpression = RewriteVectorDistanceFunction.rewrite((ConnectorExpression)captures.get(RIGHT_ARGUMENT), context);
        if (rightExpression.isEmpty()) {
            return Optional.empty();
        }
        return Optional.of(new JdbcExpression("%s %s %s".formatted(leftExpression.get().expression(), this.operator, rightExpression.get().expression()), (List)ImmutableList.builder().addAll((Iterable)leftExpression.get().parameters()).addAll((Iterable)rightExpression.get().parameters()).build(), new JdbcTypeHandle(8, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())));
    }

    public static Optional<ParameterizedExpression> rewrite(ConnectorExpression expression, ProjectFunctionRule.RewriteContext<ParameterizedExpression> context) {
        Call call;
        if (expression instanceof Constant) {
            Constant constant = (Constant)expression;
            Type elementType = ((ArrayType)constant.getType()).getElementType();
            Block value = (Block)constant.getValue();
            StringJoiner vector = new StringJoiner(",", "'[", "]'");
            for (int i = 0; i < value.getPositionCount(); ++i) {
                if (value.isNull(i)) {
                    return Optional.empty();
                }
                double doubleValue = elementType.getDouble(value, i);
                if (!RewriteVectorDistanceFunction.isSupportedVector(doubleValue)) {
                    return Optional.empty();
                }
                vector.add(Double.toString(doubleValue));
            }
            return Optional.of(new ParameterizedExpression(vector.toString(), (List)ImmutableList.of()));
        }
        if (expression instanceof Call && (call = (Call)expression).getFunctionName().equals((Object)StandardFunctions.CAST_FUNCTION_NAME)) {
            ConnectorExpression argument = (ConnectorExpression)Iterables.getOnlyElement((Iterable)call.getArguments());
            if (argument instanceof Variable) {
                Variable variable = (Variable)argument;
                JdbcColumnHandle columnHandle = (JdbcColumnHandle)context.getAssignment(variable.getName());
                JdbcTypeHandle typeHandle = columnHandle.getJdbcTypeHandle();
                if (!typeHandle.jdbcTypeName().map(type -> type.equals("vector")).orElse(false).booleanValue()) {
                    return Optional.empty();
                }
                return Optional.of(new ParameterizedExpression(RewriteVectorDistanceFunction.quoted(columnHandle.getColumnName()), (List)ImmutableList.of()));
            }
            return Optional.empty();
        }
        Optional translatedArgument = context.rewriteExpression(expression);
        if (translatedArgument.isEmpty()) {
            return Optional.empty();
        }
        return Optional.of((ParameterizedExpression)translatedArgument.orElseThrow());
    }

    public static boolean isArrayTypeWithRealOrDouble(Type type) {
        ArrayType arrayType;
        return type instanceof ArrayType && ((arrayType = (ArrayType)type).getElementType() == RealType.REAL || arrayType.getElementType() == DoubleType.DOUBLE);
    }

    private static boolean isSupportedVector(double value) {
        return !Double.isNaN(value) && !Double.isInfinite(value) && value >= (double)1.4E-45f && value <= 3.4028234663852886E38;
    }

    private static String quoted(String name) {
        name = name.replace("\"", "\"\"");
        return "\"" + name + "\"";
    }
}

