/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.ExpressionSymbolInliner;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.tree.BindExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.LambdaArgumentDeclaration;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class LambdaCaptureDesugaringRewriter {
    public static Expression rewrite(Expression expression, Map<Symbol, Type> symbolTypes, SymbolAllocator symbolAllocator) {
        return ExpressionTreeRewriter.rewriteWith((ExpressionRewriter)new Visitor(symbolTypes, symbolAllocator), (Expression)expression, (Object)new Context());
    }

    private LambdaCaptureDesugaringRewriter() {
    }

    private static class Context {
        LinkedHashSet<Symbol> referencedSymbols;

        public Context() {
            this(new LinkedHashSet<Symbol>());
        }

        private Context(LinkedHashSet<Symbol> referencedSymbols) {
            this.referencedSymbols = referencedSymbols;
        }

        public LinkedHashSet<Symbol> getReferencedSymbols() {
            return this.referencedSymbols;
        }

        public Context withReferencedSymbols(LinkedHashSet<Symbol> symbols) {
            return new Context(symbols);
        }
    }

    private static class Visitor
    extends ExpressionRewriter<Context> {
        private final Map<Symbol, Type> symbolTypes;
        private final SymbolAllocator symbolAllocator;

        public Visitor(Map<Symbol, Type> symbolTypes, SymbolAllocator symbolAllocator) {
            this.symbolTypes = Objects.requireNonNull(symbolTypes, "symbolTypes is null");
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        }

        public Expression rewriteLambdaExpression(LambdaExpression node, Context context, ExpressionTreeRewriter<Context> treeRewriter) {
            LinkedHashSet<Symbol> referencedSymbols = new LinkedHashSet<Symbol>();
            Expression rewrittenBody = treeRewriter.rewrite(node.getBody(), (Object)context.withReferencedSymbols(referencedSymbols));
            List lambdaArguments = (List)node.getArguments().stream().map(LambdaArgumentDeclaration::getName).map(Identifier::getValue).map(Symbol::new).collect(ImmutableList.toImmutableList());
            referencedSymbols.removeAll(lambdaArguments);
            LinkedHashSet<Symbol> captureSymbols = referencedSymbols;
            ImmutableMap.Builder captureSymbolToExtraSymbol = ImmutableMap.builder();
            ImmutableList.Builder newLambdaArguments = ImmutableList.builder();
            for (Symbol captureSymbol : captureSymbols) {
                Symbol extraSymbol = this.symbolAllocator.newSymbol(captureSymbol.getName(), this.symbolTypes.get(captureSymbol));
                captureSymbolToExtraSymbol.put((Object)captureSymbol, (Object)extraSymbol);
                newLambdaArguments.add((Object)new LambdaArgumentDeclaration(new Identifier(extraSymbol.getName())));
            }
            newLambdaArguments.addAll((Iterable)node.getArguments());
            ImmutableMap symbolsMap = captureSymbolToExtraSymbol.build();
            ExpressionSymbolInliner inliner = new ExpressionSymbolInliner(x -> ((Symbol)symbolsMap.getOrDefault(x, x)).toSymbolReference());
            LambdaExpression rewrittenExpression = new LambdaExpression((List)newLambdaArguments.build(), inliner.rewrite(rewrittenBody));
            if (captureSymbols.size() != 0) {
                List capturedValues = (List)captureSymbols.stream().map(symbol -> new SymbolReference(symbol.getName())).collect(ImmutableList.toImmutableList());
                rewrittenExpression = new BindExpression(capturedValues, (Expression)rewrittenExpression);
            }
            context.getReferencedSymbols().addAll(captureSymbols);
            return rewrittenExpression;
        }

        public Expression rewriteSymbolReference(SymbolReference node, Context context, ExpressionTreeRewriter<Context> treeRewriter) {
            context.getReferencedSymbols().add(new Symbol(node.getName()));
            return null;
        }
    }
}

