package com.atlassian.multitenant.servlet;

import com.atlassian.multitenant.MultiTenantContext;
import com.atlassian.multitenant.Tenant;
import org.apache.log4j.Logger;

import java.io.IOException;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;

/**
 * Servlet filter for setting up tenant contexts
 */
public class MultiTenantServletFilter implements Filter
{
    private static final Logger log = Logger.getLogger(MultiTenantServletFilter.class);

    /** Even though this is public you should almost certainly be using the public methods on
     *
     */
    public static final String TENANT_SESSION_KEY = "multitenant.tenant";

    public void init(final FilterConfig filterConfig) throws ServletException
    {
    }

    /**
     * Instead of using the public field TENANT_SESSION_KEY you should use this method. It makes it clear
     * what you actually get from the session and allows easier future refactoring.
     * @param session the session you want to get stuff out of
     * @return the tenant name
     * @throws IllegalStateException thrown if there is no tenant session attribute
     */
    public static String getTenantName(final HttpSession session) throws IllegalStateException
    {
        final Object attribute = session.getAttribute(TENANT_SESSION_KEY);
        if (attribute == null)
        {
            throw new IllegalStateException("No tenant found in session.");
        }
        return (String) attribute;
    }

    public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain filterChain)
            throws IOException, ServletException
    {
        if (request instanceof HttpServletRequest)
        {
            HttpServletRequest httpRequest = (HttpServletRequest) request;
            HttpServletResponse httpResponse = (HttpServletResponse) response;
            Tenant tenant;
            if (MultiTenantContext.getManager().isSingleTenantMode())
            {
                tenant = MultiTenantContext.getSystemTenant();
            }
            else
            {
                tenant = MultiTenantContext.getMatcher().getTenantForRequest(httpRequest, (HttpServletResponse) response);
            }
            if (tenant != null)
            {
                // Ensure this session isn't associated with another tenant
                HttpSession session = ((HttpServletRequest) request).getSession(false);
                if (session != null)
                {
                    String old = (String) session.getAttribute(TENANT_SESSION_KEY);
                    if (old == null)
                    {
                        log.warn("Session found without a tenant, is the MultiTenantSessionListener configured? If not, this instance is vulnerable to session fixation.");
                        // Put the tenant into the session
                        session.setAttribute(TENANT_SESSION_KEY, tenant.getName());
                    }
                    else if (!old.equals(tenant.getName()))
                    {
                        // The data that is cached in the session needs to be invalidated, invalidate the whole session.
                        session.invalidate();
                        httpResponse.sendError(HttpServletResponse.SC_FORBIDDEN, "This session was already associated with another tenant");
                        return;
                    }
                }
                MultiTenantContext.getTenantReference().set(tenant, false);
                try
                {
                    filterChain.doFilter(request, response);
                }
                finally
                {
                    MultiTenantContext.getTenantReference().remove();
                }
                return;
            }
            else
            {
                httpResponse.sendError(HttpServletResponse.SC_NOT_FOUND, "No tenant found to handle request, please check your hostname");
                return;
            }
        }
        filterChain.doFilter(request, response);
    }

    public void destroy()
    {
    }
}
