/*
 * Decompiled with CFR 0.152.
 */
package com.alipay.sofa.jraft.rpc.impl;

import com.alipay.sofa.jraft.ReplicatorGroup;
import com.alipay.sofa.jraft.entity.PeerId;
import com.alipay.sofa.jraft.error.InvokeTimeoutException;
import com.alipay.sofa.jraft.error.RemotingException;
import com.alipay.sofa.jraft.option.RpcOptions;
import com.alipay.sofa.jraft.rpc.InvokeCallback;
import com.alipay.sofa.jraft.rpc.InvokeContext;
import com.alipay.sofa.jraft.rpc.RpcClient;
import com.alipay.sofa.jraft.rpc.RpcUtils;
import com.alipay.sofa.jraft.rpc.impl.ManagedChannelHelper;
import com.alipay.sofa.jraft.rpc.impl.MarshallerRegistry;
import com.alipay.sofa.jraft.util.DirectExecutor;
import com.alipay.sofa.jraft.util.Endpoint;
import com.alipay.sofa.jraft.util.Requires;
import com.alipay.sofa.jraft.util.SystemPropertyUtil;
import com.google.protobuf.Message;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.ConnectivityState;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.MethodDescriptor;
import io.grpc.protobuf.ProtoUtils;
import io.grpc.stub.ClientCalls;
import io.grpc.stub.StreamObserver;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GrpcClient
implements RpcClient {
    private static final Logger LOG = LoggerFactory.getLogger(GrpcClient.class);
    private static final int MAX_FAILURES = SystemPropertyUtil.getInt((String)"jraft.grpc.max.connect.failures", (int)20);
    private final Map<Endpoint, ManagedChannel> managedChannelPool = new ConcurrentHashMap<Endpoint, ManagedChannel>();
    private final Map<Endpoint, AtomicInteger> transientFailures = new ConcurrentHashMap<Endpoint, AtomicInteger>();
    private final Map<String, Message> parserClasses;
    private final MarshallerRegistry marshallerRegistry;
    private volatile ReplicatorGroup replicatorGroup;

    public GrpcClient(Map<String, Message> parserClasses, MarshallerRegistry marshallerRegistry) {
        this.parserClasses = parserClasses;
        this.marshallerRegistry = marshallerRegistry;
    }

    public boolean init(RpcOptions opts) {
        return true;
    }

    public void shutdown() {
        this.closeAllChannels();
        this.transientFailures.clear();
    }

    public boolean checkConnection(Endpoint endpoint) {
        return this.checkConnection(endpoint, false);
    }

    public boolean checkConnection(Endpoint endpoint, boolean createIfAbsent) {
        Requires.requireNonNull((Object)endpoint, (String)"endpoint");
        return this.checkChannel(endpoint, createIfAbsent);
    }

    public void closeConnection(Endpoint endpoint) {
        Requires.requireNonNull((Object)endpoint, (String)"endpoint");
        this.closeChannel(endpoint);
    }

    public void registerConnectEventListener(ReplicatorGroup replicatorGroup) {
        this.replicatorGroup = replicatorGroup;
    }

    public Object invokeSync(Endpoint endpoint, Object request, InvokeContext ctx, long timeoutMs) throws RemotingException {
        CompletableFuture future = new CompletableFuture();
        this.invokeAsync(endpoint, request, ctx, (result, err) -> {
            if (err == null) {
                future.complete(result);
            } else {
                future.completeExceptionally(err);
            }
        }, timeoutMs);
        try {
            return future.get(timeoutMs, TimeUnit.MILLISECONDS);
        }
        catch (TimeoutException e) {
            future.cancel(true);
            throw new InvokeTimeoutException((Throwable)e);
        }
        catch (Throwable t) {
            future.cancel(true);
            throw new RemotingException(t);
        }
    }

    public void invokeAsync(Endpoint endpoint, Object request, InvokeContext ctx, InvokeCallback callback, long timeoutMs) {
        Requires.requireNonNull((Object)endpoint, (String)"endpoint");
        Requires.requireNonNull((Object)request, (String)"request");
        ManagedChannel ch = this.getChannel(endpoint);
        MethodDescriptor<Message, Message> method = this.getCallMethod(request);
        CallOptions callOpts = CallOptions.DEFAULT.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS);
        Object executor = callback.executor() != null ? callback.executor() : DirectExecutor.INSTANCE;
        ClientCalls.asyncUnaryCall((ClientCall)ch.newCall(method, callOpts), (Object)((Message)request), (StreamObserver)new StreamObserver<Message>((Executor)executor, callback){
            final /* synthetic */ Executor val$executor;
            final /* synthetic */ InvokeCallback val$callback;
            {
                this.val$executor = executor;
                this.val$callback = invokeCallback;
            }

            public void onNext(Message value) {
                this.val$executor.execute(() -> this.val$callback.complete((Object)value, null));
            }

            public void onError(Throwable throwable) {
                this.val$executor.execute(() -> this.val$callback.complete(null, throwable));
            }

            public void onCompleted() {
            }
        });
    }

    private MethodDescriptor<Message, Message> getCallMethod(Object request) {
        String interest = request.getClass().getName();
        Message reqIns = (Message)Requires.requireNonNull((Object)this.parserClasses.get(interest), (String)("null default instance: " + interest));
        return MethodDescriptor.newBuilder().setType(MethodDescriptor.MethodType.UNARY).setFullMethodName(MethodDescriptor.generateFullMethodName((String)interest, (String)"_call")).setRequestMarshaller(ProtoUtils.marshaller((Message)reqIns)).setResponseMarshaller(ProtoUtils.marshaller((Message)this.marshallerRegistry.findResponseInstanceByRequest(interest))).build();
    }

    private ManagedChannel getChannel(Endpoint endpoint) {
        return this.managedChannelPool.computeIfAbsent(endpoint, ep -> {
            ManagedChannel ch = ManagedChannelBuilder.forAddress((String)ep.getIp(), (int)ep.getPort()).usePlaintext().directExecutor().build();
            ch.notifyWhenStateChanged(ConnectivityState.READY, () -> {
                ReplicatorGroup rpGroup = this.replicatorGroup;
                if (rpGroup != null) {
                    try {
                        RpcUtils.runInThread(() -> {
                            PeerId peer = new PeerId();
                            if (peer.parse(ep.toString())) {
                                LOG.info("Peer {} is connected.", (Object)peer);
                                rpGroup.checkReplicator(peer, true);
                            } else {
                                LOG.error("Fail to parse peer: {}.", ep);
                            }
                        });
                    }
                    catch (Throwable t) {
                        LOG.error("Fail to check replicator {}.", ep, (Object)t);
                    }
                }
            });
            ch.notifyWhenStateChanged(ConnectivityState.TRANSIENT_FAILURE, () -> LOG.warn("Channel in TRANSIENT_FAILURE state: {}.", ep));
            ch.notifyWhenStateChanged(ConnectivityState.SHUTDOWN, () -> LOG.warn("Channel in SHUTDOWN state: {}.", ep));
            return ch;
        });
    }

    private void closeAllChannels() {
        for (Map.Entry<Endpoint, ManagedChannel> entry : this.managedChannelPool.entrySet()) {
            ManagedChannel ch = entry.getValue();
            LOG.info("Shutdown managed channel: {}, {}.", (Object)entry.getKey(), (Object)ch);
            ManagedChannelHelper.shutdownAndAwaitTermination(ch);
        }
    }

    private void closeChannel(Endpoint endpoint) {
        ManagedChannel ch = this.managedChannelPool.remove(endpoint);
        LOG.info("Close connection: {}, {}.", (Object)endpoint, (Object)ch);
        if (ch != null) {
            ManagedChannelHelper.shutdownAndAwaitTermination(ch);
        }
    }

    private boolean checkChannel(Endpoint endpoint, boolean createIfAbsent) {
        AtomicInteger num;
        ManagedChannel ch = this.managedChannelPool.get(endpoint);
        if (ch == null && createIfAbsent) {
            ch = this.getChannel(endpoint);
        }
        if (ch == null) {
            return false;
        }
        ConnectivityState st = ch.getState(true);
        if (st == ConnectivityState.TRANSIENT_FAILURE && (num = this.transientFailures.computeIfAbsent(endpoint, ep -> new AtomicInteger())).incrementAndGet() > MAX_FAILURES) {
            this.transientFailures.remove(endpoint);
            LOG.warn("Channel[{}] in {} state {} times, will be reset connect backoff.", new Object[]{endpoint, st, num.get()});
            ch.resetConnectBackoff();
        }
        return st != ConnectivityState.TRANSIENT_FAILURE && st != ConnectivityState.SHUTDOWN;
    }
}

