package com.microsoft.azure.documentdb.internal;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.microsoft.azure.documentdb.DocumentClientException;

public class VectorSessionToken {

    private final static Logger logger = LoggerFactory.getLogger(VectorSessionToken.class);
    private static final String SEGMENT_SEPARATOR = "#";
    private static final String REGION_PROGRESS_SEPARATOR = "=";
    private String sessionToken;
    private long version;
    private long globalLsn;
    private Map<Integer, Long> localLsnByRegion;
    
    public VectorSessionToken(long version, long globalLsn, Map<Integer, Long> localLsnByRegion) {
        this(version, globalLsn, localLsnByRegion, null);
    }
    
    public VectorSessionToken(long version, long globalLsn, Map<Integer, Long> localLsnByRegion, String sessionToken) {
        this.version = version;
        this.globalLsn = globalLsn;
        this.localLsnByRegion = localLsnByRegion;
        this.sessionToken = sessionToken;

        if (this.sessionToken == null) {
            List<String> regionAndLocalLsn = new ArrayList<String>();
            for (Entry<Integer, Long> kvp : localLsnByRegion.entrySet()) {
                regionAndLocalLsn.add(kvp.getKey() + REGION_PROGRESS_SEPARATOR + kvp.getValue());
            }

            String regionProgress = String.join(SEGMENT_SEPARATOR, regionAndLocalLsn);
            if (StringUtils.isEmpty(regionProgress)) {
                this.sessionToken = String.format("%s%s%s",
                        this.version,
                        SEGMENT_SEPARATOR,
                        this.globalLsn);
            } else {
                this.sessionToken = String.format(
                        "%s%s%s%s%s",
                        this.version,
                        SEGMENT_SEPARATOR,
                        this.globalLsn,
                        SEGMENT_SEPARATOR,
                        regionProgress);
            }
        }
    }
    
    public static VectorSessionToken create(String sessionToken) {
        Long version = null;
        Long globalLsn = null;

        if (StringUtils.isEmpty(sessionToken)) {
            logger.trace("Session token is empty");
            return null;
        }
        
        String[] segments = sessionToken.split(SEGMENT_SEPARATOR);
        
        if (segments.length < 2) {
            return null;
        }

        try {
            version = Long.parseLong(segments[0]);
        } catch (NumberFormatException e) {
            logger.trace("Unexpected session token version number '{}'", segments[0]);
            return null;
        }
        
        try {
            globalLsn = Long.parseLong(segments[1]);
        } catch (NumberFormatException e) {
            logger.trace("Unexpected global lsn '{}'", segments[1]);
            return null;
        }
        
        Map<Integer, Long> lsnByRegion = new HashMap<Integer, Long>();
        for (int i = 2; i < segments.length; i++) {
            String regionSegment = segments[i];
            String[] regionIdWithLsn = regionSegment.split(REGION_PROGRESS_SEPARATOR);

            if(regionIdWithLsn.length != 2) {
                logger.trace("Unexpected region progress segment length '{}' in session token.", regionIdWithLsn.length);
                return null;
            }

            Integer regionId = null;
            Long localLsn = null;

            try {
                regionId = Integer.parseInt(regionIdWithLsn[0]);
                localLsn = Long.parseLong(regionIdWithLsn[1]);
            } catch (NumberFormatException e) {
                logger.trace("Unexpected region progress '{}' for region '{}' in session token.", regionIdWithLsn[0], regionIdWithLsn[1]);
                return null;
            }

            lsnByRegion.put(regionId, localLsn);
        }

        return new VectorSessionToken(version, globalLsn, lsnByRegion, sessionToken);
    }

    public boolean equals(VectorSessionToken other) {
        return other == null ? false
                : this.version == other.version && this.globalLsn == other.globalLsn
                        && this.areRegionProgressEqual(other.localLsnByRegion);
    }
    
    public boolean isValid(VectorSessionToken other) throws DocumentClientException {
        if (other == null) {
            throw new IllegalArgumentException("Invalid Session Token (should not be null).");
        }

        if (other.version < this.version || other.globalLsn < this.globalLsn) {
            return false;
        }

        if (other.version == this.version && other.localLsnByRegion.size() != this.localLsnByRegion.size()) {
            throw new DocumentClientException(HttpConstants.StatusCodes.INTERNAL_SERVER_ERROR,
                    String.format("Compared session tokens '{}' and '{}' has unexpected regions.",
                            this.sessionToken, other.sessionToken));
        }

        for (Entry<Integer, Long> kvp : other.localLsnByRegion.entrySet()) {
            Integer regionId = kvp.getKey();
            long otherLocalLsn = kvp.getValue();
            Long localLsn = this.localLsnByRegion.get(regionId);

            if (localLsn == null) {
                // Region mismatch: other session token has progress for a region which is missing in this session token 
                // Region mismatch can be ignored only if this session token version is smaller than other session token version
                if (this.version == other.version) {
                    throw new DocumentClientException(HttpConstants.StatusCodes.INTERNAL_SERVER_ERROR,
                            String.format("Compared session tokens '{}' and '{}' has unexpected regions.",
                                    this.sessionToken, other.sessionToken));
                }
                else {
                    // ignore missing region as other session token version > this session token version
                }
            }
            else {
                // region is present in both session tokens.
                if (otherLocalLsn < localLsn) {
                    return false;
                }
            }
        }
        return true;
    }
    
    public VectorSessionToken merge(VectorSessionToken other) throws DocumentClientException {
        if (other == null) {
            throw new IllegalArgumentException("Invalid Session Token (should not be null).");
        }

        if (this.version == other.version && this.localLsnByRegion.size() != other.localLsnByRegion.size()) {
            throw new DocumentClientException(HttpConstants.StatusCodes.INTERNAL_SERVER_ERROR,
                    String.format("Compared session tokens '{}' and '{}' has unexpected regions.",
                            this.sessionToken, other.sessionToken));
        }

        VectorSessionToken sessionTokenWithHigherVersion;
        VectorSessionToken sessionTokenWithLowerVersion;

        if (this.version < other.version) {
            sessionTokenWithLowerVersion = this;
            sessionTokenWithHigherVersion = other;
        } else {
            sessionTokenWithLowerVersion = other;
            sessionTokenWithHigherVersion = this;
        }
        
        Map<Integer, Long> highestLocalLsnByRegion = new HashMap<Integer, Long>();

        for (Entry<Integer, Long> kvp : sessionTokenWithHigherVersion.localLsnByRegion.entrySet()) {
            int regionId = kvp.getKey();
            long localLsn1 = kvp.getValue();
            Long localLsn2 = sessionTokenWithLowerVersion.localLsnByRegion.get(regionId);

            if (localLsn2 != null) {
                highestLocalLsnByRegion.put(regionId, Math.max(localLsn1, localLsn2));
            } else if (this.version == other.version) {
                throw new DocumentClientException(HttpConstants.StatusCodes.INTERNAL_SERVER_ERROR,
                        String.format("Compared session tokens '{}' and '{}' has unexpected regions.",
                                this.sessionToken, other.sessionToken));
            } else {
                highestLocalLsnByRegion.put(regionId, localLsn1);
            }
        }

        return new VectorSessionToken(
            Math.max(this.version, other.version),
            Math.max(this.globalLsn, other.globalLsn),
            highestLocalLsnByRegion);
    }

    public String convertToString() {
        return this.sessionToken;
    }
    
    private boolean areRegionProgressEqual(Map<Integer, Long> other) {
        if (this.localLsnByRegion.size() != other.size()) {
            return false;
        }

        for (Entry<Integer, Long> kvp : this.localLsnByRegion.entrySet()) {
            int regionId = kvp.getKey();
            long localLsn1 = kvp.getValue();
            Long localLsn2 = other.get(regionId);

            if (localLsn2 != null) {
                if (localLsn1 != localLsn2) {
                    return false;
                }
            }
        }
        return true;
    }
}
