/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package com.terracotta.management.security.shiro.realm;

import org.apache.shiro.realm.Realm;
import org.apache.shiro.realm.ldap.LdapUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.directory.Attribute;
import javax.naming.directory.Attributes;
import javax.naming.directory.SearchControls;
import javax.naming.directory.SearchResult;
import javax.naming.ldap.LdapContext;
import java.util.*;

/**
 * A {@link Realm} that authenticates with an active directory LDAP server to determine the roles for a particular user. This implementation queries
 * for the user's groups and then maps the group names to roles using the {@link #groupRolesMap}.
 *
 * This Terracotta version is a fork of the original shiro one (org.apache.shiro.realm.ActiveDirectoryRealm) extending JndiLdapRealm instead of
 * AbstractLdapRealm
 *
 */
public class ActiveDirectoryRealm extends LdapRealm {

  private static final Logger log = LoggerFactory.getLogger(ActiveDirectoryRealm.class);

  protected static final String CN = "CN";

  @Override
  protected Set<String> getRoleNamesForUser(String username, LdapContext ldapContext) throws NamingException {
    Set<String> roleNames;
    roleNames = new LinkedHashSet<String>();

    SearchControls searchCtls = new SearchControls();
    searchCtls.setSearchScope(SearchControls.SUBTREE_SCOPE);

    //SHIRO-115 - prevent potential code injection:
    String searchFilter = "(&(objectClass=*)(" + CN + "={0}))";
    Object[] searchArguments = new Object[]{ username };

    NamingEnumeration answer = ldapContext.search(searchBase, searchFilter, searchArguments, searchCtls);

    while (answer.hasMoreElements()) {
      SearchResult sr = (SearchResult) answer.next();

      if (log.isDebugEnabled()) {
        log.debug("Retrieving group names for user [" + sr.getName() + "]");
      }

      Attributes attrs = sr.getAttributes();

      if (attrs != null) {
        NamingEnumeration ae = attrs.getAll();
        while (ae.hasMore()) {
          Attribute attr = (Attribute) ae.next();

          if (attr.getID().equals("memberOf")) {

            Collection<String> groupNames = LdapUtils.getAllAttributeValues(attr);

            if (log.isDebugEnabled()) {
              log.debug("Groups found for user [" + username + "]: " + groupNames);
            }

            Collection<String> rolesForGroups = getRoleNamesForGroups(groupNames);
            roleNames.addAll(rolesForGroups);
          }
        }
      }
    }

    return roleNames;
  }

//  private String getUserPrincipaName(String username) {
//    return username + "@" + getDomainWithDots(searchBase);
//  }
//
//  String getDomainWithDots(String searchBase) {
//    StringBuilder sb = new StringBuilder();
////    Pattern pattern = Pattern.compile("^[\\w,]*$", Pattern.CASE_INSENSITIVE);
//    Pattern pattern = Pattern.compile("[DC=[a-zA-Z0-9]+,]*DC=[a-zA-Z0-9]+", Pattern.CASE_INSENSITIVE);
//    Matcher matcher = pattern.matcher(searchBase);
//    if (matcher.find()) {
//      String[] split = searchBase.split(",");
//      for (String s : split) {
//        sb.append(s.substring(3));
//        sb.append(".");
//      }
//      sb.deleteCharAt(sb.lastIndexOf("."));
//    }
//    return sb.toString();  //To change body of created methods use File | Settings | File Templates.
//  }

  /**
   * This method is called by the default implementation to translate Active Directory group names to role names. This implementation uses the
   * {@link #groupRolesMap} to map group names to role names.
   *
   * @param groupNames
   *          the group names that apply to the current user.
   * @return a collection of roles that are implied by the given role names.
   */
  protected Collection<String> getRoleNamesForGroups(Collection<String> groupNames) {
    Collection<String> allAdGroups = new HashSet<String>();
    for (String groupName : groupNames) {
      allAdGroups.addAll(processGroup(groupName));
    }
    return translateGroups(allAdGroups);
  }

  private Collection<String> processGroup(String adGroups) {
    // adGroups looks like this: "CN=Domain Admins,CN=Users,DC=mykene,DC=rndlab,DC=loc"
    Collection<String> result = new HashSet<String>();

    List<String> domains = parseDomainsInSearchBase();

    String[] adGroupArray = adGroups.split(ROLE_NAMES_DELIMETER);
    for (String adGroup : adGroupArray) {
      String[] ldapValue = adGroup.split("=");
      String key = ldapValue[0];
      String value = ldapValue[1];

      if (key.equalsIgnoreCase("DC")) {
        // if the domain doesn't match the search base, ignore these groups
        if (domains.isEmpty() || !domains.get(0).equals(value)) {
          return Collections.emptySet();
        }
        domains.remove(0);
        continue;
      }
      if (!key.equalsIgnoreCase("CN")) {
        continue;
      }
      result.add(value);
    }

    // not all parts of the domains matched, ignore these groups
    if (!domains.isEmpty()) {
      return Collections.emptySet();
    }

    return result;
  }

  private List<String> parseDomainsInSearchBase() {
    List<String> result = new ArrayList<String>();

    String[] domainParts = searchBase.split(ROLE_NAMES_DELIMETER);
    for (String domainPart : domainParts) {
      result.add(domainPart.split("=")[1]);
    }

    return result;
  }

  private Collection<String> translateGroups(Collection<String> allAdGroups) {
    if (groupRolesMap != null) {
      Collection<String> result = new HashSet<String>();

      for (String adGroup : allAdGroups) {
        Set<String> xlatedGroup = groupRolesMap.get(adGroup);
        if (xlatedGroup != null) {
          for (String role : xlatedGroup) {
            result.add(role);
          }
        }
      }

      return result;
    } else {
      return allAdGroups;
    }
  }

  /**
   * This method is here because we use the com.terracotta.management.security.shiro.realm.TCJndiLdapContextFactory
   * for the LdapRealm, and it always read the systemUserName from the ContextFactory
   *
   * @param systemUsername "simple" version of the username
   */
  public void setSystemUsername(String systemUsername) {
    ((TCJndiLdapContextFactory) getContextFactory()).setSystemUsername(systemUsername);
    ((TCJndiLdapContextFactory) getContextFactory()).setSimpleSystemUsername(systemUsername);
  }

}
