/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.soap.internal.rm;

import static org.mule.soap.internal.rm.RMUtils.getReliableMessagingSequence;
import static org.apache.commons.lang3.StringUtils.isBlank;
import static org.apache.cxf.ws.rm.RMUtils.getEndpointIdentifier;
import static org.apache.cxf.ws.rm.RMUtils.getWSRMFactory;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import org.apache.cxf.endpoint.Client;
import org.apache.cxf.io.CachedOutputStream;
import org.apache.cxf.message.Message;
import org.apache.cxf.ws.addressing.AddressingProperties;
import org.apache.cxf.ws.rm.RMEndpoint;
import org.apache.cxf.ws.rm.RMException;
import org.apache.cxf.ws.rm.RMManager;
import org.apache.cxf.ws.rm.SourceSequence;
import org.apache.cxf.ws.rm.persistence.RMMessage;
import org.apache.cxf.ws.rm.v200702.Identifier;
import org.apache.cxf.ws.rm.v200702.SequenceAcknowledgement;

public class MuleRMManager extends ForwardingRMManager {

  private final static long RM_MESSAGE_NUMBER = -1L;

  public MuleRMManager(RMManager manager) {
    super(manager);
  }

  @Override
  public SourceSequence getSequence(Identifier inSeqId, Message message, AddressingProperties maps) throws RMException {
    String sequenceIdentifier = getReliableMessagingSequence(message);
    if (!isBlank(sequenceIdentifier)) {
      Identifier identifier = new Identifier();
      identifier.setValue(sequenceIdentifier);
      return super.getSource(message).getSequence(identifier);
    }
    return super.getSequence(inSeqId, message, maps);
  }

  /**
   * Recover reliable messaging state from store for the created client.
   *
   * @param client
   */
  @Override
  public void clientCreated(Client client) {
    // CXF does not allow to recover source sequences that do not have messages pending ack.
    // To avoid source sequences being removed from the store because of that, a dummy pending ack
    // message is added to the sequence. Doing that, cxf will process and recover the sequence in
    // the right way. After that, the dummy message is removed from the sequence.
    List<Identifier> sequences = preClientCreated(client);
    super.clientCreated(client);
    postClientCreated(client, sequences);
  }

  /**
   * Inject a dummy message to source sequences that do not have messages pending ack
   */
  private List<Identifier> preClientCreated(Client client) {
    final List<Identifier> sequences = new ArrayList<>();
    if (this.getStore() != null && this.getRetransmissionQueue() != null) {
      String id = getEndpointIdentifier(client.getEndpoint(), this.getBus());
      Collection<SourceSequence> sss = this.getStore().getSourceSequences(id);
      if (null != sss && !sss.isEmpty()) {
        final RMMessage message = createRMMessage();
        for (SourceSequence ss : sss) {
          Collection<RMMessage> messages = this.getStore().getMessages(ss.getIdentifier(), true);
          if (messages == null || messages.isEmpty()) {
            this.getStore().persistOutgoing(ss, message);
            sequences.add(ss.getIdentifier());
          }
        }
      }
    }
    return sequences;
  }

  /**
   * Remove dummy messages added to source sequences at {@code preClienCreated} method
   */
  private void postClientCreated(Client client, List<Identifier> sequences) {
    if (sequences == null || sequences.isEmpty()) {
      return;
    }

    final SequenceAcknowledgement emptyAck = getWSRMFactory().createSequenceAcknowledgement();
    final SequenceAcknowledgement ack = createSequenceAcknowledgement();
    final RMEndpoint rmEndpoint = super.findReliableEndpoint(client.getEndpoint().getService().getName());

    for (Identifier sequenceId : sequences) {
      SourceSequence sequence = rmEndpoint.getSource().getSequence(sequenceId);
      try {
        sequence.setAcknowledged(ack);
        sequence.setAcknowledged(emptyAck);
      } catch (RMException e) {
        throw new RuntimeException("Error trying to recover RM state.", e);
      }
    }
  }

  private RMMessage createRMMessage() {
    RMMessage message = new RMMessage();
    message.setMessageNumber(RM_MESSAGE_NUMBER);
    message.setContent(new CachedOutputStream());
    return message;
  }

  private SequenceAcknowledgement createSequenceAcknowledgement() {

    SequenceAcknowledgement.AcknowledgementRange range = getWSRMFactory().createSequenceAcknowledgementAcknowledgementRange();
    range.setLower(RM_MESSAGE_NUMBER);
    range.setUpper(RM_MESSAGE_NUMBER);
    SequenceAcknowledgement ack = getWSRMFactory().createSequenceAcknowledgement();
    ack.getAcknowledgementRange().add(range);
    return ack;
  }
}
