package com.opensymphony.xwork.util;

import com.google.common.cache.CacheBuilder;
import com.opensymphony.xwork.config.ConfigurationManager;
import ognl.Node;
import ognl.OgnlContext;
import ognl.OgnlException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import javax.lang.model.SourceVersion;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;

import static java.lang.String.format;
import static java.util.Optional.ofNullable;

/**
 * XWork OGNL security handling utility class. It supports blocking double-eval along with blocked dangerous OGNL Node
 * expressions. It also supports a blocklist and an allowlist to restrict risky java packages, classes and methods.
 * Note that this class used to live in Atlassian fork of Webwork, but has been moved here for other framework's usage.
 * To Configure xwork.xml with blocklist, set &lt;constant&gt; tag with comma-separated values of xwork.excludedClasses,
 * xwork.excludedPackageNames and xwork.allowedClasses eg:
 * <ul>
 *     <li>&lt;constant name=&quot;xwork.excludedClasses&quot; value=&quot;java.lang.Object,java.lang.Runtime&quot;/&gt;</li>
 *     <li>&lt;constant name=&quot;xwork.excludedPackageNames&quot; value=&quot;ognl,java.io&quot;/&gt;</li>
 *     <li>&lt;constant name=&quot;xwork.allowedClasses&quot; value=&quot;java.math.BigDecimal,java.lang.System&quot;/&gt;</li>
 * </ul>
 */
class SafeExpressionUtil {
    private static final Set<String> UNSAFE_VARIABLE_NAMES;
    private static final Set<String> UNSAFE_NODE_TYPES;
    private static final Optional<Method> OGNL_METHOD_GET_METHOD;
    private static final Optional<Field> OGNL_METHOD_GET_CLASS_STATIC_FIELD;

    private static final Log log = LogFactory.getLog(SafeExpressionUtil.class);

    private final Set<String> SAFE_EXPRESSIONS_CACHE = Collections.newSetFromMap(
            CacheBuilder.newBuilder()
                    .maximumSize(10000)
                    .<String, Boolean>build().asMap());
    private final Set<String> UNSAFE_EXPRESSIONS_CACHE = Collections.newSetFromMap(
            CacheBuilder.newBuilder()
                    .maximumSize(1000)
                    .<String, Boolean>build().asMap());

    private final Set<String> unsafePropertyNames;
    private final Set<String> unsafePackageNames;
    private final Set<String> unsafeMethodNames;
    private final Set<String> allowedClassNames;

    static {
        final Set<String> set = new HashSet<>();
        set.add("ognl.ASTStaticField");
        set.add("ognl.ASTCtor");
        set.add("ognl.ASTAssign");
        UNSAFE_NODE_TYPES = Collections.unmodifiableSet(set);
    }

    static {
        final Set<String> set = new HashSet<>();
        set.add("#" + OgnlContext.MEMBER_ACCESS_CONTEXT_KEY);
        set.add("#" + OgnlContext.CONTEXT_CONTEXT_KEY);
        set.add("#request");
        set.add("#parameters");
        set.add("#session");
        set.add("#application");
        set.add("#attr");
        UNSAFE_VARIABLE_NAMES = Collections.unmodifiableSet(set);
    }

    static {
        Field getStaticMethodClassName = null;
        try {
            Class<?> astStaticMethodClass = Class.forName("ognl.ASTStaticMethod");
            getStaticMethodClassName = astStaticMethodClass.getDeclaredField("className");
            getStaticMethodClassName.setAccessible(true);
        } catch (Exception e) {
        }
        OGNL_METHOD_GET_CLASS_STATIC_FIELD = ofNullable(getStaticMethodClassName);
    }

    static {
        Method method;
        try {
            Class<?> aClass = Class.forName("ognl.ASTMethod");
            method = aClass.getMethod("getMethodName");
            method.setAccessible(true);
        } catch (Exception e) {
            method = null;
        }
        OGNL_METHOD_GET_METHOD = ofNullable(method);
    }

    public SafeExpressionUtil() {
        this.unsafePropertyNames = getUnsafePropertyNames();
        this.unsafePackageNames = getUnsafePackageNames();
        this.unsafeMethodNames = getUnsafeMethodNames();
        this.allowedClassNames = getAllowedClassNames();
    }

    private Set<String> getUnsafePropertyNames() {
        final Set<String> set = new HashSet<>(ConfigurationManager.getConfiguration().getExcludedClasses());
        set.add("class");
        set.add("classLoader");
        return set;
    }

    private Set<String> getUnsafePackageNames() {
        HashSet<String> blockedPackages = new HashSet<>(ConfigurationManager.getConfiguration().getExcludedPackageNames());
        return Collections.unmodifiableSet(blockedPackages);
    }

    private Set<String> getUnsafeMethodNames() {
        final Set<String> set = new HashSet<>();
        set.add("getClass");
        set.add("getClassLoader");
        return Collections.unmodifiableSet(set);
    }

    private Set<String> getAllowedClassNames() {
        HashSet<String> allowedClassNames = new HashSet<>(ConfigurationManager.getConfiguration().getAllowedClasses());
        return Collections.unmodifiableSet(allowedClassNames);
    }

    /**
     * Returns true if the given expression is considered "safe".
     */
    public boolean isSafeExpression(String expression) {
        return isSafeExpressionInternal(expression, new HashSet<>());
    }

    private boolean isSafeExpressionInternal(String expression, Set<String> visitedExpressions) {
        if (!SAFE_EXPRESSIONS_CACHE.contains(expression)) {
            if(UNSAFE_EXPRESSIONS_CACHE.contains(expression)) {
                return false;
            }
            if(isUnSafeClass(expression)) {
                UNSAFE_EXPRESSIONS_CACHE.add(expression);
                return false;
            }
            if(SourceVersion.isName(trimQuotes(expression)) && allowedClassNames.contains(trimQuotes(expression))) {
                SAFE_EXPRESSIONS_CACHE.add(expression);
            } else {
                try {
                    final Object parsedExpression = OgnlUtil.compile(expression);
                    if (parsedExpression instanceof Node) {
                        if (containsUnsafeExpression((Node) parsedExpression, visitedExpressions)) {
                            UNSAFE_EXPRESSIONS_CACHE.add(expression);
                            log.debug(format("Unsafe clause found in [\" %s \"]", expression));
                        } else {
                            SAFE_EXPRESSIONS_CACHE.add(expression);
                        }
                    }
                } catch (OgnlException|RuntimeException ex) {
                    SAFE_EXPRESSIONS_CACHE.add(expression);
                    log.debug("Cannot verify safety of OGNL expression", ex);
                }
            }
        }
        return SAFE_EXPRESSIONS_CACHE.contains(expression);
    }

    private boolean containsUnsafeExpression(Node node, Set<String> visitedExpressions) {
        final String nodeClassName = node.getClass().getName();

        if (UNSAFE_NODE_TYPES.contains(nodeClassName)) {
            return true;
        } else if ("ognl.ASTStaticMethod".equals(nodeClassName) && !allowedClassNames.contains(getClassNameFromStaticMethod(node))) {
            return true;
        } else if ("ognl.ASTProperty".equals(nodeClassName) && isUnSafeClass(node.toString())) {
            return true;
        } else if ("ognl.ASTMethod".equals(nodeClassName) && unsafeMethodNames.contains(getMethodInOgnlExp(node))) {
            return true;
        } else if ("ognl.ASTVarRef".equals(nodeClassName) && UNSAFE_VARIABLE_NAMES.contains(node.toString())) {
            return true;
        } else if ("ognl.ASTConst".equals(nodeClassName) && !isSafeConstantExpressionNode(node, visitedExpressions)) {
            return true;
        } else {
            for (int i = 0; i < node.jjtGetNumChildren(); i++) {
                final Node childNode = node.jjtGetChild(i);
                if (childNode != null && containsUnsafeExpression(childNode, visitedExpressions)) {
                    return true;
                }
            }
            return false;
        }
    }

    private boolean isSafeConstantExpressionNode(Node node, Set<String> visitedExpressions) {
        try {
            String value = node.getValue(new OgnlContext(), null).toString();
            if(visitedExpressions.contains(value) || value == null || value.isEmpty()) {
                return true;
            }
            visitedExpressions.add(value);
            return isSafeExpressionInternal(value, visitedExpressions);
        } catch (OgnlException e) {
            log.debug("Cannot verify safety of OGNL expression", e);
        }
        return true;
    }

    private static String getClassNameFromStaticMethod(Node node) {
        try {
            if(OGNL_METHOD_GET_CLASS_STATIC_FIELD.isPresent()) {
                return (String) OGNL_METHOD_GET_CLASS_STATIC_FIELD.get().get(node);
            }
        } catch (IllegalAccessException e) {
            log.debug("Method can't be accessed for introspection", e);
        }
        return null;
    }

    private static String getMethodInOgnlExp(Node node) {
        try {
            if(OGNL_METHOD_GET_METHOD.isPresent()) {
                return (String) OGNL_METHOD_GET_METHOD.get().invoke(node);
            }
        } catch (IllegalAccessException | InvocationTargetException e) {
            log.debug("Method can't be accessed for introspection", e);
        }
        return null;
    }

    private String trimQuotes(String value) {
        String trimmedValue = value.trim();
        if(trimmedValue.startsWith("\"") && trimmedValue.endsWith("\"")) {
            return trimQuotes(trimmedValue.substring(1, trimmedValue.length() - 1));
        } else if(trimmedValue.startsWith("'") && trimmedValue.endsWith("'")) {
            return trimQuotes(trimmedValue.substring(1, trimmedValue.length() - 1));
        }
        return value;
    }

    private boolean isUnSafeClass(String expression) {
        String trimmedClassName = trimQuotes(expression);
        if(unsafePropertyNames.contains(trimmedClassName)) {
            return true;
        }
        if(SourceVersion.isName(trimmedClassName)) {
            List<String> parentPackageNames = populateParentPackages(trimmedClassName, new ArrayList<>());
            return parentPackageNames.stream().anyMatch(unsafePackageNames::contains);
        }
        return false;
    }

    /**
     *
     * @param name of a class for direct call
     * @param packages current list of possible parent packages
     * @return List of all possible parent package names.
     * Ex: "com.atlassian.package.subpackage.SomeClass" returns
     * ["com", "com.atlassian", "com.atlassian.package", "com.atlassian.package.subpackage"]
     */
    private List<String> populateParentPackages(String name, List<String> packages) {
        int dotPos = name.lastIndexOf('.');
        if (dotPos != -1) {
            String packageName = name.substring(0, dotPos);
            packages.add(packageName);
            populateParentPackages(packageName, packages);
        }
        return packages;
    }
}
