/*
 * Decompiled with CFR 0.152.
 */
package io.substrait.expression.proto;

import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.FunctionOption;
import io.substrait.expression.WindowBound;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.SimpleExtension;
import io.substrait.proto.ConsistentPartitionWindowRel;
import io.substrait.proto.Expression;
import io.substrait.proto.FunctionArgument;
import io.substrait.proto.SortField;
import io.substrait.relation.ConsistentPartitionWindow;
import io.substrait.relation.ProtoRelConverter;
import io.substrait.relation.Rel;
import io.substrait.type.Type;
import io.substrait.type.TypeVisitor;
import io.substrait.type.proto.ProtoTypeConverter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class ProtoExpressionConverter {
    public static final Type.Struct EMPTY_TYPE = Type.Struct.builder().nullable(false).build();
    private final ExtensionLookup lookup;
    private final SimpleExtension.ExtensionCollection extensions;
    private final Type.Struct rootType;
    private final ProtoTypeConverter protoTypeConverter;
    private final ProtoRelConverter protoRelConverter;

    public ProtoExpressionConverter(ExtensionLookup lookup, SimpleExtension.ExtensionCollection extensions, Type.Struct rootType, ProtoRelConverter relConverter) {
        this.lookup = lookup;
        this.extensions = extensions;
        this.rootType = Objects.requireNonNull(rootType, "rootType");
        this.protoTypeConverter = new ProtoTypeConverter(lookup, extensions);
        this.protoRelConverter = relConverter;
    }

    public FieldReference from(Expression.FieldReference reference) {
        Expression.FieldReference.ReferenceTypeCase refTypeCase = reference.getReferenceTypeCase();
        if (refTypeCase == Expression.FieldReference.ReferenceTypeCase.MASKED_REFERENCE) {
            throw new IllegalArgumentException("Unsupported type: " + (Object)((Object)refTypeCase));
        }
        if (refTypeCase != Expression.FieldReference.ReferenceTypeCase.DIRECT_REFERENCE) {
            throw new IllegalArgumentException("Unhandled type: " + (Object)((Object)refTypeCase));
        }
        switch (reference.getRootTypeCase()) {
            case EXPRESSION: {
                return FieldReference.ofExpression(this.from(reference.getExpression()), this.getDirectReferenceSegments(reference.getDirectReference()));
            }
            case ROOT_REFERENCE: {
                return FieldReference.ofRoot(this.rootType, this.getDirectReferenceSegments(reference.getDirectReference()));
            }
            case OUTER_REFERENCE: {
                return FieldReference.newRootStructOuterReference(reference.getDirectReference().getStructField().getField(), this.rootType, reference.getOuterReference().getStepsOut());
            }
        }
        throw new IllegalArgumentException("Unhandled type: " + (Object)((Object)reference.getRootTypeCase()));
    }

    private List<FieldReference.ReferenceSegment> getDirectReferenceSegments(Expression.ReferenceSegment segment) {
        ArrayList<FieldReference.ReferenceSegment> results = new ArrayList<FieldReference.ReferenceSegment>();
        while (segment != Expression.ReferenceSegment.getDefaultInstance()) {
            FieldReference.ReferenceSegment mappedSegment;
            switch (segment.getReferenceTypeCase()) {
                case MAP_KEY: {
                    Expression.ReferenceSegment.MapKey mapKey = segment.getMapKey();
                    segment = mapKey.getChild();
                    mappedSegment = FieldReference.MapKey.of(this.from(mapKey.getMapKey()));
                    break;
                }
                case STRUCT_FIELD: {
                    Expression.ReferenceSegment.StructField structField = segment.getStructField();
                    segment = structField.getChild();
                    mappedSegment = FieldReference.StructField.of(structField.getField());
                    break;
                }
                case LIST_ELEMENT: {
                    Expression.ReferenceSegment.ListElement listElement = segment.getListElement();
                    segment = listElement.getChild();
                    mappedSegment = FieldReference.ListElement.of(listElement.getOffset());
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Unhandled type: " + (Object)((Object)segment.getReferenceTypeCase()));
                }
            }
            results.add(mappedSegment);
        }
        Collections.reverse(results);
        return results;
    }

    public Expression from(io.substrait.proto.Expression expr) {
        switch (expr.getRexTypeCase()) {
            case LITERAL: {
                return this.from(expr.getLiteral());
            }
            case SELECTION: {
                return this.from(expr.getSelection());
            }
            case SCALAR_FUNCTION: {
                Expression.ScalarFunction scalarFunction = expr.getScalarFunction();
                int functionReference = scalarFunction.getFunctionReference();
                SimpleExtension.ScalarFunctionVariant declaration = this.lookup.getScalarFunction(functionReference, this.extensions);
                FunctionArg.ProtoFrom pF = new FunctionArg.ProtoFrom(this, this.protoTypeConverter);
                List args = IntStream.range(0, scalarFunction.getArgumentsCount()).mapToObj(i -> pF.convert(declaration, i, scalarFunction.getArguments(i))).collect(Collectors.toList());
                List options = scalarFunction.getOptionsList().stream().map(ProtoExpressionConverter::fromFunctionOption).collect(Collectors.toList());
                return Expression.ScalarFunctionInvocation.builder().addAllArguments(args).declaration(declaration).outputType(this.protoTypeConverter.from(scalarFunction.getOutputType())).options(options).build();
            }
            case WINDOW_FUNCTION: {
                return this.fromWindowFunction(expr.getWindowFunction());
            }
            case IF_THEN: {
                Expression.IfThen ifThen = expr.getIfThen();
                List clauses = ifThen.getIfsList().stream().map(t -> ExpressionCreator.ifThenClause(this.from(t.getIf()), this.from(t.getThen()))).collect(Collectors.toList());
                return ExpressionCreator.ifThenStatement(this.from(ifThen.getElse()), clauses);
            }
            case SWITCH_EXPRESSION: {
                Expression.SwitchExpression switchExpr = expr.getSwitchExpression();
                List clauses = switchExpr.getIfsList().stream().map(t -> ExpressionCreator.switchClause(this.from(t.getIf()), this.from(t.getThen()))).collect(Collectors.toList());
                return ExpressionCreator.switchStatement(this.from(switchExpr.getMatch()), this.from(switchExpr.getElse()), clauses);
            }
            case SINGULAR_OR_LIST: {
                Expression.SingularOrList orList = expr.getSingularOrList();
                List values = orList.getOptionsList().stream().map(this::from).collect(Collectors.toList());
                return Expression.SingleOrList.builder().condition(this.from(orList.getValue())).addAllOptions(values).build();
            }
            case MULTI_OR_LIST: {
                Expression.MultiOrList multiOrList = expr.getMultiOrList();
                List values = multiOrList.getOptionsList().stream().map(t -> Expression.MultiOrListRecord.builder().addAllValues(t.getFieldsList().stream().map(this::from).collect(Collectors.toList())).build()).collect(Collectors.toList());
                return Expression.MultiOrList.builder().addAllOptionCombinations(values).addAllConditions(multiOrList.getValueList().stream().map(this::from).collect(Collectors.toList())).build();
            }
            case CAST: {
                return ExpressionCreator.cast(this.protoTypeConverter.from(expr.getCast().getType()), this.from(expr.getCast().getInput()), Expression.FailureBehavior.fromProto(expr.getCast().getFailureBehavior()));
            }
            case SUBQUERY: {
                switch (expr.getSubquery().getSubqueryTypeCase()) {
                    case SET_PREDICATE: {
                        Rel rel = this.protoRelConverter.from(expr.getSubquery().getSetPredicate().getTuples());
                        return Expression.SetPredicate.builder().tuples(rel).predicateOp(Expression.PredicateOp.fromProto(expr.getSubquery().getSetPredicate().getPredicateOp())).build();
                    }
                    case SCALAR: {
                        Rel rel = this.protoRelConverter.from(expr.getSubquery().getScalar().getInput());
                        return Expression.ScalarSubquery.builder().input(rel).type(rel.getRecordType().accept(new TypeVisitor.TypeThrowsVisitor<Type, RuntimeException>("Expected struct field"){

                            @Override
                            public Type visit(Type.Struct type) throws RuntimeException {
                                if (type.fields().size() != 1) {
                                    throw new UnsupportedOperationException("Scalar subquery must have exactly one field");
                                }
                                return type.fields().get(0);
                            }
                        })).build();
                    }
                    case IN_PREDICATE: {
                        Rel rel = this.protoRelConverter.from(expr.getSubquery().getInPredicate().getHaystack());
                        List needles = expr.getSubquery().getInPredicate().getNeedlesList().stream().map(e -> this.from((io.substrait.proto.Expression)e)).collect(Collectors.toList());
                        return Expression.InPredicate.builder().haystack(rel).needles(needles).build();
                    }
                    case SET_COMPARISON: {
                        throw new UnsupportedOperationException("Unsupported subquery type: " + (Object)((Object)expr.getSubquery().getSubqueryTypeCase()));
                    }
                }
                throw new IllegalArgumentException("Unknown subquery type: " + (Object)((Object)expr.getSubquery().getSubqueryTypeCase()));
            }
            case ENUM: {
                throw new UnsupportedOperationException("Unsupported type: " + (Object)((Object)expr.getRexTypeCase()));
            }
        }
        throw new IllegalArgumentException("Unknown type: " + (Object)((Object)expr.getRexTypeCase()));
    }

    public Expression.WindowFunctionInvocation fromWindowFunction(Expression.WindowFunction windowFunction) {
        int functionReference = windowFunction.getFunctionReference();
        SimpleExtension.WindowFunctionVariant declaration = this.lookup.getWindowFunction(functionReference, this.extensions);
        FunctionArg.ProtoFrom argVisitor = new FunctionArg.ProtoFrom(this, this.protoTypeConverter);
        List<FunctionArg> args = ProtoExpressionConverter.fromFunctionArgumentList(windowFunction.getArgumentsCount(), argVisitor, declaration, windowFunction::getArguments);
        List partitionExprs = windowFunction.getPartitionsList().stream().map(this::from).collect(Collectors.toList());
        List sortFields = windowFunction.getSortsList().stream().map(this::fromSortField).collect(Collectors.toList());
        List options = windowFunction.getOptionsList().stream().map(ProtoExpressionConverter::fromFunctionOption).collect(Collectors.toList());
        WindowBound lowerBound = this.toWindowBound(windowFunction.getLowerBound());
        WindowBound upperBound = this.toWindowBound(windowFunction.getUpperBound());
        return Expression.WindowFunctionInvocation.builder().arguments(args).declaration(declaration).outputType(this.protoTypeConverter.from(windowFunction.getOutputType())).aggregationPhase(Expression.AggregationPhase.fromProto(windowFunction.getPhase())).partitionBy(partitionExprs).sort(sortFields).boundsType(Expression.WindowBoundsType.fromProto(windowFunction.getBoundsType())).lowerBound(lowerBound).upperBound(upperBound).invocation(Expression.AggregationInvocation.fromProto(windowFunction.getInvocation())).options(options).build();
    }

    public ConsistentPartitionWindow.WindowRelFunctionInvocation fromWindowRelFunction(ConsistentPartitionWindowRel.WindowRelFunction windowRelFunction) {
        int functionReference = windowRelFunction.getFunctionReference();
        SimpleExtension.WindowFunctionVariant declaration = this.lookup.getWindowFunction(functionReference, this.extensions);
        FunctionArg.ProtoFrom argVisitor = new FunctionArg.ProtoFrom(this, this.protoTypeConverter);
        List<FunctionArg> args = ProtoExpressionConverter.fromFunctionArgumentList(windowRelFunction.getArgumentsCount(), argVisitor, declaration, windowRelFunction::getArguments);
        List options = windowRelFunction.getOptionsList().stream().map(ProtoExpressionConverter::fromFunctionOption).collect(Collectors.toList());
        WindowBound lowerBound = this.toWindowBound(windowRelFunction.getLowerBound());
        WindowBound upperBound = this.toWindowBound(windowRelFunction.getUpperBound());
        return ConsistentPartitionWindow.WindowRelFunctionInvocation.builder().arguments(args).declaration(declaration).outputType(this.protoTypeConverter.from(windowRelFunction.getOutputType())).aggregationPhase(Expression.AggregationPhase.fromProto(windowRelFunction.getPhase())).boundsType(Expression.WindowBoundsType.fromProto(windowRelFunction.getBoundsType())).lowerBound(lowerBound).upperBound(upperBound).invocation(Expression.AggregationInvocation.fromProto(windowRelFunction.getInvocation())).options(options).build();
    }

    private WindowBound toWindowBound(Expression.WindowFunction.Bound bound) {
        switch (bound.getKindCase()) {
            case PRECEDING: {
                return WindowBound.Preceding.of(bound.getPreceding().getOffset());
            }
            case FOLLOWING: {
                return WindowBound.Following.of(bound.getFollowing().getOffset());
            }
            case CURRENT_ROW: {
                return WindowBound.CURRENT_ROW;
            }
            case UNBOUNDED: {
                return WindowBound.UNBOUNDED;
            }
            case KIND_NOT_SET: {
                return WindowBound.UNBOUNDED;
            }
        }
        throw new IllegalArgumentException("Unsupported bound kind: " + (Object)((Object)bound.getKindCase()));
    }

    public Expression.Literal from(Expression.Literal literal) {
        switch (literal.getLiteralTypeCase()) {
            case BOOLEAN: {
                return ExpressionCreator.bool(literal.getNullable(), literal.getBoolean());
            }
            case I8: {
                return ExpressionCreator.i8(literal.getNullable(), literal.getI8());
            }
            case I16: {
                return ExpressionCreator.i16(literal.getNullable(), literal.getI16());
            }
            case I32: {
                return ExpressionCreator.i32(literal.getNullable(), literal.getI32());
            }
            case I64: {
                return ExpressionCreator.i64(literal.getNullable(), literal.getI64());
            }
            case FP32: {
                return ExpressionCreator.fp32(literal.getNullable(), literal.getFp32());
            }
            case FP64: {
                return ExpressionCreator.fp64(literal.getNullable(), literal.getFp64());
            }
            case STRING: {
                return ExpressionCreator.string(literal.getNullable(), literal.getString());
            }
            case BINARY: {
                return ExpressionCreator.binary(literal.getNullable(), literal.getBinary());
            }
            case TIMESTAMP: {
                return ExpressionCreator.timestamp(literal.getNullable(), literal.getTimestamp());
            }
            case TIMESTAMP_TZ: {
                return ExpressionCreator.timestampTZ(literal.getNullable(), literal.getTimestampTz());
            }
            case PRECISION_TIMESTAMP: {
                return ExpressionCreator.precisionTimestamp(literal.getNullable(), literal.getPrecisionTimestamp().getValue(), literal.getPrecisionTimestamp().getPrecision());
            }
            case PRECISION_TIMESTAMP_TZ: {
                return ExpressionCreator.precisionTimestampTZ(literal.getNullable(), literal.getPrecisionTimestampTz().getValue(), literal.getPrecisionTimestampTz().getPrecision());
            }
            case DATE: {
                return ExpressionCreator.date(literal.getNullable(), literal.getDate());
            }
            case TIME: {
                return ExpressionCreator.time(literal.getNullable(), literal.getTime());
            }
            case INTERVAL_YEAR_TO_MONTH: {
                return ExpressionCreator.intervalYear(literal.getNullable(), literal.getIntervalYearToMonth().getYears(), literal.getIntervalYearToMonth().getMonths());
            }
            case INTERVAL_DAY_TO_SECOND: {
                int precision = literal.getIntervalDayToSecond().hasPrecision() ? literal.getIntervalDayToSecond().getPrecision() : 6;
                long subseconds = literal.getIntervalDayToSecond().hasPrecision() ? literal.getIntervalDayToSecond().getSubseconds() : (long)literal.getIntervalDayToSecond().getMicroseconds();
                return ExpressionCreator.intervalDay(literal.getNullable(), literal.getIntervalDayToSecond().getDays(), literal.getIntervalDayToSecond().getSeconds(), subseconds, precision);
            }
            case INTERVAL_COMPOUND: {
                if (!literal.getIntervalCompound().getIntervalDayToSecond().hasPrecision()) {
                    throw new UnsupportedOperationException("Interval compound with deprecated version of interval day (ie. no precision) is not supported");
                }
                return ExpressionCreator.intervalCompound(literal.getNullable(), literal.getIntervalCompound().getIntervalYearToMonth().getYears(), literal.getIntervalCompound().getIntervalYearToMonth().getMonths(), literal.getIntervalCompound().getIntervalDayToSecond().getDays(), literal.getIntervalCompound().getIntervalDayToSecond().getSeconds(), literal.getIntervalCompound().getIntervalDayToSecond().getSubseconds(), literal.getIntervalCompound().getIntervalDayToSecond().getPrecision());
            }
            case FIXED_CHAR: {
                return ExpressionCreator.fixedChar(literal.getNullable(), literal.getFixedChar());
            }
            case VAR_CHAR: {
                return ExpressionCreator.varChar(literal.getNullable(), literal.getVarChar().getValue(), literal.getVarChar().getLength());
            }
            case FIXED_BINARY: {
                return ExpressionCreator.fixedBinary(literal.getNullable(), literal.getFixedBinary());
            }
            case DECIMAL: {
                return ExpressionCreator.decimal(literal.getNullable(), literal.getDecimal().getValue(), literal.getDecimal().getPrecision(), literal.getDecimal().getScale());
            }
            case STRUCT: {
                return ExpressionCreator.struct(literal.getNullable(), literal.getStruct().getFieldsList().stream().map(this::from).collect(Collectors.toList()));
            }
            case MAP: {
                return ExpressionCreator.map(literal.getNullable(), literal.getMap().getKeyValuesList().stream().collect(Collectors.toMap(kv -> this.from(kv.getKey()), kv -> this.from(kv.getValue()))));
            }
            case EMPTY_MAP: {
                Type.Map mapType = this.protoTypeConverter.fromMap(literal.getEmptyMap());
                return ExpressionCreator.emptyMap(mapType.nullable(), mapType.key(), mapType.value());
            }
            case UUID: {
                return ExpressionCreator.uuid(literal.getNullable(), literal.getUuid());
            }
            case NULL: {
                return ExpressionCreator.typedNull(this.protoTypeConverter.from(literal.getNull()));
            }
            case LIST: {
                return ExpressionCreator.list(literal.getNullable(), literal.getList().getValuesList().stream().map(this::from).collect(Collectors.toList()));
            }
            case EMPTY_LIST: {
                Type.ListType listType = this.protoTypeConverter.fromList(literal.getEmptyList());
                return ExpressionCreator.emptyList(listType.nullable(), listType.elementType());
            }
            case USER_DEFINED: {
                Expression.Literal.UserDefined userDefinedLiteral = literal.getUserDefined();
                SimpleExtension.Type type = this.lookup.getType(userDefinedLiteral.getTypeReference(), this.extensions);
                return ExpressionCreator.userDefinedLiteral(literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue());
            }
        }
        throw new IllegalStateException("Unexpected value: " + (Object)((Object)literal.getLiteralTypeCase()));
    }

    private static List<FunctionArg> fromFunctionArgumentList(int argumentsCount, FunctionArg.ProtoFrom argVisitor, SimpleExtension.Function declaration, Function<Integer, FunctionArgument> argFunction) {
        return IntStream.range(0, argumentsCount).mapToObj(i -> argVisitor.convert(declaration, i, (FunctionArgument)argFunction.apply(i))).collect(Collectors.toList());
    }

    public Expression.SortField fromSortField(SortField s) {
        return Expression.SortField.builder().direction(Expression.SortDirection.fromProto(s.getDirection())).expr(this.from(s.getExpr())).build();
    }

    public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) {
        return FunctionOption.builder().name(o.getName()).addAllValues((Iterable<String>)o.getPreferenceList()).build();
    }
}

