/*
 * Decompiled with CFR 0.152.
 */
package org.openrewrite.java.security;

import org.openrewrite.Cursor;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Recipe;
import org.openrewrite.TreeVisitor;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaVisitor;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.search.UsesType;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaCoordinates;
import org.openrewrite.java.tree.Statement;
import org.openrewrite.java.tree.TypeUtils;

public class XmlParserXXEVulnerability
extends Recipe {
    private static final MethodMatcher XML_PARSER_FACTORY_INSTANCE = new MethodMatcher("javax.xml.stream.XMLInputFactory new*()");
    private static final MethodMatcher XML_PARSER_FACTORY_SET_PROPERTY = new MethodMatcher("javax.xml.stream.XMLInputFactory setProperty(java.lang.String, ..)");
    private static final String XML_FACTORY_FQN = "javax.xml.stream.XMLInputFactory";
    private static final String SUPPORTING_EXTERNAL_ENTITIES_PROPERTY_NAME = "IS_SUPPORTING_EXTERNAL_ENTITIES";
    private static final String SUPPORT_DTD_PROPERTY_NAME = "SUPPORT_DTD";
    private static final String XML_PARSER_INITIALIZATION_METHOD = "xml-parser-initialization-method";
    private static final String XML_FACTORY_VARIABLE_NAME = "xml-factory-variable-name";

    public String getDisplayName() {
        return "XML parser XXE vulnerability";
    }

    public String getDescription() {
        return "Avoid exposing dangerous features of the XML parser by setting XMLInputFactory `IS_SUPPORTING_EXTERNAL_ENTITIES` and `SUPPORT_DTD` properties to `false`.";
    }

    protected JavaVisitor<ExecutionContext> getSingleSourceApplicableTest() {
        return new UsesType<ExecutionContext>(XML_FACTORY_FQN);
    }

    protected TreeVisitor<?, ExecutionContext> getVisitor() {
        return new JavaIsoVisitor<ExecutionContext>(){

            @Override
            public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext executionContext) {
                J cd = super.visitClassDeclaration(classDecl, executionContext);
                Cursor supportsExternalCursor = (Cursor)this.getCursor().getMessage(XmlParserXXEVulnerability.SUPPORTING_EXTERNAL_ENTITIES_PROPERTY_NAME);
                Cursor supportsDTDCursor = (Cursor)this.getCursor().getMessage(XmlParserXXEVulnerability.SUPPORT_DTD_PROPERTY_NAME);
                Cursor initializationCursor = (Cursor)this.getCursor().getMessage(XmlParserXXEVulnerability.XML_PARSER_INITIALIZATION_METHOD);
                String xmlFactoryVariableName = (String)this.getCursor().getMessage(XmlParserXXEVulnerability.XML_FACTORY_VARIABLE_NAME);
                Cursor setPropertyBlockCursor = null;
                if (supportsExternalCursor == null && supportsDTDCursor == null) {
                    setPropertyBlockCursor = initializationCursor;
                } else if (supportsExternalCursor == null ^ supportsDTDCursor == null) {
                    Cursor cursor = setPropertyBlockCursor = supportsExternalCursor == null ? supportsDTDCursor : supportsExternalCursor;
                }
                if (setPropertyBlockCursor != null && xmlFactoryVariableName != null) {
                    this.doAfterVisit(new XmlFactoryInsertPropertyStatementVisitor((J.Block)setPropertyBlockCursor.getValue(), xmlFactoryVariableName, supportsExternalCursor == null, supportsDTDCursor == null));
                }
                return cd;
            }

            @Override
            public J.VariableDeclarations.NamedVariable visitVariable(J.VariableDeclarations.NamedVariable variable, ExecutionContext executionContext) {
                J v = super.visitVariable(variable, executionContext);
                if (TypeUtils.isOfClassType(((J.VariableDeclarations.NamedVariable)v).getType(), XmlParserXXEVulnerability.XML_FACTORY_FQN)) {
                    this.getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, XmlParserXXEVulnerability.XML_FACTORY_VARIABLE_NAME, (Object)((J.VariableDeclarations.NamedVariable)v).getSimpleName());
                }
                return v;
            }

            @Override
            public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext executionContext) {
                J m = super.visitMethodInvocation(method, executionContext);
                if (XML_PARSER_FACTORY_INSTANCE.matches((J.MethodInvocation)m)) {
                    this.getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, XmlParserXXEVulnerability.XML_PARSER_INITIALIZATION_METHOD, (Object)this.getCursor().dropParentUntil(J.Block.class::isInstance));
                } else if (XML_PARSER_FACTORY_SET_PROPERTY.matches((J.MethodInvocation)m) && ((J.MethodInvocation)m).getArguments().get(0) instanceof J.FieldAccess) {
                    J.FieldAccess fa = (J.FieldAccess)((J.MethodInvocation)m).getArguments().get(0);
                    if (XmlParserXXEVulnerability.SUPPORTING_EXTERNAL_ENTITIES_PROPERTY_NAME.equals(fa.getSimpleName())) {
                        this.getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, XmlParserXXEVulnerability.SUPPORTING_EXTERNAL_ENTITIES_PROPERTY_NAME, (Object)this.getCursor().dropParentUntil(J.Block.class::isInstance));
                    } else if (XmlParserXXEVulnerability.SUPPORT_DTD_PROPERTY_NAME.equals(fa.getSimpleName())) {
                        this.getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, XmlParserXXEVulnerability.SUPPORT_DTD_PROPERTY_NAME, (Object)this.getCursor().dropParentUntil(J.Block.class::isInstance));
                    }
                }
                return m;
            }
        };
    }

    private static class XmlFactoryInsertPropertyStatementVisitor
    extends JavaIsoVisitor<ExecutionContext> {
        J.Block scope;
        StringBuilder propertyTemplate = new StringBuilder();

        public XmlFactoryInsertPropertyStatementVisitor(J.Block scope, String factoryVariableName, boolean needsExternalEntitiesDisabled, boolean needsSupportsDtdDisabled) {
            this.scope = scope;
            if (needsExternalEntitiesDisabled) {
                this.propertyTemplate.append(factoryVariableName).append(".setProperty(XMLInputFactory.IS_SUPPORTING_EXTERNAL_ENTITIES, false);");
            }
            if (needsSupportsDtdDisabled) {
                this.propertyTemplate.append(factoryVariableName).append(".setProperty(XMLInputFactory.SUPPORT_DTD, false);");
            }
        }

        @Override
        public J.Block visitBlock(J.Block block, ExecutionContext executionContext) {
            J b = super.visitBlock(block, executionContext);
            Statement beforeStatement = null;
            if (b.isScope(this.scope)) {
                for (int i = ((J.Block)b).getStatements().size() - 2; i > -1; --i) {
                    J.MethodInvocation m;
                    J.VariableDeclarations vd;
                    Statement st = ((J.Block)b).getStatements().get(i);
                    Statement stBefore = ((J.Block)b).getStatements().get(i + 1);
                    if (st instanceof J.MethodInvocation) {
                        J.MethodInvocation m2 = (J.MethodInvocation)st;
                        if (!XML_PARSER_FACTORY_INSTANCE.matches(m2) && !XML_PARSER_FACTORY_SET_PROPERTY.matches(m2)) continue;
                        beforeStatement = stBefore;
                        continue;
                    }
                    if (!(st instanceof J.VariableDeclarations) || !((vd = (J.VariableDeclarations)st).getVariables().get(0).getInitializer() instanceof J.MethodInvocation) || (m = (J.MethodInvocation)vd.getVariables().get(0).getInitializer()) == null || !XML_PARSER_FACTORY_INSTANCE.matches(m)) continue;
                    beforeStatement = stBefore;
                }
                if (this.getCursor().getParent() != null && this.getCursor().getParent().getValue() instanceof J.ClassDeclaration) {
                    this.propertyTemplate.insert(0, "{\n").append("}");
                }
                JavaCoordinates insertCoordinates = beforeStatement != null ? beforeStatement.getCoordinates().before() : ((J.Block)b).getCoordinates().lastStatement();
                b = (J.Block)b.withTemplate(this.template(this.propertyTemplate.toString()).build(), insertCoordinates, new Object[0]);
            }
            return b;
        }
    }
}

