/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.postgresql.rule;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.ConnectorExpressionPatterns;
import io.trino.plugin.base.projection.ProjectFunctionRule;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.plugin.postgresql.rule.RewriteVectorDistanceFunction;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.type.DoubleType;
import java.util.List;
import java.util.Optional;

public final class RewriteDotProductFunction
implements ProjectFunctionRule<JdbcExpression, ParameterizedExpression> {
    private static final Capture<ConnectorExpression> CALL = Capture.newCapture();
    private static final Pattern<Call> PATTERN = ConnectorExpressionPatterns.call().with(ConnectorExpressionPatterns.functionName().equalTo((Object)new FunctionName("$negate"))).with(ConnectorExpressionPatterns.type().matching(type -> type == DoubleType.DOUBLE)).with(ConnectorExpressionPatterns.argumentCount().equalTo((Object)1)).with(ConnectorExpressionPatterns.argument((int)0).matching(ConnectorExpressionPatterns.expression().capturedAs(CALL).matching(expression -> {
        Call call;
        return expression instanceof Call && (call = (Call)expression).getFunctionName().equals((Object)new FunctionName("dot_product")) && call.getArguments().size() == 2 && call.getArguments().stream().allMatch(argument -> RewriteVectorDistanceFunction.isArrayTypeWithRealOrDouble(argument.getType()));
    })));

    public Pattern<? extends ConnectorExpression> getPattern() {
        return PATTERN;
    }

    public Optional<JdbcExpression> rewrite(ConnectorTableHandle handle, ConnectorExpression projectionExpression, Captures captures, ProjectFunctionRule.RewriteContext<ParameterizedExpression> context) {
        ConnectorExpression call = (ConnectorExpression)captures.get(CALL);
        Optional<ParameterizedExpression> leftExpression = RewriteVectorDistanceFunction.rewrite((ConnectorExpression)call.getChildren().getFirst(), context);
        if (leftExpression.isEmpty()) {
            return Optional.empty();
        }
        Optional<ParameterizedExpression> rightExpression = RewriteVectorDistanceFunction.rewrite((ConnectorExpression)call.getChildren().get(1), context);
        if (rightExpression.isEmpty()) {
            return Optional.empty();
        }
        return Optional.of(new JdbcExpression("%s <#> %s".formatted(leftExpression.get().expression(), rightExpression.get().expression()), (List)ImmutableList.builder().addAll((Iterable)leftExpression.get().parameters()).addAll((Iterable)rightExpression.get().parameters()).build(), new JdbcTypeHandle(8, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())));
    }
}

