/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.californium.scandium.dtls;

import java.net.DatagramPacket;
import java.net.InetSocketAddress;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import org.eclipse.californium.elements.util.Bytes;
import org.eclipse.californium.elements.util.DatagramWriter;
import org.eclipse.californium.elements.util.NoPublicAPI;
import org.eclipse.californium.elements.util.StringUtil;
import org.eclipse.californium.scandium.dtls.AlertMessage;
import org.eclipse.californium.scandium.dtls.ConnectionId;
import org.eclipse.californium.scandium.dtls.ContentType;
import org.eclipse.californium.scandium.dtls.DTLSContext;
import org.eclipse.californium.scandium.dtls.DTLSMessage;
import org.eclipse.californium.scandium.dtls.Finished;
import org.eclipse.californium.scandium.dtls.FragmentedHandshakeMessage;
import org.eclipse.californium.scandium.dtls.HandshakeException;
import org.eclipse.californium.scandium.dtls.HandshakeMessage;
import org.eclipse.californium.scandium.dtls.MultiHandshakeMessage;
import org.eclipse.californium.scandium.dtls.Record;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@NoPublicAPI
public class DTLSFlight {
    private static final Logger LOGGER = LoggerFactory.getLogger(DTLSFlight.class);
    private final List<Record> records;
    private final List<EpochMessage> dtlsMessages;
    private final DTLSContext context;
    private final InetSocketAddress peer;
    private final Object peerToLog;
    private final int flightNumber;
    private int tries;
    private int timeoutMillis;
    private int maxDatagramSize;
    private int maxFragmentSize;
    private int effectiveMaxDatagramSize;
    private int effectiveMaxMessageSize;
    private boolean useMultiHandshakeMessageRecords;
    private int multiEpoch;
    private boolean multiUseCid;
    private MultiHandshakeMessage multiHandshakeMessage;
    private boolean retransmissionNeeded;
    private boolean finishedIncluded;
    private volatile boolean responseStarted;
    private volatile boolean responseCompleted;
    private ScheduledFuture<?> timeoutTask;

    public DTLSFlight(DTLSContext context, int flightNumber, InetSocketAddress peer) {
        if (context == null) {
            throw new NullPointerException("Session must not be null");
        }
        this.context = context;
        this.peer = peer;
        this.peerToLog = StringUtil.toLog(peer);
        this.records = new ArrayList<Record>();
        this.dtlsMessages = new ArrayList<EpochMessage>();
        this.retransmissionNeeded = true;
        this.flightNumber = flightNumber;
    }

    public void addDtlsMessage(int epoch, DTLSMessage messageToAdd) {
        if (messageToAdd == null) {
            throw new NullPointerException("message must not be null!");
        }
        if (messageToAdd instanceof Finished) {
            this.finishedIncluded = true;
        }
        this.dtlsMessages.add(new EpochMessage(epoch, messageToAdd));
    }

    public int getNumberOfMessages() {
        return this.dtlsMessages.size();
    }

    public boolean contains(DTLSMessage message) {
        for (EpochMessage epochMessage : this.dtlsMessages) {
            if (!Arrays.equals(message.toByteArray(), epochMessage.message.toByteArray())) continue;
            return true;
        }
        return false;
    }

    protected final void wrapMessage(EpochMessage epochMessage) throws HandshakeException {
        try {
            DTLSMessage message = epochMessage.message;
            switch (message.getContentType()) {
                case HANDSHAKE: {
                    this.wrapHandshakeMessage(epochMessage);
                    break;
                }
                case CHANGE_CIPHER_SPEC: {
                    this.flushMultiHandshakeMessages();
                    this.records.add(new Record(message.getContentType(), epochMessage.epoch, message, this.context, false, 0));
                    LOGGER.debug("Add CCS message of {} bytes for [{}]", (Object)message.size(), this.peerToLog);
                    break;
                }
                default: {
                    throw new HandshakeException("Cannot create " + (Object)((Object)message.getContentType()) + " record for flight", new AlertMessage(AlertMessage.AlertLevel.FATAL, AlertMessage.AlertDescription.INTERNAL_ERROR));
                }
            }
        }
        catch (GeneralSecurityException e) {
            throw new HandshakeException("Cannot create record", new AlertMessage(AlertMessage.AlertLevel.FATAL, AlertMessage.AlertDescription.INTERNAL_ERROR), e);
        }
    }

    private void wrapHandshakeMessage(EpochMessage epochMessage) throws GeneralSecurityException {
        int fragmentLength;
        int effectiveMaxMessageSize;
        ConnectionId connectionId;
        HandshakeMessage handshakeMessage = (HandshakeMessage)epochMessage.message;
        int maxPayloadLength = this.maxDatagramSize - 13;
        boolean useCid = false;
        if (epochMessage.epoch > 0 && (connectionId = this.context.getWriteConnectionId()) != null && !connectionId.isEmpty()) {
            useCid = true;
            maxPayloadLength -= connectionId.length();
        }
        if (this.maxFragmentSize >= maxPayloadLength) {
            effectiveMaxMessageSize = maxPayloadLength;
            this.effectiveMaxDatagramSize = this.maxDatagramSize;
        } else {
            effectiveMaxMessageSize = this.maxFragmentSize;
            this.effectiveMaxDatagramSize = this.maxFragmentSize + (this.maxDatagramSize - maxPayloadLength);
        }
        if (epochMessage.epoch > 0) {
            effectiveMaxMessageSize -= this.context.getSession().getMaxCiphertextExpansion();
            if (useCid) {
                --effectiveMaxMessageSize;
            }
        }
        this.effectiveMaxMessageSize = effectiveMaxMessageSize;
        int messageSize = handshakeMessage.size();
        if (messageSize <= effectiveMaxMessageSize) {
            if (this.useMultiHandshakeMessageRecords) {
                if (this.multiHandshakeMessage != null) {
                    if (this.multiEpoch == epochMessage.epoch && this.multiUseCid == useCid && this.multiHandshakeMessage.size() + messageSize < effectiveMaxMessageSize) {
                        this.multiHandshakeMessage.add(handshakeMessage);
                        LOGGER.debug("Add multi-handshake-message {} message of {} bytes, resulting in {} bytes for [{}]", new Object[]{handshakeMessage.getMessageType(), messageSize, this.multiHandshakeMessage.getMessageLength(), this.peerToLog});
                        return;
                    }
                    this.flushMultiHandshakeMessages();
                }
                if (this.multiHandshakeMessage == null && messageSize < effectiveMaxMessageSize) {
                    this.multiHandshakeMessage = new MultiHandshakeMessage();
                    this.multiHandshakeMessage.add(handshakeMessage);
                    this.multiEpoch = epochMessage.epoch;
                    this.multiUseCid = useCid;
                    LOGGER.debug("Start multi-handshake-message with {} message of {} bytes for [{}]", new Object[]{handshakeMessage.getMessageType(), messageSize, this.peerToLog});
                    return;
                }
            }
            this.records.add(new Record(ContentType.HANDSHAKE, epochMessage.epoch, handshakeMessage, this.context, useCid, 0));
            LOGGER.debug("Add {} message of {} bytes for [{}]", new Object[]{handshakeMessage.getMessageType(), messageSize, this.peerToLog});
            return;
        }
        this.flushMultiHandshakeMessages();
        LOGGER.debug("Splitting up {} message of {} bytes for [{}] into multiple handshake fragments of max. {} bytes", new Object[]{handshakeMessage.getMessageType(), messageSize, this.peerToLog, effectiveMaxMessageSize});
        byte[] messageBytes = handshakeMessage.fragmentToByteArray();
        int handshakeMessageLength = handshakeMessage.getMessageLength();
        int maxHandshakeMessageLength = effectiveMaxMessageSize - 12;
        if (messageBytes.length != handshakeMessageLength) {
            throw new IllegalStateException("message length " + handshakeMessageLength + " differs from message " + messageBytes.length + "!");
        }
        int messageSeq = handshakeMessage.getMessageSeq();
        for (int offset = 0; offset < handshakeMessageLength; offset += fragmentLength) {
            fragmentLength = maxHandshakeMessageLength;
            if (offset + fragmentLength > handshakeMessageLength) {
                fragmentLength = handshakeMessageLength - offset;
            }
            byte[] fragmentBytes = new byte[fragmentLength];
            System.arraycopy(messageBytes, offset, fragmentBytes, 0, fragmentLength);
            FragmentedHandshakeMessage fragmentedMessage = new FragmentedHandshakeMessage(handshakeMessage.getMessageType(), handshakeMessageLength, messageSeq, offset, fragmentBytes);
            LOGGER.debug("fragment for offset {}, {} bytes", (Object)offset, (Object)fragmentedMessage.size());
            this.records.add(new Record(ContentType.HANDSHAKE, epochMessage.epoch, fragmentedMessage, this.context, false, 0));
        }
    }

    private void flushMultiHandshakeMessages() throws GeneralSecurityException {
        if (this.multiHandshakeMessage != null) {
            this.records.add(new Record(ContentType.HANDSHAKE, this.multiEpoch, this.multiHandshakeMessage, this.context, this.multiUseCid, 0));
            int count = this.multiHandshakeMessage.getNumberOfHandshakeMessages();
            LOGGER.debug("Add {} multi handshake message, epoch {} of {} bytes (max. {}) for [{}]", count, this.multiEpoch, this.multiHandshakeMessage.getMessageLength(), this.effectiveMaxMessageSize, this.peerToLog);
            this.multiHandshakeMessage = null;
            this.multiEpoch = 0;
            this.multiUseCid = false;
        }
    }

    public List<Record> getRecords(int maxDatagramSize, int maxFragmentSize, boolean useMultiHandshakeMessageRecords) throws HandshakeException {
        try {
            if (this.maxDatagramSize == maxDatagramSize && this.maxFragmentSize == maxFragmentSize && this.useMultiHandshakeMessageRecords == useMultiHandshakeMessageRecords) {
                for (int index = 0; index < this.records.size(); ++index) {
                    Record record = this.records.get(index);
                    int epoch = record.getEpoch();
                    DTLSMessage fragment = record.getFragment();
                    boolean useCid = record.useConnectionId();
                    this.records.set(index, new Record(record.getType(), epoch, fragment, this.context, useCid, 0));
                }
            } else {
                this.effectiveMaxDatagramSize = maxDatagramSize;
                this.maxDatagramSize = maxDatagramSize;
                this.maxFragmentSize = maxFragmentSize;
                this.useMultiHandshakeMessageRecords = useMultiHandshakeMessageRecords;
                this.records.clear();
                for (EpochMessage message : this.dtlsMessages) {
                    this.wrapMessage(message);
                }
                this.flushMultiHandshakeMessages();
            }
        }
        catch (GeneralSecurityException e) {
            this.records.clear();
            throw new HandshakeException("Cannot create record", new AlertMessage(AlertMessage.AlertLevel.FATAL, AlertMessage.AlertDescription.INTERNAL_ERROR), e);
        }
        return this.records;
    }

    public List<DatagramPacket> getDatagrams(int maxDatagramSize, int maxFragmentSize, Boolean useMultiHandshakeMessageRecords, Boolean useMultiRecordMessages, boolean backOff) throws HandshakeException {
        boolean multiRecords;
        DatagramWriter writer = new DatagramWriter(maxDatagramSize);
        ArrayList<DatagramPacket> datagrams = new ArrayList<DatagramPacket>();
        boolean multiHandshakeMessages = Boolean.TRUE.equals(useMultiHandshakeMessageRecords);
        boolean bl = multiRecords = !Boolean.FALSE.equals(useMultiRecordMessages);
        if (backOff) {
            maxDatagramSize = Math.min(512, maxDatagramSize);
        }
        LOGGER.trace("Prepare flight {}, using max. datagram size {}, max. fragment size {} [mhm={}, mr={}]", this.flightNumber, maxDatagramSize, maxFragmentSize, multiHandshakeMessages, multiRecords);
        List<Record> records = this.getRecords(maxDatagramSize, maxFragmentSize, multiHandshakeMessages);
        LOGGER.trace("Effective max. datagram size {}, max. message size {}", (Object)this.effectiveMaxDatagramSize, (Object)this.effectiveMaxMessageSize);
        for (int index = 0; index < records.size(); ++index) {
            int left;
            Record record = records.get(index);
            byte[] recordBytes = record.toByteArray();
            if (recordBytes.length > this.effectiveMaxDatagramSize) {
                LOGGER.error("{} record of {} bytes for peer [{}] exceeds max. datagram size [{}], discarding...", new Object[]{record.getType(), recordBytes.length, this.peerToLog, this.effectiveMaxDatagramSize});
                LOGGER.debug("{}", (Object)record);
                continue;
            }
            LOGGER.trace("Sending record of {} bytes to peer [{}]:\n{}", recordBytes.length, this.peerToLog, record);
            if (multiRecords && record.getType() == ContentType.CHANGE_CIPHER_SPEC && ++index < records.size()) {
                Record finish = records.get(index);
                recordBytes = Bytes.concatenate(recordBytes, finish.toByteArray());
            }
            int n = left = multiRecords && (!backOff || useMultiRecordMessages != null) ? this.effectiveMaxDatagramSize - recordBytes.length : 0;
            if (writer.size() > left) {
                byte[] payload = writer.toByteArray();
                DatagramPacket datagram = new DatagramPacket(payload, payload.length, this.peer.getAddress(), this.peer.getPort());
                datagrams.add(datagram);
                LOGGER.debug("Sending datagram of {} bytes to peer [{}]", (Object)payload.length, this.peerToLog);
            }
            writer.writeBytes(recordBytes);
        }
        byte[] payload = writer.toByteArray();
        DatagramPacket datagram = new DatagramPacket(payload, payload.length, this.peer.getAddress(), this.peer.getPort());
        datagrams.add(datagram);
        LOGGER.debug("Sending datagram of {} bytes to peer [{}]", (Object)payload.length, this.peerToLog);
        writer = null;
        return datagrams;
    }

    public int getEffectiveMaxMessageSize() {
        return this.effectiveMaxMessageSize;
    }

    public int getFlightNumber() {
        return this.flightNumber;
    }

    public int getTries() {
        return this.tries;
    }

    public void incrementTries() {
        ++this.tries;
    }

    public int getTimeout() {
        return this.timeoutMillis;
    }

    public void setTimeout(int timeoutMillis) {
        this.timeoutMillis = timeoutMillis;
    }

    public void incrementTimeout(float scale, int maxTimeoutMillis) {
        this.timeoutMillis = DTLSFlight.incrementTimeout(this.timeoutMillis, scale, maxTimeoutMillis);
    }

    public boolean isRetransmissionNeeded() {
        return this.retransmissionNeeded;
    }

    public void setRetransmissionNeeded(boolean needsRetransmission) {
        this.retransmissionNeeded = needsRetransmission;
    }

    public boolean isResponseStarted() {
        return this.responseStarted;
    }

    public void setResponseStarted() {
        this.responseStarted = true;
    }

    private final void cancelTimeout() {
        if (this.timeoutTask != null) {
            if (!this.timeoutTask.isDone()) {
                this.timeoutTask.cancel(true);
            }
            this.timeoutTask = null;
        }
    }

    public void setResponseCompleted() {
        this.responseCompleted = true;
        this.cancelTimeout();
    }

    public boolean isResponseCompleted() {
        return this.responseCompleted;
    }

    public boolean isFinishedIncluded() {
        return this.finishedIncluded;
    }

    public void scheduleRetransmission(ScheduledExecutorService timer, Runnable task) {
        if (!this.responseCompleted) {
            if (this.isRetransmissionNeeded()) {
                this.cancelTimeout();
                try {
                    this.timeoutTask = timer.schedule(task, (long)this.timeoutMillis, TimeUnit.MILLISECONDS);
                    LOGGER.trace("handshake flight to peer {}, retransmission {} ms.", this.peerToLog, (Object)this.timeoutMillis);
                }
                catch (RejectedExecutionException ex) {
                    LOGGER.trace("handshake flight stopped by shutdown.");
                }
            } else {
                LOGGER.trace("handshake flight to peer {}, no retransmission!", this.peerToLog);
            }
        }
    }

    public static int incrementTimeout(int timeoutMillis, float scale, int maxTimeoutMillis) {
        if (timeoutMillis < maxTimeoutMillis) {
            timeoutMillis = Math.round((float)timeoutMillis * scale);
            timeoutMillis = Math.min(timeoutMillis, maxTimeoutMillis);
        }
        return timeoutMillis;
    }

    private static class EpochMessage {
        private final int epoch;
        private final DTLSMessage message;

        private EpochMessage(int epoch, DTLSMessage message) {
            this.epoch = epoch;
            this.message = message;
        }
    }
}

