/*
 * Copyright (c) MuleSoft, Inc.  All rights reserved.  http://www.mulesoft.com
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.soap.internal.client;

import com.google.common.collect.ImmutableList;
import org.apache.cxf.binding.Binding;
import org.apache.cxf.binding.soap.interceptor.CheckFaultInterceptor;
import org.apache.cxf.binding.soap.interceptor.Soap11FaultInInterceptor;
import org.apache.cxf.binding.soap.interceptor.Soap12FaultInInterceptor;
import org.apache.cxf.endpoint.Client;
import org.apache.cxf.interceptor.Interceptor;
import org.apache.cxf.message.Message;
import org.apache.cxf.phase.PhaseInterceptor;
import org.apache.cxf.ws.security.wss4j.WSS4JInInterceptor;
import org.apache.cxf.ws.security.wss4j.WSS4JOutInterceptor;
import org.apache.cxf.wsdl.interceptors.WrappedOutInterceptor;
import org.mule.soap.api.SoapVersion;
import org.mule.soap.api.SoapWebServiceConfiguration;
import org.mule.soap.api.security.SecurityStrategy;
import org.mule.soap.internal.interceptor.OutputMtomSoapAttachmentsInterceptor;
import org.mule.soap.internal.interceptor.OutputSoapHeadersInterceptor;
import org.mule.soap.internal.interceptor.SoapActionInterceptor;
import org.mule.soap.internal.interceptor.StreamClosingInterceptor;
import org.mule.soap.internal.security.callback.CompositeCallbackHandler;

import javax.security.auth.callback.CallbackHandler;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.StringJoiner;

import static java.util.Collections.emptyMap;
import static org.apache.commons.lang3.StringUtils.isNotBlank;
import static org.apache.cxf.message.Message.MTOM_ENABLED;
import static org.apache.wss4j.common.ConfigurationConstants.ACTION;
import static org.apache.wss4j.common.ConfigurationConstants.PW_CALLBACK_REF;
import static org.mule.soap.api.security.SecurityStrategy.SecurityStrategyType.ALWAYS;

/**
 * Object that creates CXF specific clients based on a {@link SoapWebServiceConfiguration} setting all the required CXF properties.
 * <p>
 * the created client aims to be the CXF client used in the {@link AbstractSoapCxfClient}.
 *
 * @since 1.0
 */
class CxfClientProvider {

  private final CxfClientFactory factory = new CxfClientFactory();

  Client getClient(SoapWebServiceConfiguration configuration) {
    boolean isMtom = configuration.isMtomEnabled();
    String address = configuration.getAddress();
    SoapVersion version = configuration.getVersion();
    Client client = factory.createClient(address, version.getNumber());
    addSecurityInterceptors(client, configuration.getSecurities());
    addRequestInterceptors(client);
    addResponseInterceptors(client, isMtom);
    client.getEndpoint().put(MTOM_ENABLED, isMtom);
    removeUnnecessaryCxfInterceptors(client);
    return client;
  }

  private void addSecurityInterceptors(Client client, List<SecurityStrategy> securityStrategies) {
    Map<String, Object> requestProps =
        buildSecurityProperties(securityStrategies, SecurityStrategy.SecurityStrategyType.OUTGOING);
    if (!requestProps.isEmpty() && isNotBlank((String) requestProps.get(ACTION))) {
      client.getOutInterceptors().add(new WSS4JOutInterceptor(requestProps));
    }

    Map<String, Object> responseProps =
        buildSecurityProperties(securityStrategies, SecurityStrategy.SecurityStrategyType.INCOMING);
    if (!responseProps.isEmpty() && isNotBlank((String) responseProps.get(ACTION))) {
      client.getInInterceptors().add(new WSS4JInInterceptor(responseProps));
    }
  }

  private Map<String, Object> buildSecurityProperties(List<SecurityStrategy> strategies,
                                                      SecurityStrategy.SecurityStrategyType type) {
    if (strategies.isEmpty()) {
      return emptyMap();
    }

    Map<String, Object> props = new HashMap<>();
    StringJoiner actionsJoiner = new StringJoiner(" ");

    ImmutableList.Builder<CallbackHandler> callbackHandlersBuilder = ImmutableList.builder();
    strategies.stream()
        .filter(strategy -> strategy.securityType().equals(type) || strategy.securityType().equals(ALWAYS))
        .forEach(securityStrategy -> {
          props.putAll(securityStrategy.buildSecurityProperties());
          if (isNotBlank(securityStrategy.securityAction())) {
            actionsJoiner.add(securityStrategy.securityAction());
          }
          securityStrategy.buildPasswordCallbackHandler().ifPresent(callbackHandlersBuilder::add);
        });

    List<CallbackHandler> handlers = callbackHandlersBuilder.build();
    if (!handlers.isEmpty()) {
      props.put(PW_CALLBACK_REF, new CompositeCallbackHandler(handlers));
    }

    String actions = actionsJoiner.toString();
    if (isNotBlank(actions)) {
      // the list of actions is passed as a String with the action names separated by a black space.
      props.put(ACTION, actions);
    }

    // This Map needs to be mutable, cxf will add properties if needed.
    return props;
  }

  private void addRequestInterceptors(Client client) {
    List<Interceptor<? extends Message>> outInterceptors = client.getOutInterceptors();
    outInterceptors.add(new SoapActionInterceptor());
  }

  private void addResponseInterceptors(Client client, boolean mtomEnabled) {
    List<Interceptor<? extends Message>> inInterceptors = client.getInInterceptors();
    inInterceptors.add(new StreamClosingInterceptor());
    inInterceptors.add(new CheckFaultInterceptor());
    inInterceptors.add(new OutputSoapHeadersInterceptor());
    inInterceptors.add(new SoapActionInterceptor());
    if (mtomEnabled) {
      inInterceptors.add(new OutputMtomSoapAttachmentsInterceptor());
    }
  }

  private void removeUnnecessaryCxfInterceptors(Client client) {
    Binding binding = client.getEndpoint().getBinding();
    removeInterceptor(binding.getOutInterceptors(), WrappedOutInterceptor.class.getName());
    removeInterceptor(binding.getInInterceptors(), Soap11FaultInInterceptor.class.getName());
    removeInterceptor(binding.getInInterceptors(), Soap12FaultInInterceptor.class.getName());
    removeInterceptor(binding.getInInterceptors(), CheckFaultInterceptor.class.getName());
  }

  private void removeInterceptor(List<Interceptor<? extends Message>> inInterceptors, String name) {
    inInterceptors.removeIf(i -> i instanceof PhaseInterceptor && ((PhaseInterceptor) i).getId().equals(name));
  }
}
