/*
 * Decompiled with CFR 0.152.
 */
package org.terracotta.dynamic_config.cli.api.restart;

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.time.Duration;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.terracotta.diagnostic.client.DiagnosticService;
import org.terracotta.diagnostic.client.connection.ConcurrencySizing;
import org.terracotta.diagnostic.client.connection.DiagnosticServiceProvider;
import org.terracotta.diagnostic.client.connection.DiagnosticServiceProviderException;
import org.terracotta.diagnostic.common.DiagnosticException;
import org.terracotta.diagnostic.model.LogicalServerState;
import org.terracotta.dynamic_config.api.model.Node;
import org.terracotta.dynamic_config.api.service.DynamicConfigService;
import org.terracotta.dynamic_config.cli.api.restart.RestartProgress;

public class RestartService {
    private static final Logger LOGGER = LoggerFactory.getLogger(RestartService.class);
    private final DiagnosticServiceProvider diagnosticServiceProvider;
    private final ConcurrencySizing concurrencySizing;

    public RestartService(DiagnosticServiceProvider diagnosticServiceProvider, ConcurrencySizing concurrencySizing) {
        this.diagnosticServiceProvider = Objects.requireNonNull(diagnosticServiceProvider);
        this.concurrencySizing = Objects.requireNonNull(concurrencySizing);
    }

    public RestartProgress restartNodes(Collection<Node.Endpoint> endpoints, Duration restartDelay, Collection<LogicalServerState> acceptedStates) {
        return this.restartNodes(endpoints, restartDelay, acceptedStates, DynamicConfigService::restart);
    }

    public RestartProgress restartNodesIfActives(Collection<Node.Endpoint> endpoints, Duration restartDelay, Collection<LogicalServerState> acceptedStates) {
        return this.restartNodes(endpoints, restartDelay, acceptedStates, DynamicConfigService::restartIfActive);
    }

    public RestartProgress restartNodesIfPassives(Collection<Node.Endpoint> endpoints, Duration restartDelay, Collection<LogicalServerState> acceptedStates) {
        return this.restartNodes(endpoints, restartDelay, acceptedStates, DynamicConfigService::restartIfPassive);
    }

    private RestartProgress restartNodes(Collection<Node.Endpoint> endpoints, Duration restartDelay, Collection<LogicalServerState> acceptedStates, BiConsumer<DynamicConfigService, Duration> restart) {
        if (restartDelay.getSeconds() < 1L) {
            throw new IllegalArgumentException("Restart delay must be at least 1 second");
        }
        LOGGER.debug("Asking all nodes: {} to restart themselves", endpoints);
        HashMap<Node.Endpoint, DiagnosticService> restartRequested = new HashMap<Node.Endpoint, DiagnosticService>();
        final HashMap<Node.Endpoint, Exception> restartRequestFailed = new HashMap<Node.Endpoint, Exception>();
        for (Node.Endpoint endpoint2 : endpoints) {
            try {
                DiagnosticService diagnosticService2 = this.diagnosticServiceProvider.fetchDiagnosticService(endpoint2.getHostPort().createInetSocketAddress());
                restartRequested.put(endpoint2, diagnosticService2);
                restart.accept((DynamicConfigService)diagnosticService2.getProxy(DynamicConfigService.class), restartDelay);
            }
            catch (Exception e) {
                restartRequestFailed.put(endpoint2, e);
                LOGGER.debug("Failed asking node {} to restart: {}", new Object[]{endpoint2, e.getMessage(), e});
            }
        }
        final CountDownLatch done = new CountDownLatch(restartRequested.size());
        final ConcurrentHashMap restartedNodes = new ConcurrentHashMap();
        final AtomicReference progressCallback = new AtomicReference();
        final AtomicBoolean continuePolling = new AtomicBoolean(true);
        final ExecutorService executorService = Executors.newFixedThreadPool(this.concurrencySizing.getThreadCount(endpoints.size()), r -> new Thread(r, this.getClass().getName()));
        restartRequested.forEach((endpoint, diagnosticService) -> executorService.submit(() -> {
            while (continuePolling.get() && !Thread.currentThread().isInterrupted() && diagnosticService.isConnected()) {
                try {
                    LOGGER.debug("Waiting for node: {} to stop...", endpoint);
                    Thread.sleep(500L);
                }
                catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            }
            if (!diagnosticService.isConnected()) {
                LOGGER.debug("Node: {} has stopped", endpoint);
                LOGGER.debug("Waiting for node: {} to restart...", endpoint);
                LogicalServerState state = null;
                while (state == null && continuePolling.get() && !Thread.currentThread().isInterrupted()) {
                    try {
                        state = this.isRestarted((Node.Endpoint)endpoint, acceptedStates);
                        if (state != null) {
                            LOGGER.debug("Node: {} has restarted", endpoint);
                            restartedNodes.put(endpoint, state);
                            BiConsumer cb = (BiConsumer)progressCallback.get();
                            if (cb != null) {
                                cb.accept(endpoint, state);
                            }
                            done.countDown();
                            continue;
                        }
                        Thread.sleep(500L);
                    }
                    catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                    }
                }
            } else {
                LOGGER.warn("Restart of node: {} has been interrupted", endpoint);
            }
        }));
        return new RestartProgress(){

            @Override
            public void await() throws InterruptedException {
                try {
                    done.await();
                }
                finally {
                    continuePolling.set(false);
                    RestartService.this.shutdown(executorService);
                }
            }

            @Override
            @SuppressFBWarnings(value={"RV_RETURN_VALUE_IGNORED"})
            public Map<Node.Endpoint, LogicalServerState> await(Duration duration) throws InterruptedException {
                try {
                    done.await(duration.toMillis(), TimeUnit.MILLISECONDS);
                    HashMap<Node.Endpoint, LogicalServerState> hashMap = new HashMap<Node.Endpoint, LogicalServerState>(restartedNodes);
                    return hashMap;
                }
                finally {
                    continuePolling.set(false);
                    RestartService.this.shutdown(executorService);
                }
            }

            @Override
            public void onRestarted(BiConsumer<Node.Endpoint, LogicalServerState> c) {
                progressCallback.set(c);
                restartedNodes.forEach(c);
            }

            @Override
            public Map<Node.Endpoint, Exception> getErrors() {
                return restartRequestFailed;
            }
        };
    }

    private void shutdown(ExecutorService executorService) {
        executorService.shutdownNow();
        try {
            if (!executorService.awaitTermination(30L, TimeUnit.SECONDS)) {
                executorService.shutdownNow();
            }
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private LogicalServerState isRestarted(Node.Endpoint endpoint, Collection<LogicalServerState> acceptedStates) {
        LOGGER.debug("Checking if node: {} has restarted", (Object)endpoint);
        try (DiagnosticService diagnosticService = this.diagnosticServiceProvider.fetchDiagnosticService(endpoint.getHostPort().createInetSocketAddress());){
            LogicalServerState state = diagnosticService.getLogicalServerState();
            LogicalServerState logicalServerState = state == null || !acceptedStates.contains(state) ? null : state;
            return logicalServerState;
        }
        catch (DiagnosticServiceProviderException | DiagnosticException e) {
            LOGGER.debug("Status query for node: {} failed: {}", (Object)endpoint, (Object)e.getMessage());
            return null;
        }
        catch (Exception e) {
            LOGGER.error("Unexpected error during status query for node: {}", (Object)endpoint, (Object)e);
            return null;
        }
    }
}

