package org.nd4j.remote;

import java.util.List;
import lombok.NonNull;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder;
import org.glassfish.jersey.servlet.ServletContainer;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.adapters.InputAdapter;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.remote.clients.serde.BinaryDeserializer;
import org.nd4j.remote.clients.serde.BinarySerializer;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.remote.serving.ModelServingServlet;
import org.nd4j.remote.serving.SameDiffServlet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/remote/SameDiffJsonModelServer.class */
public class SameDiffJsonModelServer<I, O> {
    private static final Logger log = LoggerFactory.getLogger(SameDiffJsonModelServer.class);
    protected SameDiff sdModel;
    protected final JsonSerializer<O> serializer;
    protected final JsonDeserializer<I> deserializer;
    protected final BinarySerializer<O> binarySerializer;
    protected final BinaryDeserializer<I> binaryDeserializer;
    protected final InferenceAdapter<I, O> inferenceAdapter;
    protected final int port;
    protected ModelServingServlet<I, O> servingServlet;
    protected Server server;
    protected String[] orderedInputNodes;
    protected String[] orderedOutputNodes;

    /* loaded from: input_file:org/nd4j/remote/SameDiffJsonModelServer$Builder.class */
    public static class Builder<I, O> {
        private SameDiff sameDiff;
        private String[] orderedInputNodes;
        private String[] orderedOutputNodes;
        private InferenceAdapter<I, O> inferenceAdapter;
        private JsonSerializer<O> serializer;
        private JsonDeserializer<I> deserializer;
        private int port;
        private InputAdapter<I> inputAdapter;
        private OutputAdapter<O> outputAdapter;

        public Builder<I, O> sdModel(@NonNull SameDiff sameDiff) {
            if (sameDiff == null) {
                throw new NullPointerException("sameDiff is marked non-null but is null");
            }
            this.sameDiff = sameDiff;
            return this;
        }

        public Builder<I, O> inferenceAdapter(InferenceAdapter<I, O> inferenceAdapter) {
            this.inferenceAdapter = inferenceAdapter;
            return this;
        }

        public Builder<I, O> inputAdapter(@NonNull InputAdapter<I> inputAdapter) {
            if (inputAdapter == null) {
                throw new NullPointerException("inputAdapter is marked non-null but is null");
            }
            this.inputAdapter = inputAdapter;
            return this;
        }

        public Builder<I, O> outputAdapter(@NonNull OutputAdapter<O> outputAdapter) {
            if (outputAdapter == null) {
                throw new NullPointerException("outputAdapter is marked non-null but is null");
            }
            this.outputAdapter = outputAdapter;
            return this;
        }

        public Builder<I, O> outputSerializer(@NonNull JsonSerializer<O> jsonSerializer) {
            if (jsonSerializer == null) {
                throw new NullPointerException("serializer is marked non-null but is null");
            }
            this.serializer = jsonSerializer;
            return this;
        }

        public Builder<I, O> inputDeserializer(@NonNull JsonDeserializer<I> jsonDeserializer) {
            if (jsonDeserializer == null) {
                throw new NullPointerException("deserializer is marked non-null but is null");
            }
            this.deserializer = jsonDeserializer;
            return this;
        }

        public Builder<I, O> orderedInputNodes(String... strArr) {
            this.orderedInputNodes = strArr;
            return this;
        }

        public Builder<I, O> orderedInputNodes(@NonNull List<String> list) {
            if (list == null) {
                throw new NullPointerException("args is marked non-null but is null");
            }
            this.orderedInputNodes = (String[]) list.toArray(new String[list.size()]);
            return this;
        }

        public Builder<I, O> orderedOutputNodes(String... strArr) {
            Preconditions.checkArgument(strArr != null && strArr.length > 0, "OutputNodes should contain at least 1 element");
            this.orderedOutputNodes = strArr;
            return this;
        }

        public Builder<I, O> orderedOutputNodes(@NonNull List<String> list) {
            if (list == null) {
                throw new NullPointerException("args is marked non-null but is null");
            }
            Preconditions.checkArgument(list.size() > 0, "OutputNodes should contain at least 1 element");
            this.orderedOutputNodes = (String[]) list.toArray(new String[list.size()]);
            return this;
        }

        public Builder<I, O> port(int i) {
            this.port = i;
            return this;
        }

        public SameDiffJsonModelServer<I, O> build() {
            if (this.inferenceAdapter == null) {
                if (this.inputAdapter == null || this.outputAdapter == null) {
                    throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured");
                }
                this.inferenceAdapter = new InferenceAdapter<I, O>() { // from class: org.nd4j.remote.SameDiffJsonModelServer.Builder.1
                    public MultiDataSet apply(I i) {
                        return Builder.this.inputAdapter.apply(i);
                    }

                    public O apply(INDArray... iNDArrayArr) {
                        return (O) Builder.this.outputAdapter.apply(iNDArrayArr);
                    }
                };
            }
            return new SameDiffJsonModelServer<>(this.sameDiff, this.inferenceAdapter, this.serializer, this.deserializer, null, null, this.port, this.orderedInputNodes, this.orderedOutputNodes);
        }
    }

    protected SameDiffJsonModelServer(@NonNull InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer, BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer, int i) {
        if (inferenceAdapter == null) {
            throw new NullPointerException("inferenceAdapter is marked non-null but is null");
        }
        Preconditions.checkArgument(i > 0 && i < 65535, "TCP port must be in range of 0..65535");
        Preconditions.checkArgument((jsonSerializer == null && binarySerializer == null) || (jsonSerializer != null && binarySerializer == null) || (jsonSerializer == null && binarySerializer != null), "JSON and binary serializers/deserializers are mutually exclusive and mandatory.");
        this.binarySerializer = binarySerializer;
        this.binaryDeserializer = binaryDeserializer;
        this.inferenceAdapter = inferenceAdapter;
        this.serializer = jsonSerializer;
        this.deserializer = jsonDeserializer;
        this.port = i;
    }

    public SameDiffJsonModelServer(SameDiff sameDiff, @NonNull InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer, BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer, int i, String[] strArr, @NonNull String[] strArr2) {
        this(inferenceAdapter, jsonSerializer, jsonDeserializer, binarySerializer, binaryDeserializer, i);
        if (inferenceAdapter == null) {
            throw new NullPointerException("inferenceAdapter is marked non-null but is null");
        }
        if (strArr2 == null) {
            throw new NullPointerException("orderedOutputNodes is marked non-null but is null");
        }
        this.sdModel = sameDiff;
        this.orderedInputNodes = strArr;
        this.orderedOutputNodes = strArr2;
        if (strArr != null) {
        }
        Preconditions.checkArgument(strArr2 != null && strArr2.length > 0, "SameDiff serving requires at least 1 output node");
    }

    protected void start(int i, @NonNull ModelServingServlet<I, O> modelServingServlet) throws Exception {
        if (modelServingServlet == null) {
            throw new NullPointerException("servlet is marked non-null but is null");
        }
        ServletContextHandler servletContextHandler = new ServletContextHandler(1);
        servletContextHandler.setContextPath("/");
        this.server = new Server(i);
        this.server.setHandler(servletContextHandler);
        ServletHolder addServlet = servletContextHandler.addServlet(ServletContainer.class, "/*");
        addServlet.setInitOrder(0);
        addServlet.setServlet(modelServingServlet);
        this.server.start();
    }

    public void start() throws Exception {
        Preconditions.checkArgument(this.sdModel != null, "SameDiff model wasn't defined");
        this.servingServlet = SameDiffServlet.builder().sdModel(this.sdModel).serializer(this.serializer).deserializer(this.deserializer).inferenceAdapter(this.inferenceAdapter).orderedInputNodes(this.orderedInputNodes).orderedOutputNodes(this.orderedOutputNodes).build();
        start(this.port, this.servingServlet);
    }

    public void join() throws InterruptedException {
        Preconditions.checkArgument(this.server != null, "Model server wasn't started yet");
        this.server.join();
    }

    public void stop() throws Exception {
        this.server.stop();
    }
}
