/*
 * Decompiled with CFR 0.152.
 */
package software.amazon.jdbc.plugin;

import java.lang.ref.WeakReference;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.logging.Logger;
import software.amazon.jdbc.HostSpec;
import software.amazon.jdbc.PluginService;
import software.amazon.jdbc.util.Messages;
import software.amazon.jdbc.util.RdsUtils;
import software.amazon.jdbc.util.StringUtils;
import software.amazon.jdbc.util.SynchronousExecutor;
import software.amazon.jdbc.util.telemetry.TelemetryContext;
import software.amazon.jdbc.util.telemetry.TelemetryFactory;
import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel;

public class OpenedConnectionTracker {
    static final Map<String, Queue<WeakReference<Connection>>> openedConnections = new ConcurrentHashMap<String, Queue<WeakReference<Connection>>>();
    private static final String TELEMETRY_INVALIDATE_CONNECTIONS = "invalidate connections";
    private static final ExecutorService invalidateConnectionsExecutorService = Executors.newCachedThreadPool(r -> {
        Thread invalidateThread = new Thread(r);
        invalidateThread.setDaemon(true);
        return invalidateThread;
    });
    private static final Executor abortConnectionExecutor = new SynchronousExecutor();
    private static final Logger LOGGER = Logger.getLogger(OpenedConnectionTracker.class.getName());
    private static final RdsUtils rdsUtils = new RdsUtils();
    private static final Set<String> safeToCheckClosedClasses = new HashSet<String>(Arrays.asList("HikariProxyConnection", "org.postgresql.jdbc.PgConnection", "com.mysql.cj.jdbc.ConnectionImpl", "org.mariadb.jdbc.Connection"));
    private final PluginService pluginService;

    public OpenedConnectionTracker(PluginService pluginService) {
        this.pluginService = pluginService;
    }

    public void populateOpenedConnectionQueue(HostSpec hostSpec, Connection conn) {
        Set<String> aliases = hostSpec.asAliases();
        if (rdsUtils.isRdsInstance(hostSpec.getHost())) {
            this.trackConnection(hostSpec.getHostAndPort(), conn);
            this.logOpenedConnections();
            return;
        }
        String instanceEndpoint = aliases.stream().filter(x -> rdsUtils.isRdsInstance(rdsUtils.removePort((String)x))).max(String::compareToIgnoreCase).orElse(null);
        if (instanceEndpoint != null) {
            this.trackConnection(instanceEndpoint, conn);
            this.logOpenedConnections();
            return;
        }
        for (String alias : aliases) {
            this.trackConnection(alias, conn);
        }
        this.logOpenedConnections();
    }

    public void invalidateAllConnections(HostSpec hostSpec) {
        this.invalidateAllConnections(hostSpec.asAlias());
        this.invalidateAllConnections(hostSpec.getAliases().toArray(new String[0]));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void invalidateAllConnections(String ... keys) {
        TelemetryFactory telemetryFactory = this.pluginService.getTelemetryFactory();
        TelemetryContext telemetryContext = telemetryFactory.openTelemetryContext(TELEMETRY_INVALIDATE_CONNECTIONS, TelemetryTraceLevel.NESTED);
        try {
            for (String key : keys) {
                try {
                    Queue<WeakReference<Connection>> connectionQueue = openedConnections.get(key);
                    this.logConnectionQueue(key, connectionQueue);
                    this.invalidateConnections(connectionQueue);
                }
                catch (Exception exception) {
                    // empty catch block
                }
            }
        }
        finally {
            telemetryContext.closeContext();
        }
    }

    public void removeConnectionTracking(HostSpec hostSpec, Connection connection) {
        String host;
        String string = host = rdsUtils.isRdsInstance(hostSpec.getHost()) ? hostSpec.asAlias() : (String)hostSpec.getAliases().stream().filter(x -> rdsUtils.isRdsInstance(rdsUtils.removePort((String)x))).findFirst().orElse(null);
        if (StringUtils.isNullOrEmpty(host)) {
            return;
        }
        Queue<WeakReference<Connection>> connectionQueue = openedConnections.get(host);
        if (connectionQueue != null) {
            this.logConnectionQueue(host, connectionQueue);
            connectionQueue.removeIf(connectionWeakReference -> Objects.equals(connectionWeakReference.get(), connection));
        }
    }

    private void trackConnection(String instanceEndpoint, Connection connection) {
        Queue connectionQueue = openedConnections.computeIfAbsent(instanceEndpoint, k -> new ConcurrentLinkedQueue());
        connectionQueue.add(new WeakReference<Connection>(connection));
    }

    private void invalidateConnections(Queue<WeakReference<Connection>> connectionQueue) {
        if (connectionQueue == null || connectionQueue.isEmpty()) {
            return;
        }
        invalidateConnectionsExecutorService.submit(() -> {
            WeakReference connReference;
            while ((connReference = (WeakReference)connectionQueue.poll()) != null) {
                Connection conn = (Connection)connReference.get();
                if (conn == null) continue;
                try {
                    conn.abort(abortConnectionExecutor);
                }
                catch (SQLException sQLException) {}
            }
        });
    }

    public void logOpenedConnections() {
        LOGGER.finest(() -> {
            StringBuilder builder = new StringBuilder();
            openedConnections.forEach((key, queue) -> {
                if (!queue.isEmpty()) {
                    builder.append("\t");
                    builder.append((String)key).append(" :");
                    builder.append("\n\t{");
                    for (WeakReference connection : queue) {
                        builder.append("\n\t\t").append(connection.get());
                    }
                    builder.append("\n\t}\n");
                }
            });
            return String.format("Opened Connections Tracked: \n[\n%s\n]", builder);
        });
    }

    private void logConnectionQueue(String host, Queue<WeakReference<Connection>> queue) {
        if (queue == null || queue.isEmpty()) {
            return;
        }
        StringBuilder builder = new StringBuilder();
        builder.append(host).append("\n[");
        for (WeakReference weakReference : queue) {
            builder.append("\n\t").append(weakReference.get());
        }
        builder.append("\n]");
        LOGGER.finest(Messages.get("OpenedConnectionTracker.invalidatingConnections", new Object[]{builder.toString()}));
    }

    public void pruneNullConnections() {
        openedConnections.forEach((key, queue) -> queue.removeIf(connectionWeakReference -> {
            Connection conn = (Connection)connectionWeakReference.get();
            if (conn == null) {
                return true;
            }
            if (safeToCheckClosedClasses.contains(conn.getClass().getSimpleName()) || safeToCheckClosedClasses.contains(conn.getClass().getName())) {
                try {
                    return conn.isClosed();
                }
                catch (SQLException ex) {
                    return false;
                }
            }
            return false;
        }));
    }

    public static void clearCache() {
        openedConnections.clear();
    }
}

