/*
 * Decompiled with CFR 0.152.
 */
package io.substrait.relation;

import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.relation.AbstractUpdate;
import io.substrait.relation.Aggregate;
import io.substrait.relation.ConsistentPartitionWindow;
import io.substrait.relation.CopyOnWriteUtils;
import io.substrait.relation.Cross;
import io.substrait.relation.EmptyScan;
import io.substrait.relation.Expand;
import io.substrait.relation.ExpressionCopyOnWriteVisitor;
import io.substrait.relation.ExtensionDdl;
import io.substrait.relation.ExtensionLeaf;
import io.substrait.relation.ExtensionMulti;
import io.substrait.relation.ExtensionSingle;
import io.substrait.relation.ExtensionTable;
import io.substrait.relation.ExtensionWrite;
import io.substrait.relation.Fetch;
import io.substrait.relation.Filter;
import io.substrait.relation.ImmutableJoin;
import io.substrait.relation.Join;
import io.substrait.relation.LocalFiles;
import io.substrait.relation.NamedDdl;
import io.substrait.relation.NamedScan;
import io.substrait.relation.NamedUpdate;
import io.substrait.relation.NamedWrite;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import io.substrait.relation.RelVisitor;
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.VirtualTableScan;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.MergeJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.util.EmptyVisitationContext;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;

public class RelCopyOnWriteVisitor<E extends Exception>
implements RelVisitor<Optional<Rel>, EmptyVisitationContext, E> {
    private final ExpressionCopyOnWriteVisitor<E> expressionCopyOnWriteVisitor;

    public RelCopyOnWriteVisitor() {
        this.expressionCopyOnWriteVisitor = new ExpressionCopyOnWriteVisitor(this);
    }

    public RelCopyOnWriteVisitor(ExpressionCopyOnWriteVisitor<E> expressionCopyOnWriteVisitor) {
        this.expressionCopyOnWriteVisitor = expressionCopyOnWriteVisitor;
    }

    public RelCopyOnWriteVisitor(Function<RelCopyOnWriteVisitor<E>, ExpressionCopyOnWriteVisitor<E>> fn) {
        this.expressionCopyOnWriteVisitor = fn.apply(this);
    }

    protected ExpressionCopyOnWriteVisitor<E> getExpressionCopyOnWriteVisitor() {
        return this.expressionCopyOnWriteVisitor;
    }

    @Override
    public Optional<Rel> visit(Aggregate aggregate, EmptyVisitationContext context) throws E {
        Optional input = (Optional)aggregate.getInput().accept(this, context);
        Optional<List<Aggregate.Grouping>> groupings = CopyOnWriteUtils.transformList(aggregate.getGroupings(), context, this::visitGrouping);
        Optional<List<Aggregate.Measure>> measures = CopyOnWriteUtils.transformList(aggregate.getMeasures(), context, this::visitMeasure);
        if (CopyOnWriteUtils.allEmpty(input, groupings, measures)) {
            return Optional.empty();
        }
        return Optional.of(Aggregate.builder().from(aggregate).input(input.orElse(aggregate.getInput())).groupings((Iterable<? extends Aggregate.Grouping>)groupings.orElse(aggregate.getGroupings())).measures((Iterable<? extends Aggregate.Measure>)measures.orElse(aggregate.getMeasures())).build());
    }

    protected Optional<Aggregate.Grouping> visitGrouping(Aggregate.Grouping grouping, EmptyVisitationContext context) throws E {
        return this.visitExprList(grouping.getExpressions(), context).map(exprs -> Aggregate.Grouping.builder().from(grouping).expressions((Iterable<? extends Expression>)exprs).build());
    }

    protected Optional<Aggregate.Measure> visitMeasure(Aggregate.Measure measure, EmptyVisitationContext context) throws E {
        Optional<Expression> preMeasureFilter = this.visitOptionalExpression(measure.getPreMeasureFilter(), context);
        Optional<AggregateFunctionInvocation> afi = this.visitAggregateFunction(measure.getFunction(), context);
        if (CopyOnWriteUtils.allEmpty(preMeasureFilter, afi)) {
            return Optional.empty();
        }
        return Optional.of(Aggregate.Measure.builder().from(measure).preMeasureFilter(CopyOnWriteUtils.or(preMeasureFilter, measure::getPreMeasureFilter)).function(afi.orElse(measure.getFunction())).build());
    }

    protected Optional<AggregateFunctionInvocation> visitAggregateFunction(AggregateFunctionInvocation afi, EmptyVisitationContext context) throws E {
        Optional<List<FunctionArg>> arguments = this.visitFunctionArguments(afi.arguments(), context);
        Optional<List<Expression.SortField>> sort = CopyOnWriteUtils.transformList(afi.sort(), context, this::visitSortField);
        if (CopyOnWriteUtils.allEmpty(arguments, sort)) {
            return Optional.empty();
        }
        return Optional.of(AggregateFunctionInvocation.builder().from(afi).arguments((Iterable<? extends FunctionArg>)arguments.orElse(afi.arguments())).sort((Iterable<? extends Expression.SortField>)sort.orElse(afi.sort())).build());
    }

    @Override
    public Optional<Rel> visit(EmptyScan emptyScan, EmptyVisitationContext context) throws E {
        Optional<Expression> filter = this.visitOptionalExpression(emptyScan.getFilter(), context);
        if (CopyOnWriteUtils.allEmpty(filter)) {
            return Optional.empty();
        }
        return Optional.of(EmptyScan.builder().from(emptyScan).filter(filter.isPresent() ? filter : emptyScan.getFilter()).build());
    }

    @Override
    public Optional<Rel> visit(Fetch fetch, EmptyVisitationContext context) throws E {
        return ((Optional)fetch.getInput().accept(this, context)).map(input -> Fetch.builder().from(fetch).input((Rel)input).build());
    }

    @Override
    public Optional<Rel> visit(Filter filter, EmptyVisitationContext context) throws E {
        Optional input = (Optional)filter.getInput().accept(this, context);
        Optional condition = (Optional)filter.getCondition().accept(this.getExpressionCopyOnWriteVisitor(), context);
        if (CopyOnWriteUtils.allEmpty(input, condition)) {
            return Optional.empty();
        }
        return Optional.of(Filter.builder().from(filter).input(input.orElse(filter.getInput())).condition(condition.orElse(filter.getCondition())).build());
    }

    @Override
    public Optional<Rel> visit(Join join, EmptyVisitationContext context) throws E {
        Optional left = (Optional)join.getLeft().accept(this, context);
        Optional right = (Optional)join.getRight().accept(this, context);
        Optional<Expression> condition = this.visitOptionalExpression(join.getCondition(), context);
        Optional<Expression> postFilter = this.visitOptionalExpression(join.getPostJoinFilter(), context);
        if (CopyOnWriteUtils.allEmpty(left, right, condition, postFilter)) {
            return Optional.empty();
        }
        return Optional.of(ImmutableJoin.builder().from(join).left(left.orElse(join.getLeft())).right(right.orElse(join.getRight())).condition(CopyOnWriteUtils.or(condition, join::getCondition)).postJoinFilter(CopyOnWriteUtils.or(postFilter, join::getPostJoinFilter)).build());
    }

    @Override
    public Optional<Rel> visit(Set set, EmptyVisitationContext context) throws E {
        return CopyOnWriteUtils.transformList(set.getInputs(), context, (t, c) -> (Optional)t.accept(this, c)).map(s -> Set.builder().from(set).inputs((Iterable<? extends Rel>)s).build());
    }

    @Override
    public Optional<Rel> visit(NamedScan namedScan, EmptyVisitationContext context) throws E {
        Optional<Expression> filter = this.visitOptionalExpression(namedScan.getFilter(), context);
        if (CopyOnWriteUtils.allEmpty(filter)) {
            return Optional.empty();
        }
        return Optional.of(NamedScan.builder().from(namedScan).filter(CopyOnWriteUtils.or(filter, namedScan::getFilter)).build());
    }

    @Override
    public Optional<Rel> visit(LocalFiles localFiles, EmptyVisitationContext context) throws E {
        Optional<Expression> filter = this.visitOptionalExpression(localFiles.getFilter(), context);
        if (CopyOnWriteUtils.allEmpty(filter)) {
            return Optional.empty();
        }
        return Optional.of(LocalFiles.builder().from(localFiles).filter(CopyOnWriteUtils.or(filter, localFiles::getFilter)).build());
    }

    @Override
    public Optional<Rel> visit(Project project, EmptyVisitationContext context) throws E {
        Optional input = (Optional)project.getInput().accept(this, context);
        Optional<List<Expression>> expressions = this.visitExprList(project.getExpressions(), context);
        if (CopyOnWriteUtils.allEmpty(input, expressions)) {
            return Optional.empty();
        }
        return Optional.of(Project.builder().from(project).input(input.orElse(project.getInput())).expressions((Iterable<? extends Expression>)expressions.orElse(project.getExpressions())).build());
    }

    @Override
    public Optional<Rel> visit(Expand expand, EmptyVisitationContext context) throws E {
        throw new UnsupportedOperationException();
    }

    @Override
    public Optional<Rel> visit(NamedWrite write, EmptyVisitationContext context) throws E {
        Optional input = (Optional)write.getInput().accept(this, context);
        if (CopyOnWriteUtils.allEmpty(input)) {
            return Optional.empty();
        }
        return Optional.of(NamedWrite.builder().from(write).input(input.orElse(write.getInput())).build());
    }

    @Override
    public Optional<Rel> visit(ExtensionWrite write, EmptyVisitationContext context) throws E {
        throw new UnsupportedOperationException();
    }

    @Override
    public Optional<Rel> visit(NamedDdl ddl, EmptyVisitationContext context) throws E {
        throw new UnsupportedOperationException();
    }

    @Override
    public Optional<Rel> visit(ExtensionDdl ddl, EmptyVisitationContext context) throws E {
        throw new UnsupportedOperationException();
    }

    protected Optional<AbstractUpdate.TransformExpression> visitTransformExpression(AbstractUpdate.TransformExpression transform, EmptyVisitationContext context) throws E {
        return ((Optional)transform.getTransformation().accept(this.getExpressionCopyOnWriteVisitor(), context)).map(expr -> AbstractUpdate.TransformExpression.builder().from(transform).transformation((Expression)expr).build());
    }

    @Override
    public Optional<Rel> visit(NamedUpdate update, EmptyVisitationContext context) throws E {
        Optional condition = (Optional)update.getCondition().accept(this.getExpressionCopyOnWriteVisitor(), context);
        Optional<List<AbstractUpdate.TransformExpression>> transformations = CopyOnWriteUtils.transformList(update.getTransformations(), context, this::visitTransformExpression);
        if (CopyOnWriteUtils.allEmpty(condition, transformations)) {
            return Optional.empty();
        }
        return Optional.of(NamedUpdate.builder().from(update).condition(condition.orElse(update.getCondition())).transformations((Iterable<? extends AbstractUpdate.TransformExpression>)transformations.orElse(update.getTransformations())).build());
    }

    @Override
    public Optional<Rel> visit(Sort sort, EmptyVisitationContext context) throws E {
        Optional input = (Optional)sort.getInput().accept(this, context);
        Optional<List<Expression.SortField>> sortFields = CopyOnWriteUtils.transformList(sort.getSortFields(), context, this::visitSortField);
        if (CopyOnWriteUtils.allEmpty(input, sortFields)) {
            return Optional.empty();
        }
        return Optional.of(Sort.builder().from(sort).input(input.orElse(sort.getInput())).sortFields((Iterable<? extends Expression.SortField>)sortFields.orElse(sort.getSortFields())).build());
    }

    @Override
    public Optional<Rel> visit(Cross cross, EmptyVisitationContext context) throws E {
        Optional left = (Optional)cross.getLeft().accept(this, context);
        Optional right = (Optional)cross.getRight().accept(this, context);
        if (CopyOnWriteUtils.allEmpty(left, right)) {
            return Optional.empty();
        }
        return Optional.of(Cross.builder().from(cross).left(left.orElse(cross.getLeft())).right(right.orElse(cross.getRight())).build());
    }

    @Override
    public Optional<Rel> visit(VirtualTableScan virtualTableScan, EmptyVisitationContext context) throws E {
        Optional<Expression> filter = this.visitOptionalExpression(virtualTableScan.getFilter(), context);
        if (CopyOnWriteUtils.allEmpty(filter)) {
            return Optional.empty();
        }
        return Optional.of(VirtualTableScan.builder().from(virtualTableScan).filter(CopyOnWriteUtils.or(filter, virtualTableScan::getFilter)).build());
    }

    @Override
    public Optional<Rel> visit(ExtensionLeaf extensionLeaf, EmptyVisitationContext context) throws E {
        return Optional.empty();
    }

    @Override
    public Optional<Rel> visit(ExtensionSingle extensionSingle, EmptyVisitationContext context) throws E {
        return ((Optional)extensionSingle.getInput().accept(this, context)).map(input -> ExtensionSingle.builder().from(extensionSingle).input((Rel)input).build());
    }

    @Override
    public Optional<Rel> visit(ExtensionMulti extensionMulti, EmptyVisitationContext context) throws E {
        return CopyOnWriteUtils.transformList(extensionMulti.getInputs(), context, (rel, c) -> (Optional)rel.accept(this, c)).map(inputs -> ExtensionMulti.builder().from(extensionMulti).inputs((Iterable<? extends Rel>)inputs).build());
    }

    @Override
    public Optional<Rel> visit(ExtensionTable extensionTable, EmptyVisitationContext context) throws E {
        Optional<Expression> filter = this.visitOptionalExpression(extensionTable.getFilter(), context);
        if (CopyOnWriteUtils.allEmpty(filter)) {
            return Optional.empty();
        }
        return Optional.of(ExtensionTable.builder().from(extensionTable).filter(CopyOnWriteUtils.or(filter, extensionTable::getFilter)).build());
    }

    @Override
    public Optional<Rel> visit(HashJoin hashJoin, EmptyVisitationContext context) throws E {
        Optional left = (Optional)hashJoin.getLeft().accept(this, context);
        Optional right = (Optional)hashJoin.getRight().accept(this, context);
        Optional<List<FieldReference>> leftKeys = CopyOnWriteUtils.transformList(hashJoin.getLeftKeys(), context, this::visitFieldReference);
        Optional<List<FieldReference>> rightKeys = CopyOnWriteUtils.transformList(hashJoin.getRightKeys(), context, this::visitFieldReference);
        Optional<Expression> postFilter = this.visitOptionalExpression(hashJoin.getPostJoinFilter(), context);
        if (CopyOnWriteUtils.allEmpty(left, right, leftKeys, rightKeys, postFilter)) {
            return Optional.empty();
        }
        return Optional.of(HashJoin.builder().from(hashJoin).left(left.orElse(hashJoin.getLeft())).right(right.orElse(hashJoin.getRight())).leftKeys((Iterable<? extends FieldReference>)leftKeys.orElse(hashJoin.getLeftKeys())).rightKeys((Iterable<? extends FieldReference>)rightKeys.orElse(hashJoin.getRightKeys())).postJoinFilter(CopyOnWriteUtils.or(postFilter, hashJoin::getPostJoinFilter)).build());
    }

    @Override
    public Optional<Rel> visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws E {
        Optional left = (Optional)mergeJoin.getLeft().accept(this, context);
        Optional right = (Optional)mergeJoin.getRight().accept(this, context);
        Optional<List<FieldReference>> leftKeys = CopyOnWriteUtils.transformList(mergeJoin.getLeftKeys(), context, this::visitFieldReference);
        Optional<List<FieldReference>> rightKeys = CopyOnWriteUtils.transformList(mergeJoin.getRightKeys(), context, this::visitFieldReference);
        Optional<Expression> postFilter = this.visitOptionalExpression(mergeJoin.getPostJoinFilter(), context);
        if (CopyOnWriteUtils.allEmpty(left, right, leftKeys, rightKeys, postFilter)) {
            return Optional.empty();
        }
        return Optional.of(MergeJoin.builder().from(mergeJoin).left(left.orElse(mergeJoin.getLeft())).right(right.orElse(mergeJoin.getRight())).leftKeys((Iterable<? extends FieldReference>)leftKeys.orElse(mergeJoin.getLeftKeys())).rightKeys((Iterable<? extends FieldReference>)rightKeys.orElse(mergeJoin.getRightKeys())).postJoinFilter(CopyOnWriteUtils.or(postFilter, mergeJoin::getPostJoinFilter)).build());
    }

    @Override
    public Optional<Rel> visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext context) throws E {
        Optional left = (Optional)nestedLoopJoin.getLeft().accept(this, context);
        Optional right = (Optional)nestedLoopJoin.getRight().accept(this, context);
        Optional condition = (Optional)nestedLoopJoin.getCondition().accept(this.getExpressionCopyOnWriteVisitor(), context);
        if (CopyOnWriteUtils.allEmpty(left, right, condition)) {
            return Optional.empty();
        }
        return Optional.of(NestedLoopJoin.builder().from(nestedLoopJoin).left(left.orElse(nestedLoopJoin.getLeft())).right(right.orElse(nestedLoopJoin.getRight())).condition(condition.orElse(nestedLoopJoin.getCondition())).build());
    }

    @Override
    public Optional<Rel> visit(ConsistentPartitionWindow consistentPartitionWindow, EmptyVisitationContext context) throws E {
        Optional<List<ConsistentPartitionWindow.WindowRelFunctionInvocation>> windowFunctions = CopyOnWriteUtils.transformList(consistentPartitionWindow.getWindowFunctions(), context, this::visitWindowRelFunction);
        Optional<List<Expression>> partitionExpressions = CopyOnWriteUtils.transformList(consistentPartitionWindow.getPartitionExpressions(), context, (t, c) -> (Optional)t.accept(this.getExpressionCopyOnWriteVisitor(), c));
        Optional<List<Expression.SortField>> sorts = CopyOnWriteUtils.transformList(consistentPartitionWindow.getSorts(), context, this::visitSortField);
        if (CopyOnWriteUtils.allEmpty(windowFunctions, partitionExpressions, sorts)) {
            return Optional.empty();
        }
        return Optional.of(ConsistentPartitionWindow.builder().from(consistentPartitionWindow).partitionExpressions((Iterable<? extends Expression>)partitionExpressions.orElse(consistentPartitionWindow.getPartitionExpressions())).sorts((Iterable<? extends Expression.SortField>)sorts.orElse(consistentPartitionWindow.getSorts())).windowFunctions((Iterable<? extends ConsistentPartitionWindow.WindowRelFunctionInvocation>)windowFunctions.orElse(consistentPartitionWindow.getWindowFunctions())).build());
    }

    protected Optional<ConsistentPartitionWindow.WindowRelFunctionInvocation> visitWindowRelFunction(ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFunctionInvocation, EmptyVisitationContext context) throws E {
        Optional<List<FunctionArg>> functionArgs = this.visitFunctionArguments(windowRelFunctionInvocation.arguments(), context);
        if (CopyOnWriteUtils.allEmpty(functionArgs)) {
            return Optional.empty();
        }
        return Optional.of(ConsistentPartitionWindow.WindowRelFunctionInvocation.builder().from(windowRelFunctionInvocation).arguments((Iterable<? extends FunctionArg>)functionArgs.orElse(windowRelFunctionInvocation.arguments())).build());
    }

    protected Optional<List<Expression>> visitExprList(List<Expression> exprs, EmptyVisitationContext context) throws E {
        return CopyOnWriteUtils.transformList(exprs, context, (t, c) -> (Optional)t.accept(this.getExpressionCopyOnWriteVisitor(), c));
    }

    public Optional<FieldReference> visitFieldReference(FieldReference fieldReference, EmptyVisitationContext context) throws E {
        Optional<Expression> inputExpression = this.visitOptionalExpression(fieldReference.inputExpression(), context);
        if (CopyOnWriteUtils.allEmpty(inputExpression)) {
            return Optional.empty();
        }
        return Optional.of(FieldReference.builder().inputExpression(inputExpression).build());
    }

    protected Optional<List<FunctionArg>> visitFunctionArguments(List<FunctionArg> funcArgs, EmptyVisitationContext context) throws E {
        return CopyOnWriteUtils.transformList(funcArgs, context, (arg, c) -> {
            if (arg instanceof Expression) {
                return ((Optional)((Expression)arg).accept(this.getExpressionCopyOnWriteVisitor(), c)).flatMap(Optional::of);
            }
            return Optional.empty();
        });
    }

    protected Optional<Expression.SortField> visitSortField(Expression.SortField sortField, EmptyVisitationContext context) throws E {
        return ((Optional)sortField.expr().accept(this.getExpressionCopyOnWriteVisitor(), context)).map(expr -> Expression.SortField.builder().from(sortField).expr((Expression)expr).build());
    }

    private Optional<Expression> visitOptionalExpression(Optional<Expression> optExpr, EmptyVisitationContext context) throws E {
        if (optExpr.isPresent()) {
            return (Optional)optExpr.get().accept(this.getExpressionCopyOnWriteVisitor(), context);
        }
        return Optional.empty();
    }
}

