/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.transport;

import com.google.common.annotations.VisibleForTesting;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import org.apache.cassandra.service.QueryState;
import org.apache.cassandra.service.StorageService;
import org.apache.cassandra.tracing.Tracing;
import org.apache.cassandra.transport.CBCodec;
import org.apache.cassandra.transport.CBUtil;
import org.apache.cassandra.transport.Connection;
import org.apache.cassandra.transport.Envelope;
import org.apache.cassandra.transport.ProtocolException;
import org.apache.cassandra.transport.ProtocolVersion;
import org.apache.cassandra.transport.messages.AuthChallenge;
import org.apache.cassandra.transport.messages.AuthResponse;
import org.apache.cassandra.transport.messages.AuthSuccess;
import org.apache.cassandra.transport.messages.AuthenticateMessage;
import org.apache.cassandra.transport.messages.BatchMessage;
import org.apache.cassandra.transport.messages.ErrorMessage;
import org.apache.cassandra.transport.messages.EventMessage;
import org.apache.cassandra.transport.messages.ExecuteMessage;
import org.apache.cassandra.transport.messages.OptionsMessage;
import org.apache.cassandra.transport.messages.PrepareMessage;
import org.apache.cassandra.transport.messages.QueryMessage;
import org.apache.cassandra.transport.messages.ReadyMessage;
import org.apache.cassandra.transport.messages.RegisterMessage;
import org.apache.cassandra.transport.messages.ResultMessage;
import org.apache.cassandra.transport.messages.StartupMessage;
import org.apache.cassandra.transport.messages.SupportedMessage;
import org.apache.cassandra.transport.messages.UnsupportedMessageCodec;
import org.apache.cassandra.utils.TimeUUID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class Message {
    protected static final Logger logger = LoggerFactory.getLogger(Message.class);
    public final Type type;
    protected Connection connection;
    private int streamId;
    private Envelope source;
    private Map<String, ByteBuffer> customPayload;
    protected ProtocolVersion forcedProtocolVersion = null;
    private static final Decoder.RequestDecoder REQUEST_DECODER = new Decoder.RequestDecoder();
    private static final Decoder.ResponseDecoder RESPONSE_DECODER = new Decoder.ResponseDecoder();

    protected Message(Type type) {
        this.type = type;
    }

    public void attach(Connection connection) {
        this.connection = connection;
    }

    public Connection connection() {
        return this.connection;
    }

    public Message setStreamId(int streamId) {
        this.streamId = streamId;
        return this;
    }

    public int getStreamId() {
        return this.streamId;
    }

    public void setSource(Envelope source) {
        this.source = source;
    }

    public Envelope getSource() {
        return this.source;
    }

    public Map<String, ByteBuffer> getCustomPayload() {
        return this.customPayload;
    }

    public void setCustomPayload(Map<String, ByteBuffer> customPayload) {
        this.customPayload = customPayload;
    }

    public String toString() {
        return String.format("(%s:%s:%s)", new Object[]{this.type, this.streamId, this.connection == null ? "null" : Integer.valueOf(this.connection.getVersion().asInt())});
    }

    public Envelope encode(ProtocolVersion version) {
        EnumSet<Envelope.Header.Flag> flags = EnumSet.noneOf(Envelope.Header.Flag.class);
        Codec<?> codec = this.type.codec;
        try {
            ProtocolVersion responseVersion;
            ByteBuf body;
            int messageSize = codec.encodedSize(this, version);
            if (this instanceof Response) {
                List<String> warnings;
                Response message = (Response)this;
                TimeUUID tracingId = message.getTracingId();
                Map<String, ByteBuffer> customPayload = message.getCustomPayload();
                if (tracingId != null) {
                    messageSize = (int)((long)messageSize + TimeUUID.sizeInBytes());
                }
                if ((warnings = message.getWarnings()) != null) {
                    if (version.isSmallerThan(ProtocolVersion.V4)) {
                        logger.warn("Warnings present in message with version less than v4 (it is {}); warnings={}", (Object)version, warnings);
                        warnings = null;
                    } else {
                        messageSize += CBUtil.sizeOfStringList(warnings);
                    }
                }
                if (customPayload != null) {
                    if (version.isSmallerThan(ProtocolVersion.V4)) {
                        throw new ProtocolException("Must not send frame with CUSTOM_PAYLOAD flag for native protocol version < 4");
                    }
                    messageSize += CBUtil.sizeOfBytesMap(customPayload);
                }
                body = CBUtil.allocator.buffer(messageSize);
                if (tracingId != null) {
                    CBUtil.writeUUID(tracingId, body);
                    flags.add(Envelope.Header.Flag.TRACING);
                }
                if (warnings != null) {
                    CBUtil.writeStringList(warnings, body);
                    flags.add(Envelope.Header.Flag.WARNING);
                }
                if (customPayload != null) {
                    CBUtil.writeBytesMap(customPayload, body);
                    flags.add(Envelope.Header.Flag.CUSTOM_PAYLOAD);
                }
            } else {
                Map<String, ByteBuffer> payload;
                assert (this instanceof Request);
                if (((Request)this).isTracingRequested()) {
                    flags.add(Envelope.Header.Flag.TRACING);
                }
                if ((payload = this.getCustomPayload()) != null) {
                    messageSize += CBUtil.sizeOfBytesMap(payload);
                }
                body = CBUtil.allocator.buffer(messageSize);
                if (payload != null) {
                    CBUtil.writeBytesMap(payload, body);
                    flags.add(Envelope.Header.Flag.CUSTOM_PAYLOAD);
                }
            }
            try {
                codec.encode(this, body, version);
            }
            catch (Throwable e) {
                body.release();
                throw e;
            }
            ProtocolVersion protocolVersion = responseVersion = this.forcedProtocolVersion == null ? version : this.forcedProtocolVersion;
            if (responseVersion.isBeta()) {
                flags.add(Envelope.Header.Flag.USE_BETA);
            }
            return Envelope.create(this.type, this.getStreamId(), responseVersion, flags, body);
        }
        catch (Throwable e) {
            throw ErrorMessage.wrap(e, this.getStreamId());
        }
    }

    static Decoder<Request> requestDecoder() {
        return REQUEST_DECODER;
    }

    static Decoder<Response> responseDecoder() {
        return RESPONSE_DECODER;
    }

    static abstract class Decoder<M extends Message> {
        Decoder() {
        }

        static Message decodeMessage(Channel channel, Envelope inbound) {
            Map<String, ByteBuffer> customPayload;
            boolean isRequest = inbound.header.type.direction == Direction.REQUEST;
            boolean isTracing = inbound.header.flags.contains((Object)Envelope.Header.Flag.TRACING);
            boolean isCustomPayload = inbound.header.flags.contains((Object)Envelope.Header.Flag.CUSTOM_PAYLOAD);
            boolean hasWarning = inbound.header.flags.contains((Object)Envelope.Header.Flag.WARNING);
            TimeUUID tracingId = isRequest || !isTracing ? null : CBUtil.readTimeUUID(inbound.body);
            List<String> warnings = isRequest || !hasWarning ? null : CBUtil.readStringList(inbound.body);
            Map<String, ByteBuffer> map = customPayload = !isCustomPayload ? null : CBUtil.readBytesMap(inbound.body);
            if (isCustomPayload && inbound.header.version.isSmallerThan(ProtocolVersion.V4)) {
                throw new ProtocolException("Received frame with CUSTOM_PAYLOAD flag for native protocol version < 4");
            }
            Message message = (Message)inbound.header.type.codec.decode(inbound.body, inbound.header.version);
            message.setStreamId(inbound.header.streamId);
            message.setSource(inbound);
            message.setCustomPayload(customPayload);
            if (isRequest) {
                assert (message instanceof Request);
                Request req = (Request)message;
                Connection connection = (Connection)channel.attr(Connection.attributeKey).get();
                req.attach(connection);
                if (isTracing) {
                    req.setTracingRequested();
                }
            } else {
                assert (message instanceof Response);
                if (isTracing) {
                    ((Response)message).setTracingId(tracingId);
                }
                if (hasWarning) {
                    ((Response)message).setWarnings(warnings);
                }
            }
            return message;
        }

        abstract M decode(Channel var1, Envelope var2);

        private static class ResponseDecoder
        extends Decoder<Response> {
            private ResponseDecoder() {
            }

            @Override
            Response decode(Channel channel, Envelope response) {
                if (response.header.type.direction != Direction.RESPONSE) {
                    throw new ProtocolException(String.format("Unexpected REQUEST message %s, expecting RESPONSE", new Object[]{response.header.type}));
                }
                return (Response)ResponseDecoder.decodeMessage(channel, response);
            }
        }

        private static class RequestDecoder
        extends Decoder<Request> {
            private RequestDecoder() {
            }

            @Override
            Request decode(Channel channel, Envelope request) {
                if (request.header.type.direction != Direction.REQUEST) {
                    throw new ProtocolException(String.format("Unexpected RESPONSE message %s, expecting REQUEST", new Object[]{request.header.type}));
                }
                return (Request)RequestDecoder.decodeMessage(channel, request);
            }
        }
    }

    public static abstract class Response
    extends Message {
        protected TimeUUID tracingId;
        protected List<String> warnings;

        protected Response(Type type) {
            super(type);
            if (type.direction != Direction.RESPONSE) {
                throw new IllegalArgumentException();
            }
        }

        Message setTracingId(TimeUUID tracingId) {
            this.tracingId = tracingId;
            return this;
        }

        TimeUUID getTracingId() {
            return this.tracingId;
        }

        public Message setWarnings(List<String> warnings) {
            this.warnings = warnings;
            return this;
        }

        public List<String> getWarnings() {
            return this.warnings;
        }
    }

    public static abstract class Request
    extends Message {
        private boolean tracingRequested;

        protected Request(Type type) {
            super(type);
            if (type.direction != Direction.REQUEST) {
                throw new IllegalArgumentException();
            }
        }

        protected boolean isTraceable() {
            return false;
        }

        protected boolean isTrackable() {
            return false;
        }

        protected abstract Response execute(QueryState var1, long var2, boolean var4);

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public final Response execute(QueryState queryState, long queryStartNanoTime) {
            Response response;
            boolean shouldTrace = false;
            TimeUUID tracingSessionId = null;
            if (this.isTraceable()) {
                if (this.isTracingRequested()) {
                    shouldTrace = true;
                    tracingSessionId = TimeUUID.Generator.nextTimeUUID();
                    Tracing.instance.newSession(tracingSessionId, this.getCustomPayload());
                } else if (StorageService.instance.shouldTraceProbablistically()) {
                    shouldTrace = true;
                    Tracing.instance.newSession(this.getCustomPayload());
                }
            }
            try {
                response = this.execute(queryState, queryStartNanoTime, shouldTrace);
            }
            finally {
                if (shouldTrace) {
                    Tracing.instance.stopSession();
                }
            }
            if (this.isTraceable() && this.isTracingRequested()) {
                response.setTracingId(tracingSessionId);
            }
            return response;
        }

        void setTracingRequested() {
            this.tracingRequested = true;
        }

        boolean isTracingRequested() {
            return this.tracingRequested;
        }
    }

    public static enum Type {
        ERROR(0, Direction.RESPONSE, ErrorMessage.codec),
        STARTUP(1, Direction.REQUEST, StartupMessage.codec),
        READY(2, Direction.RESPONSE, ReadyMessage.codec),
        AUTHENTICATE(3, Direction.RESPONSE, AuthenticateMessage.codec),
        CREDENTIALS(4, Direction.REQUEST, UnsupportedMessageCodec.instance),
        OPTIONS(5, Direction.REQUEST, OptionsMessage.codec),
        SUPPORTED(6, Direction.RESPONSE, SupportedMessage.codec),
        QUERY(7, Direction.REQUEST, QueryMessage.codec),
        RESULT(8, Direction.RESPONSE, ResultMessage.codec),
        PREPARE(9, Direction.REQUEST, PrepareMessage.codec),
        EXECUTE(10, Direction.REQUEST, ExecuteMessage.codec),
        REGISTER(11, Direction.REQUEST, RegisterMessage.codec),
        EVENT(12, Direction.RESPONSE, EventMessage.codec),
        BATCH(13, Direction.REQUEST, BatchMessage.codec),
        AUTH_CHALLENGE(14, Direction.RESPONSE, AuthChallenge.codec),
        AUTH_RESPONSE(15, Direction.REQUEST, AuthResponse.codec),
        AUTH_SUCCESS(16, Direction.RESPONSE, AuthSuccess.codec);

        public final int opcode;
        public final Direction direction;
        public final Codec<?> codec;
        private static final Type[] opcodeIdx;

        private Type(int opcode, Direction direction, Codec<?> codec) {
            this.opcode = opcode;
            this.direction = direction;
            this.codec = codec;
        }

        public static Type fromOpcode(int opcode, Direction direction) {
            if (opcode >= opcodeIdx.length) {
                throw new ProtocolException(String.format("Unknown opcode %d", opcode));
            }
            Type t = opcodeIdx[opcode];
            if (t == null) {
                throw new ProtocolException(String.format("Unknown opcode %d", opcode));
            }
            if (t.direction != direction) {
                throw new ProtocolException(String.format("Wrong protocol direction (expected %s, got %s) for opcode %d (%s)", new Object[]{t.direction, direction, opcode, t}));
            }
            return t;
        }

        @VisibleForTesting
        public Codec<?> unsafeSetCodec(Codec<?> codec) throws NoSuchFieldException, IllegalAccessException {
            Codec<?> original = this.codec;
            Field field = Type.class.getDeclaredField("codec");
            field.setAccessible(true);
            Field modifiers = Field.class.getDeclaredField("modifiers");
            modifiers.setAccessible(true);
            modifiers.setInt(field, field.getModifiers() & 0xFFFFFFEF);
            field.set((Object)this, codec);
            return original;
        }

        static {
            int maxOpcode = -1;
            for (Type type : Type.values()) {
                maxOpcode = Math.max(maxOpcode, type.opcode);
            }
            opcodeIdx = new Type[maxOpcode + 1];
            for (Type type : Type.values()) {
                if (opcodeIdx[type.opcode] != null) {
                    throw new IllegalStateException("Duplicate opcode");
                }
                Type.opcodeIdx[type.opcode] = type;
            }
        }
    }

    public static enum Direction {
        REQUEST,
        RESPONSE;


        public static Direction extractFromVersion(int versionWithDirection) {
            return (versionWithDirection & 0x80) == 0 ? REQUEST : RESPONSE;
        }

        public int addToVersion(int rawVersion) {
            return this == REQUEST ? rawVersion & 0x7F : rawVersion | 0x80;
        }
    }

    public static interface Codec<M extends Message>
    extends CBCodec<M> {
    }
}

