package org.nd4j.remote.serving;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Map;
import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.NonNull;
import org.apache.commons.lang3.StringUtils;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.remote.clients.serde.BinaryDeserializer;
import org.nd4j.remote.clients.serde.BinarySerializer;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/remote/serving/SameDiffServlet.class */
public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
    private static final Logger log = LoggerFactory.getLogger(SameDiffServlet.class);
    protected static final String typeJson = "application/json";
    protected static final String typeBinary = "application/octet-stream";
    protected SameDiff sdModel;
    protected JsonSerializer<O> serializer;
    protected JsonDeserializer<I> deserializer;
    protected BinarySerializer<O> binarySerializer;
    protected BinaryDeserializer<I> binaryDeserializer;
    protected InferenceAdapter<I, O> inferenceAdapter;
    protected String[] orderedInputNodes;
    protected String[] orderedOutputNodes;
    protected static final String SERVING_ENDPOINT = "/v1/serving";
    protected static final String LISTING_ENDPOINT = "/v1";
    protected static final int PAYLOAD_SIZE_LIMIT = 10485760;

    /* loaded from: input_file:org/nd4j/remote/serving/SameDiffServlet$SameDiffServletBuilder.class */
    public static class SameDiffServletBuilder<I, O> {
        private SameDiff sdModel;
        private JsonSerializer<O> serializer;
        private JsonDeserializer<I> deserializer;
        private BinarySerializer<O> binarySerializer;
        private BinaryDeserializer<I> binaryDeserializer;
        private InferenceAdapter<I, O> inferenceAdapter;
        private String[] orderedInputNodes;
        private String[] orderedOutputNodes;

        SameDiffServletBuilder() {
        }

        public SameDiffServletBuilder<I, O> sdModel(SameDiff sameDiff) {
            this.sdModel = sameDiff;
            return this;
        }

        public SameDiffServletBuilder<I, O> serializer(JsonSerializer<O> jsonSerializer) {
            this.serializer = jsonSerializer;
            return this;
        }

        public SameDiffServletBuilder<I, O> deserializer(JsonDeserializer<I> jsonDeserializer) {
            this.deserializer = jsonDeserializer;
            return this;
        }

        public SameDiffServletBuilder<I, O> binarySerializer(BinarySerializer<O> binarySerializer) {
            this.binarySerializer = binarySerializer;
            return this;
        }

        public SameDiffServletBuilder<I, O> binaryDeserializer(BinaryDeserializer<I> binaryDeserializer) {
            this.binaryDeserializer = binaryDeserializer;
            return this;
        }

        public SameDiffServletBuilder<I, O> inferenceAdapter(InferenceAdapter<I, O> inferenceAdapter) {
            this.inferenceAdapter = inferenceAdapter;
            return this;
        }

        public SameDiffServletBuilder<I, O> orderedInputNodes(String[] strArr) {
            this.orderedInputNodes = strArr;
            return this;
        }

        public SameDiffServletBuilder<I, O> orderedOutputNodes(String[] strArr) {
            this.orderedOutputNodes = strArr;
            return this;
        }

        public SameDiffServlet<I, O> build() {
            return new SameDiffServlet<>(this.sdModel, this.serializer, this.deserializer, this.binarySerializer, this.binaryDeserializer, this.inferenceAdapter, this.orderedInputNodes, this.orderedOutputNodes);
        }

        public String toString() {
            return "SameDiffServlet.SameDiffServletBuilder(sdModel=" + this.sdModel + ", serializer=" + this.serializer + ", deserializer=" + this.deserializer + ", binarySerializer=" + this.binarySerializer + ", binaryDeserializer=" + this.binaryDeserializer + ", inferenceAdapter=" + this.inferenceAdapter + ", orderedInputNodes=" + Arrays.deepToString(this.orderedInputNodes) + ", orderedOutputNodes=" + Arrays.deepToString(this.orderedOutputNodes) + ")";
        }
    }

    protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer) {
        if (inferenceAdapter == null) {
            throw new NullPointerException("inferenceAdapter is marked non-null but is null");
        }
        this.serializer = jsonSerializer;
        this.deserializer = jsonDeserializer;
        this.inferenceAdapter = inferenceAdapter;
    }

    protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter, BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer) {
        if (inferenceAdapter == null) {
            throw new NullPointerException("inferenceAdapter is marked non-null but is null");
        }
        this.binarySerializer = binarySerializer;
        this.binaryDeserializer = binaryDeserializer;
        this.inferenceAdapter = inferenceAdapter;
    }

    protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer, BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer) {
        if (inferenceAdapter == null) {
            throw new NullPointerException("inferenceAdapter is marked non-null but is null");
        }
        this.serializer = jsonSerializer;
        this.deserializer = jsonDeserializer;
        this.binarySerializer = binarySerializer;
        this.binaryDeserializer = binaryDeserializer;
        this.inferenceAdapter = inferenceAdapter;
        if ((this.serializer != null && binarySerializer != null) || (this.serializer == null && binarySerializer == null)) {
            throw new IllegalStateException("Binary and JSON serializers/deserializers are mutually exclusive and mandatory.");
        }
    }

    public void init(ServletConfig servletConfig) throws ServletException {
    }

    public ServletConfig getServletConfig() {
        return null;
    }

    public void service(ServletRequest servletRequest, ServletResponse servletResponse) throws ServletException, IOException {
        HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
        HttpServletResponse httpServletResponse = (HttpServletResponse) servletResponse;
        if (httpServletRequest.getMethod().equals("GET")) {
            doGet(httpServletRequest, httpServletResponse);
        } else if (httpServletRequest.getMethod().equals("POST")) {
            doPost(httpServletRequest, httpServletResponse);
        }
    }

    protected void sendError(String str, HttpServletResponse httpServletResponse) throws IOException {
        String str2 = "Requested endpoint [" + str + "] not found";
        httpServletResponse.setStatus(404, str2);
        httpServletResponse.sendError(404, str2);
    }

    protected void sendBadContentType(String str, HttpServletResponse httpServletResponse) throws IOException {
        String str2 = "Content type [" + str + "] not supported";
        httpServletResponse.setStatus(415, str2);
        httpServletResponse.sendError(415, str2);
    }

    protected boolean validateRequest(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) throws IOException {
        String contentType = httpServletRequest.getContentType();
        if (StringUtils.equals(contentType, typeJson)) {
            return true;
        }
        sendBadContentType(contentType, httpServletResponse);
        if (httpServletRequest.getContentLength() <= PAYLOAD_SIZE_LIMIT) {
            return false;
        }
        httpServletResponse.sendError(500, "Payload size limit violated!");
        return false;
    }

    protected void doGet(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) throws IOException {
        ServingProcessor servingProcessor = new ServingProcessor();
        String str = "";
        if (httpServletRequest.getPathInfo().equals(LISTING_ENDPOINT)) {
            String contentType = httpServletRequest.getContentType();
            if (!StringUtils.equals(contentType, typeJson)) {
                sendBadContentType(contentType, httpServletResponse);
            }
            str = servingProcessor.listEndpoints();
        } else {
            sendError(httpServletRequest.getRequestURI(), httpServletResponse);
        }
        try {
            httpServletResponse.getWriter().write(str);
        } catch (IOException e) {
            log.error(e.getMessage());
        }
    }

    protected void doPost(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) throws IOException {
        new ServingProcessor();
        String str = "";
        if (httpServletRequest.getPathInfo().equals(SERVING_ENDPOINT)) {
            httpServletRequest.getContentType();
            if (validateRequest(httpServletRequest, httpServletResponse)) {
                BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(httpServletRequest.getInputStream()));
                char[] cArr = new char[128];
                StringBuilder sb = new StringBuilder();
                while (true) {
                    int read = bufferedReader.read(cArr);
                    if (read <= 0) {
                        break;
                    } else {
                        sb.append(cArr, 0, read);
                    }
                }
                MultiDataSet apply = this.inferenceAdapter.apply(this.deserializer.deserialize(sb.toString()));
                LinkedHashMap linkedHashMap = new LinkedHashMap();
                if (this.orderedInputNodes != null && this.orderedInputNodes.length > 0) {
                    int i = 0;
                    for (String str2 : this.orderedInputNodes) {
                        int i2 = i;
                        i++;
                        linkedHashMap.put(str2, apply.getFeatures(i2));
                    }
                }
                Map output = this.sdModel.output(linkedHashMap, this.orderedOutputNodes);
                INDArray[] iNDArrayArr = new INDArray[output.size()];
                int i3 = 0;
                for (String str3 : this.orderedOutputNodes) {
                    int i4 = i3;
                    i3++;
                    iNDArrayArr[i4] = (INDArray) output.get(str3);
                }
                str = this.serializer.serialize(this.inferenceAdapter.apply(iNDArrayArr));
            }
        } else {
            sendError(httpServletRequest.getRequestURI(), httpServletResponse);
        }
        try {
            httpServletResponse.getWriter().write(str);
        } catch (IOException e) {
            log.error(e.getMessage());
        }
    }

    public String getServletInfo() {
        return null;
    }

    public void destroy() {
    }

    public static <I, O> SameDiffServletBuilder<I, O> builder() {
        return new SameDiffServletBuilder<>();
    }

    public SameDiffServlet() {
    }

    public SameDiffServlet(SameDiff sameDiff, JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer, BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer, InferenceAdapter<I, O> inferenceAdapter, String[] strArr, String[] strArr2) {
        this.sdModel = sameDiff;
        this.serializer = jsonSerializer;
        this.deserializer = jsonDeserializer;
        this.binarySerializer = binarySerializer;
        this.binaryDeserializer = binaryDeserializer;
        this.inferenceAdapter = inferenceAdapter;
        this.orderedInputNodes = strArr;
        this.orderedOutputNodes = strArr2;
    }
}
