package com.atlassian.diagnostics.internal.jmx;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import java.lang.management.ThreadInfo;
import java.lang.management.ThreadMXBean;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

public class ThreadMemoryAllocationService {

    private static final Logger logger = LoggerFactory.getLogger(ThreadMemoryAllocationService.class);

    private final ThreadMXBean threadMXBean;

    public ThreadMemoryAllocationService(@Nonnull final JmxService jmxService) {
        this.threadMXBean = jmxService.getThreadMXBean();
    }

    @Nonnull
    public List<ThreadMemoryAllocation> getThreadMemoryAllocations(final long minimumMemoryAllocation) {
        return getThreadMemoryAllocations(minimumMemoryAllocation, Integer.MAX_VALUE);
    }

    @Nonnull
    public List<ThreadMemoryAllocation> getThreadMemoryAllocations(final long minimumMemoryAllocation, final int maxStackTraceDepth) {
        final Map<Long, Long> threadIdMemoryAllocationsAboveThreshold = threadIdMemoryAllocations().entrySet().stream()
                .filter(e -> e.getValue() >= minimumMemoryAllocation)
                .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

        if (!threadIdMemoryAllocationsAboveThreshold.isEmpty()) {
            return threadMemoryAllocations(threadIdMemoryAllocationsAboveThreshold, maxStackTraceDepth);
        } else {
            return Collections.emptyList();
        }
    }

    @Nonnull
    public List<ThreadMemoryAllocation> getTopThreadMemoryAllocations(final int numberOfThreadMemoryAllocations) {
        return getTopThreadMemoryAllocations(numberOfThreadMemoryAllocations, Integer.MAX_VALUE);
    }

    @Nonnull
    public List<ThreadMemoryAllocation> getTopThreadMemoryAllocations(final int numberOfThreadMemoryAllocations, final int maxStackTraceDepth) {
        final Map<Long, Long> topThreadsByMemoryAllocationSize = threadIdMemoryAllocations().entrySet().stream()
                .sorted(Comparator.comparing(Map.Entry::getValue, Comparator.reverseOrder()))
                .limit(Math.max(0, numberOfThreadMemoryAllocations))
                .collect(Collectors.toMap(
                        Map.Entry::getKey,
                        Map.Entry::getValue,
                        (v1, v2) -> v2,
                        LinkedHashMap::new
                ));

        if (!topThreadsByMemoryAllocationSize.isEmpty()) {
            return threadMemoryAllocations(topThreadsByMemoryAllocationSize, maxStackTraceDepth);
        }

        return Collections.emptyList();
    }

    private Map<Long, Long> threadIdMemoryAllocations() {
        final long[] threadIds = threadMXBean.getAllThreadIds();
        final long[] memoryAllocations = getMemoryAllocations(threadIds);
        if (threadIds.length == memoryAllocations.length) {
            final Map<Long, Long> threadIdMemoryAllocations = new HashMap<>();
            for (int i = 0; i < threadIds.length; i++) {
                threadIdMemoryAllocations.put(threadIds[i], memoryAllocations[i]);
            }

            return threadIdMemoryAllocations;
        } else {
            return Collections.emptyMap();
        }
    }

    private long[] getMemoryAllocations(final long[] threadIds) {
        try {
            return ((com.sun.management.ThreadMXBean) threadMXBean).getThreadAllocatedBytes(threadIds);
        } catch (ClassCastException | NoClassDefFoundError e) {
            logger.debug("Plugins cannot compile this class as com.sun.management.* is filtered out of the ClassLoader. Compile from the product instead", e);
            return new long[0];
        }
    }

    private List<ThreadMemoryAllocation> threadMemoryAllocations(final Map<Long, Long> threadIdMemoryAllocations, final int maxStackTraceDepth) {
        if (!threadIdMemoryAllocations.isEmpty()) {
            final long[] threadIds = threadIdMemoryAllocations.keySet().stream().mapToLong(l -> l).toArray();
            final ThreadInfo[] threadInfos = threadMXBean.getThreadInfo(threadIds, Math.max(maxStackTraceDepth, 0));

            return Arrays.stream(threadInfos)
                    .filter(Objects::nonNull)
                    .map(threadInfo -> new ThreadMemoryAllocation(
                            threadInfo.getThreadName(),
                            threadIdMemoryAllocations.get(threadInfo.getThreadId()),
                            threadInfo.getStackTrace()
                    )).collect(Collectors.toList());
        } else {
            return Collections.emptyList();
        }
    }
}
