/*
 * All content copyright (c) 2003-2012 Terracotta, Inc., except as may otherwise be noted in a separate copyright
 * notice. All rights reserved.
 */
package com.terracotta.management.servlet;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.terracotta.management.BuildInfo;

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;

/**
 * Servlet for making TC branding / advertising available to the TMC.
 * 
 * @author jhouse
 */
public final class AdServlet extends HttpServlet {
  private static final Logger LOG = LoggerFactory.getLogger(AdServlet.class);

  private static final String AD_LIST_LOCATION_PARAM_NAME = "adListUrl";
  private static final String AD_LIST_LOCATION = "http://www.terracotta.org/tmc/tmcadlist.txt";
  private static final String TMC_VERSION_PARAM = "tmc-version";
  private static final String DEFAULT_AD = "/tcinfo.html";
  
  private static final long AD_CHECK_INTERVAL = 1000L * 60L * 60L * 1L; // hour
  
  private String adListLocation;
  
  private long lastAdCheck = -1;
  private int adOn = 0;
  private List<String> adURLs;
  private List<String> theAds;
    
  @Override
  public void init(ServletConfig config) throws ServletException {
    super.init(config);
    
    adListLocation = config.getInitParameter(AD_LIST_LOCATION_PARAM_NAME);
    if(adListLocation == null)
      adListLocation = AD_LIST_LOCATION;

    adListLocation += "?" + TMC_VERSION_PARAM + "=" + BuildInfo.VERSION_MAJOR + "." + BuildInfo.VERSION_MINOR + "." + BuildInfo.VERSION_ITERATION;
  }

  @Override
  public void destroy() {
    super.destroy();
  }

  @Override
  protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
    
    if(theAds == null || System.currentTimeMillis() - lastAdCheck > AD_CHECK_INTERVAL) {
      lastAdCheck = System.currentTimeMillis();
      loadAds();
    }
    
    if(theAds != null && !theAds.isEmpty()) {
      int adToServe = adOn++ % theAds.size();
      LOG.info("Serving ad #" + (adToServe + 1));
      response.getWriter().print(theAds.get(adToServe));
      response.getWriter().flush();
      response.flushBuffer();
    }
    else {
      // fall-back to static ad
      LOG.info("Serving default ad");
      getServletContext().getRequestDispatcher(DEFAULT_AD).forward(request, response);
    }
  }
  
  @Override
  protected void doHead(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
    super.doHead(request, response);
  }

  @Override
  protected long getLastModified(HttpServletRequest req) {
    return System.currentTimeMillis();
  }

  protected boolean loadAdList() {
    LOG.debug("Checking for available ads...");
    String rawAdList = getURLContentAsString(adListLocation);
    if(rawAdList != null) {
      String[] adList = rawAdList.split("\n");
      List<String> newAdList = new LinkedList<String>();
      for(String adUrl : adList) {
        adUrl = adUrl.trim();
        if(adUrl.length() > 0 && adUrl.startsWith("#")) {
          continue;
        }
        if(adUrl.length() > 8 && adUrl.startsWith("http://") || adUrl.startsWith("https://")) {
          newAdList.add(adUrl);
          LOG.debug("Found ad location: " + adUrl);
        }
      }
      if(!newAdList.isEmpty()) {
        adURLs = newAdList;
        LOG.info("Found " + adURLs.size() + " ad URLs");
        return true; // updated the list
      }
    }
    else {
      LOG.debug("List of ads not available.");
      adURLs = Collections.emptyList();
      theAds = Collections.emptyList();
    }
    return false;
  }
  
  protected void loadAds() {
    if(loadAdList()) {
      theAds = new LinkedList<String>();
      for(String adUrl: adURLs) {
        String ad = getURLContentAsString(adUrl);
        if(ad != null && (ad = ad.trim()).length() > 0) {
          theAds.add(ad);
          LOG.info("Loaded ad from location: " + adUrl);
        }
        else {
          LOG.info("No content found at ad location: " + adUrl);
        }
      }
    }
  }
  
  protected String getURLContentAsString(String url) {
    HttpURLConnection connection = null;
    BufferedReader rd  = null;
    StringBuilder sb = null;
    String line = null;
    
    try {
      URL adAddress = new URL(url);
      connection = (HttpURLConnection) adAddress.openConnection();
      
      connection.setRequestMethod("GET");
      connection.setReadTimeout(10000);
                
      connection.connect();
      
      if(connection.getResponseCode() == HttpURLConnection.HTTP_OK) {
        rd  = new BufferedReader(new InputStreamReader(connection.getInputStream()));
        sb = new StringBuilder();
      
        while ((line = rd.readLine()) != null)
        {
          sb.append(line);
          sb.append("\n");
        }        
        
        LOG.info("Fetched url: " + url);

        return sb.toString();
      }
      else {
        LOG.info("Could not fetch ad from " + adListLocation + " -- got HTTP response code: " + connection.getResponseCode());
      }
      return null;
    }
    catch(Exception e) {
      LOG.info("Error fetching url: " + url, e);
      return null;
    }
    finally {
      try { connection.disconnect(); } catch(Exception ignore) { /* ignore */ }
      rd = null;
      sb = null;
      connection = null;
    }
  }
}
