/*
 * Decompiled with CFR 0.152.
 */
package org.openrewrite.java.testing.mockito;

import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import org.openrewrite.Cursor;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Preconditions;
import org.openrewrite.Recipe;
import org.openrewrite.Tree;
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.ListUtils;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.VariableNameUtils;
import org.openrewrite.java.search.UsesMethod;
import org.openrewrite.java.tree.Flag;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.Statement;

public class MockitoWhenOnStaticToMockStatic
extends Recipe {
    private static final MethodMatcher MOCKITO_WHEN = new MethodMatcher("org.mockito.Mockito when(..)");

    public String getDisplayName() {
        return "Replace `Mockito.when` on static (non mock) with try-with-resource with MockedStatic";
    }

    public String getDescription() {
        return "Replace `Mockito.when` on static (non mock) with try-with-resource with MockedStatic as Mockito4 no longer allows this.";
    }

    public TreeVisitor<?, ExecutionContext> getVisitor() {
        return Preconditions.check((TreeVisitor)new UsesMethod(MOCKITO_WHEN), (TreeVisitor)new JavaIsoVisitor<ExecutionContext>(){

            public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) {
                J.MethodDeclaration m = super.visitMethodDeclaration(method, (Object)ctx);
                if (m.getBody() == null) {
                    return m;
                }
                List<Statement> newStatements = this.maybeWrapStatementsInTryWithResourcesMockedStatic(m, m.getBody().getStatements());
                return (J.MethodDeclaration)this.maybeAutoFormat((J)m, (J)m.withBody(m.getBody().withStatements(newStatements)), ctx);
            }

            private List<Statement> maybeWrapStatementsInTryWithResourcesMockedStatic(J.MethodDeclaration m, List<Statement> remainingStatements) {
                AtomicBoolean restInTry = new AtomicBoolean(false);
                return ListUtils.flatMap(remainingStatements, (index, statement) -> {
                    J.MethodInvocation whenArg;
                    J.MethodInvocation when;
                    if (restInTry.get()) {
                        return Collections.emptyList();
                    }
                    if (statement instanceof J.MethodInvocation && MOCKITO_WHEN.matches(((J.MethodInvocation)statement).getSelect()) && (when = (J.MethodInvocation)((J.MethodInvocation)statement).getSelect()) != null && when.getArguments().get(0) instanceof J.MethodInvocation && (whenArg = (J.MethodInvocation)when.getArguments().get(0)).getMethodType() != null && whenArg.getMethodType().hasFlags(new Flag[]{Flag.Static})) {
                        J.Identifier clazz;
                        J.FieldAccess fieldAccess;
                        if (whenArg.getSelect() instanceof J.Identifier) {
                            J.Identifier clazz2 = (J.Identifier)whenArg.getSelect();
                            if (clazz2.getType() != null) {
                                return this.tryWithMockedStatic(m, remainingStatements, (Integer)index, (Statement)statement, clazz2.getSimpleName(), whenArg, restInTry);
                            }
                        } else if (whenArg.getSelect() instanceof J.FieldAccess && (fieldAccess = (J.FieldAccess)whenArg.getSelect()).getTarget() instanceof J.Identifier && (clazz = (J.Identifier)fieldAccess.getTarget()).getType() != null) {
                            return this.tryWithMockedStatic(m, remainingStatements, (Integer)index, (Statement)statement, clazz.getSimpleName(), whenArg, restInTry);
                        }
                    }
                    return statement;
                });
            }

            private J.Try tryWithMockedStatic(J.MethodDeclaration m, List<Statement> remainingStatements, Integer index, Statement statement, String simpleName, J.MethodInvocation whenArg, AtomicBoolean restInTry) {
                String mockName = VariableNameUtils.generateVariableName((String)("mock" + simpleName), (Cursor)this.updateCursor((Tree)m), (VariableNameUtils.GenerationStrategy)VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER);
                this.maybeAddImport("org.mockito.MockedStatic", false);
                this.maybeAddImport("org.mockito.Mockito", "mockStatic");
                String template = String.format("try(MockedStatic<%1$s> %2$s = mockStatic(%1$s.class)) {\n    %2$s.when(#{any()}).thenReturn(#{any()});\n}", simpleName, mockName);
                J.Try try_ = (J.Try)((J.MethodDeclaration)JavaTemplate.builder((String)template).contextSensitive().imports(new String[]{"org.mockito.MockedStatic"}).staticImports(new String[]{"org.mockito.Mockito.mockStatic"}).build().apply(this.getCursor(), m.getCoordinates().replaceBody(), new Object[]{whenArg, ((J.MethodInvocation)statement).getArguments().get(0)})).getBody().getStatements().get(0);
                restInTry.set(true);
                List<Statement> precedingStatements = remainingStatements.subList(0, index);
                return try_.withBody(try_.getBody().withStatements(ListUtils.concatAll((List)try_.getBody().getStatements(), this.maybeWrapStatementsInTryWithResourcesMockedStatic(m.withBody(m.getBody().withStatements(ListUtils.concat(precedingStatements, (Object)try_))), remainingStatements.subList(index + 1, remainingStatements.size()))))).withPrefix(statement.getPrefix());
            }
        });
    }
}

