package com.atlassian.diagnostics.internal.dao;

import com.atlassian.diagnostics.Alert;
import com.atlassian.diagnostics.AlertCriteria;
import com.atlassian.diagnostics.PageRequest;
import com.atlassian.diagnostics.Severity;
import org.apache.commons.lang3.mutable.MutableLong;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;

import static com.atlassian.diagnostics.CallbackResult.DONE;
import static com.atlassian.diagnostics.internal.util.InstantPrecisionUtil.truncateNanoSecondPrecision;
import static org.apache.commons.lang3.StringUtils.defaultString;
import static org.apache.commons.lang3.StringUtils.length;

/**
 * Simple implementation that's used for ref-app and validating the sanity of the DAO contract
 */
public class InMemoryAlertEntityDao implements AlertEntityDao {

    private static final Logger logger = LoggerFactory.getLogger(InMemoryAlertEntityDao.class);

    private static final Comparator<AlertEntity> ALERT_ENTITY_COMPARATOR = (alert1, alert2) -> {
        int result = alert2.getTimestamp().compareTo(alert1.getTimestamp()); // new -> old
        if (result != 0) {
            return result;
        }

        return Long.compare(alert2.getId(), alert1.getId()); // new -> old
    };

    private final AtomicLong nextId = new AtomicLong(1000L);
    private final List<AlertEntity> entities = new CopyOnWriteArrayList<>();

    @Override
    public void deleteAll(@Nonnull AlertCriteria criteria) {
        entities.removeIf(entity -> matches(criteria, entity));
    }

    @Override
    public Set<String> findAllComponentIds() {
        Set<String> results = entities.stream().map(AlertEntity::getIssueComponentId).collect(Collectors.toSet());
        logger.info("Got findAllComponentIds: [{}]", results);
        return results;
    }

    @Override
    public Map<String, Severity> findAllIssueIds() {
        Map<String, Severity> result = new HashMap<>();
        entities.forEach(entity -> result.putIfAbsent(entity.getIssueId(), entity.getIssueSeverity()));
        return result;
    }

    @Override
    public Set<String> findAllNodeNames() {
        Set<String> results = entities.stream().map(AlertEntity::getNodeName).filter(Objects::nonNull).collect(Collectors.toSet());
        logger.info("Got findAllNodeNames: [{}]", results);
        return results;
    }

    @Override
    public Set<String> findAllPluginKeys() {
        Set<String> results = entities.stream().map(AlertEntity::getTriggerPluginKey).filter(Objects::nonNull)
                .collect(Collectors.toSet());
        logger.info("Got findAllPluginKeys: [{}]", results);
        return results;
    }

    @Override
    public AlertEntity getById(long id) {
        return entities.stream()
                .filter(entity -> entity.getId() == id)
                .findFirst()
                .orElse(null);
    }

    @Nonnull
    @Override
    public AlertEntity save(@Nonnull Alert alert) {
        SimpleAlertEntity entity = new SimpleAlertEntity(alert, nextId.getAndIncrement());
        entities.add(entity);
        entities.sort(ALERT_ENTITY_COMPARATOR);

        logger.info("Saved {} entities", entities.size());
        logger.trace("Saved entities: [{}]", entities);
        return entity;
    }

    @Override
    public void streamAll(@Nonnull AlertCriteria criteria, @Nonnull RowCallback<AlertEntity> callback,
                          @Nonnull PageRequest pageRequest) {

        int row = 0;
        int startOffset = pageRequest.getStart();
        int endOffset = startOffset + pageRequest.getLimit();

        for (AlertEntity entity : entities) {
            if (matches(criteria, entity)) {
                if (row >= startOffset && row <= endOffset) {
                    if (callback.onRow(entity) == DONE) {
                        return;
                    }
                }
                if (row++ == endOffset) {
                    // for fromStart page requests, return limit + 1 items
                    return;
                }
            }
        }
    }

    @Override
    public void streamByIds(@Nonnull Collection<Long> ids, @Nonnull RowCallback<AlertEntity> callback) {
        Set<Long> remaining = new HashSet<>(ids);
        for (AlertEntity entity : entities) {
            if (remaining.remove(entity.getId())) {
                if (callback.onRow(entity) == DONE || remaining.isEmpty()) {
                    return;
                }
            }
        }
    }

    @Override
    public void streamMetrics(@Nonnull AlertCriteria criteria, @Nonnull RowCallback<AlertMetric> callback,
                              @Nonnull PageRequest pageRequest) {

        LinkedHashMap<MetricKey, MutableLong> metricCounts = new LinkedHashMap<>();
        entities.stream()
                .filter(entity -> matches(criteria, entity))
                .forEach(entity ->
                        metricCounts.computeIfAbsent(new MetricKey(entity), key -> new MutableLong(0L))
                                .increment());

        List<AlertMetric> metrics = metricCounts.entrySet().stream()
                .map(entry -> {
                    MetricKey key = entry.getKey();
                    return new AlertMetric(key.issueId, key.issueSeverity, key.pluginKey, key.pluginVersion,
                            key.nodeName, entry.getValue().getValue());
                })
                .collect(Collectors.toList());

        metrics.sort(Comparator.comparing(AlertMetric::getIssueSeverity, Comparator.comparingInt(Severity::getId).reversed())
                .thenComparing(AlertMetric::getIssueId)
                .thenComparing(AlertMetric::getPluginKey)
                .thenComparing(metric -> defaultString(metric.getPluginVersion(), ""))
                .thenComparing(AlertMetric::getNodeName));

        int row = 0;
        int startOffset = pageRequest.getStart();
        int endOffset = startOffset + pageRequest.getLimit();

        for (AlertMetric metric : metrics) {
            if (row >= startOffset && row <= endOffset) {
                if (callback.onRow(metric) == DONE) {
                    return;
                }
            }
            if (row++ == endOffset) {
                // for fromStart page requests, return limit + 1 items
                break;
            }
        }
    }

    @Override
    public void streamMinimalAlerts(@Nonnull AlertCriteria criteria, @Nonnull RowCallback<MinimalAlertEntity> callback,
                                    @Nonnull PageRequest pageRequest) {
        streamAll(criteria, entity -> callback.onRow(toMinimalAlert(entity)), pageRequest);
    }

    private static boolean matches(AlertCriteria criteria, AlertEntity entity) {
        return valueMatchesCaseInsensitively(criteria.getIssueIds(), entity.getIssueId()) &&
                valueMatchesCaseInsensitively(criteria.getPluginKeys(), entity.getTriggerPluginKey()) &&
                valueMatchesCaseInsensitively(criteria.getNodeNames(), entity.getNodeName()) &&
                valueMatches(criteria.getSeverities(), entity.getIssueSeverity()) &&
                valueMatches(criteria.getComponentIds(), entity.getIssueComponentId()) &&
                criteria.getSince().map(since -> since.isBefore(truncateNanoSecondPrecision(entity.getTimestamp()))).orElse(true) &&
                criteria.getUntil().map(until -> !until.isBefore(truncateNanoSecondPrecision(entity.getTimestamp()))).orElse(true);
    }

    private static <T> boolean valueMatches(Set<T> acceptedValues, T value) {
        return acceptedValues.isEmpty() || acceptedValues.contains(value);
    }

    private static boolean valueMatchesCaseInsensitively(Set<String> acceptedValues, String value) {
        return acceptedValues.isEmpty() || acceptedValues.stream().anyMatch(val -> val.equalsIgnoreCase(value));
    }

    private static MinimalAlertEntity toMinimalAlert(AlertEntity entity) {
        return new SimpleMinimalAlertEntity(entity.getId(), entity.getTimestamp().toEpochMilli(), entity.getIssueId(),
                entity.getTriggerPluginKey(), entity.getNodeName(), length(entity.getDetailsJson()));
    }

    private static class MetricKey {

        private final String issueId;
        private final Severity issueSeverity;
        private final String nodeName;
        private final String pluginKey;
        private final String pluginVersion;

        private MetricKey(AlertEntity entity) {
            issueId = entity.getIssueId();
            issueSeverity = entity.getIssueSeverity();
            nodeName = entity.getNodeName();
            pluginKey = entity.getTriggerPluginKey();
            pluginVersion = entity.getTriggerPluginVersion();
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            MetricKey metricKey = (MetricKey) o;
            return Objects.equals(issueId, metricKey.issueId) &&
                    Objects.equals(nodeName, metricKey.nodeName) &&
                    Objects.equals(pluginKey, metricKey.pluginKey) &&
                    Objects.equals(pluginVersion, metricKey.pluginVersion);
        }

        @Override
        public int hashCode() {
            return Objects.hash(issueId, nodeName, pluginKey, pluginVersion);
        }
    }
}
