package com.atlassian.confluence.compat.struts2.servletactioncontext;

import com.atlassian.confluence.api.service.exceptions.ServiceException;
import com.atlassian.core.filters.ServletContextThreadLocal;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.ServletConfig;
import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
import java.util.Enumeration;

class ServletActionContextStruts2AndWWCompat implements ServletActionContextCompat {

    private static final Logger log = LoggerFactory.getLogger(ServletActionContextStruts2AndWWCompat.class);

    public static final String STRUTS_2_SERVLET_ACTION_CONTEXT = "org.apache.struts2.ServletActionContext";
    private final Method setRequest;
    private final Method getRequest;
    private final Method setResponse;
    private final Method getResponse;
    private Method getServletContext;
    private Method setServletContext;
    private Method getServletConfig;
    private Method setServletConfig;
    private static String sacClass;

    ServletActionContextStruts2AndWWCompat(String sacClass, ClassLoader classLoader) throws ReflectiveOperationException {
        this.sacClass = sacClass;
        setRequest = getSACStruts2Method("setRequest", classLoader, HttpServletRequest.class);
        getRequest = getSACStruts2Method("getRequest", classLoader);
        setResponse = getSACStruts2Method("setResponse", classLoader, HttpServletResponse.class);
        getResponse = getSACStruts2Method("getResponse", classLoader);
        if (sacClass.equals(STRUTS_2_SERVLET_ACTION_CONTEXT)) {
            getServletContext = getSACStruts2Method("getServletContext", classLoader);
            setServletContext = getSACStruts2Method("setServletContext", classLoader, ServletContext.class);
        } else {
            getServletConfig = getSACStruts2Method("getServletConfig", classLoader);
            setServletConfig = getSACStruts2Method("setServletConfig", classLoader, ServletConfig.class);
        }
    }

    @Override
    public void setRequest(HttpServletRequest request) {
        try {
            setRequest.invoke(null, request);
        } catch (ReflectiveOperationException e) {
            throw new ServiceException("Couldn't set ServletActionContext request", e);
        }
    }

    @Override
    public HttpServletRequest getRequest() {
        try {
            return (HttpServletRequest) getRequest.invoke(null) == null ? ServletContextThreadLocal.getRequest() : (HttpServletRequest) getRequest.invoke(null);
        } catch (NullPointerException e) {
            return ServletContextThreadLocal.getRequest();
        } catch (ReflectiveOperationException e) {
            throw new ServiceException("Couldn't get ServletActionContext request", e);
        }
    }

    @Override
    public void setResponse(HttpServletResponse response) {
        try {
            setResponse.invoke(null, response);
        } catch (ReflectiveOperationException e) {
            throw new ServiceException("Couldn't set ServletActionContext response", e);
        }
    }

    @Override
    public HttpServletResponse getResponse() {
        try {
            return (HttpServletResponse) getResponse.invoke(null);
        } catch (ReflectiveOperationException e) {
            throw new ServiceException("Couldn't get ServletActionContext response", e);
        }
    }

    @Override
    public void setServletConfig(ServletConfig config) {
        try {
            if (this.sacClass.equals(STRUTS_2_SERVLET_ACTION_CONTEXT)) {
                setServletContext.invoke(null, config.getServletContext());
            } else {
                setServletConfig.invoke(null, config);
            }
        } catch (ReflectiveOperationException e) {
            throw new ServiceException("Couldn't get ServletActionContext response", e);
        }
    }

    @Override
    public ServletConfig getServletConfig() {
        if (this.sacClass.equals(STRUTS_2_SERVLET_ACTION_CONTEXT)) {
            return new ServletConfig() {

                @Override
                public String getServletName() {
                    return null;
                }

                @Override
                public ServletContext getServletContext() {
                    ServletContext context = null;
                    try {
                        context = (ServletContext) getServletContext.invoke(null);
                    } catch (ReflectiveOperationException e) {
                        throw new ServiceException("Couldn't get ServletActionContext response", e);
                    }
                    return context;
                }

                @Override
                public String getInitParameter(String s) {
                    return null;
                }

                @Override
                public Enumeration<String> getInitParameterNames() {
                    return null;
                }
            };
        } else {
            try {
                return (ServletConfig) getServletConfig.invoke(null);
            } catch (ReflectiveOperationException e) {
                throw new ServiceException("Couldn't get ServletActionContext response", e);
            }
        }
    }

    @Override
    public ServletContext getServletContext() {
        try {
            if (this.sacClass.equals(STRUTS_2_SERVLET_ACTION_CONTEXT)) {
                return (ServletContext) getServletContext.invoke(null);
            } else {
                ServletConfig config = (ServletConfig) getServletConfig.invoke(null);
                return config != null ? config.getServletContext() : null;
            }
        } catch (ReflectiveOperationException e) {
            throw new ServiceException("Couldn't get ServletActionContext response", e);
        }
    }

    private Method getSACStruts2Method(String methodName, ClassLoader classLoader, Class<?>... parameterTypes) throws ReflectiveOperationException {
        return Class.forName(this.sacClass, false, classLoader)
                .getMethod(methodName, parameterTypes);
    }
}
