package com.atlassian.diagnostics.internal.platform.monitor.db;

import com.atlassian.diagnostics.internal.platform.plugin.PluginFinder;
import com.atlassian.plugin.util.PluginKeyStack;
import com.atlassian.util.profiling.Ticker;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.RemovalCause;
import com.google.common.cache.RemovalListener;
import com.google.common.cache.RemovalNotification;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.sql.Connection;
import java.sql.SQLException;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;

import static com.atlassian.diagnostics.internal.platform.monitor.DurationUtils.durationOf;
import static com.atlassian.util.profiling.Metrics.metric;
import static java.lang.System.currentTimeMillis;
import static java.lang.Thread.currentThread;
import static java.time.Duration.ZERO;

public class DefaultDatabaseDiagnosticsCollector implements DatabaseDiagnosticsCollector {
    private static final Logger logger = LoggerFactory.getLogger(DefaultDatabaseDiagnosticsCollector.class);

    private final ClassContextSecurityManager classContextSecurityManager;
    private final DatabaseMonitorConfiguration configuration;
    private final DatabasePoolDiagnosticProvider databasePoolDiagnosticProvider;
    private final Duration poolConnectionLeakTimeout;
    private final DatabaseMonitor databaseMonitor;
    private final boolean findInvoker;
    private final boolean improvedAccuracy;
    private final PluginFinder pluginFinder;
    private final Clock clock;
    private final Cache<Connection, Instant> connectionCache;

    public DefaultDatabaseDiagnosticsCollector(
            @Nonnull final DatabaseMonitorConfiguration configuration,
            @Nonnull final DatabasePoolDiagnosticProvider databasePoolDiagnosticProvider,
            @Nonnull final Clock clock,
            @Nonnull final DatabaseMonitor databaseMonitor,
            @Nonnull final PluginFinder pluginFinder
    ) {
        this.configuration = configuration;
        this.classContextSecurityManager = new ClassContextSecurityManager();
        this.databasePoolDiagnosticProvider = databasePoolDiagnosticProvider;
        this.poolConnectionLeakTimeout = configuration.poolConnectionLeakTimeout();
        this.databaseMonitor = databaseMonitor;
        this.pluginFinder = pluginFinder;
        this.clock = clock;
        this.connectionCache = CacheBuilder.newBuilder()
                .weakKeys()
                .maximumSize(500)
                .expireAfterAccess(poolConnectionLeakTimeout)
                .removalListener(new LeakedConnectionListener())
                .build();
        // We want these cached for performance, and this enables products to figure out if they should be enabled
        // however they want.
        this.findInvoker = configuration.findStaticMethodInvoker();
        this.improvedAccuracy = configuration.staticMethodInvokerImprovedAccuracy();
    }

    @Override
    public boolean isEnabled() {
        return configuration.isEnabled();
    }

    @Override
    public void trackConnection(final Connection connection) {
        if (durationOf(poolConnectionLeakTimeout).isGreaterThan(ZERO)) {
            connectionCache.put(connection, clock.instant());
        }
    }

    @Override
    public void removeTrackedConnection(final Connection connection) {
        if (durationOf(poolConnectionLeakTimeout).isGreaterThan(ZERO)) {
            connectionCache.invalidate(connection);
        }
    }

    public <T> T recordExecutionTime(final SqlOperation<T> operation, final String sql) throws SQLException {
        final Ticker ticker = startTimingDatabaseOperation(sql);
        final long startTime = currentTimeMillis();
        try {
            return operation.execute();
        } finally {
            try {
                ticker.close();
                final long endTime = currentTimeMillis();
                raiseAlertIfExecutionExceededThreshold(sql, Duration.ofMillis(endTime - startTime));
            } catch (final Exception exception) {
                logger.error("Something threw an exception while completing timing of a DB operation. This logging is" +
                        " just a safety net so the application continues to work. The exception was: ", exception);
            }
        }
    }

    @Nonnull
    private Ticker startTimingDatabaseOperation(@Nullable final String sql) {
            // This code block has to be done on the thread, so we minimise the cost as much as possible
            final String pluginKeyFromOsgiStack = PluginKeyStack.getFirstPluginKey();
            // Specifically getting the class context can be expensive while cache replicating in Jira
            final Class<?>[] classContext = findInvoker && improvedAccuracy ? classContextSecurityManager.getClassContext() : null;
            final Throwable throwable = findInvoker && classContext == null ? new Throwable() : null;

                String invokingPluginKey = pluginKeyFromOsgiStack;

                if (invokingPluginKey == null && classContext != null) {
                    invokingPluginKey = pluginFinder.getInvokingPluginKeyFromClassContext(classContext);
                }

                if (invokingPluginKey == null && throwable != null) {
                    invokingPluginKey = pluginFinder.getInvokingPluginKeyFromStackTrace(throwable.getStackTrace());
                }

                return metric("db.core.executionTime")
                        .optionalTag("sql", sql)
                        .invokerPluginKey(invokingPluginKey)
                        .withAnalytics()
                        .startLongRunningTimer();
    }

    private void raiseAlertIfExecutionExceededThreshold(final String sql, final Duration duration) {
        if (durationOf(duration).isGreaterThanOrEqualTo(configuration.longRunningOperationLimit())) {
            final DatabaseOperationDiagnostic diagnostic = new DatabaseOperationDiagnostic(sql, duration, currentThread().getName());
            databaseMonitor.raiseAlertForSlowOperation(Instant.now(), diagnostic);
        }
    }

    private class LeakedConnectionListener implements RemovalListener<Connection, Instant> {

        @Override
        public void onRemoval(final RemovalNotification<Connection, Instant> notification) {
            if (notification.getCause() == RemovalCause.EXPIRED) {
                final DatabasePoolDiagnostic databasePoolDiagnostic = databasePoolDiagnosticProvider.getDiagnostic();
                if (!databasePoolDiagnostic.isEmpty()) {
                    databaseMonitor.raiseAlertForConnectionLeak(clock.instant(), notification.getValue(), databasePoolDiagnostic);
                }
            }
        }
    }

    /**
     * A simple {@link SecurityManager} that allows for accessing the class context.
     * <p>
     * Once JDK 8 support is dropped, the {@code java.lang.StackWalker} API could be used instead.
     * <p>
     * Inspired by the internals of {@link com.atlassian.diagnostics.internal.platform.plugin.PluginFinderImpl},
     * we don't want to support others using this though, especially with SecurityManager deprecated in Java 17.
     */
    private static class ClassContextSecurityManager extends SecurityManager {
        @Nullable
        @Override
        protected Class<?>[] getClassContext() {
            return super.getClassContext();
        }
    }
}
