/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.tensorflow.zoo.cv.objectdetction;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.ObjectDetectionTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;

public class TfSsdTranslator
extends ObjectDetectionTranslator {
    private int maxBoxes;
    private String numDetectionsOutputName;
    private String boundingBoxOutputName;
    private String scoresOutputName;
    private String classLabelOutputName;

    protected TfSsdTranslator(Builder builder) {
        super((ObjectDetectionTranslator.ObjectDetectionBuilder)builder);
        this.maxBoxes = builder.maxBoxes;
        this.numDetectionsOutputName = builder.numDetectionsOutputName;
        this.boundingBoxOutputName = builder.boundingBoxOutputName;
        this.scoresOutputName = builder.scoresOutputName;
        this.classLabelOutputName = builder.classLabelOutputName;
    }

    public NDList processInput(TranslatorContext ctx, Image input) {
        return new NDList(new NDArray[]{((NDArray)super.processInput(ctx, input).get(0)).expandDims(0)});
    }

    public Batchifier getBatchifier() {
        return null;
    }

    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
        int len = (int)((NDArray)list.get(0)).getShape().get(0);
        for (NDArray array : list) {
            if (!this.numDetectionsOutputName.equals(array.getName())) continue;
            len = array.toArray()[0].intValue();
            break;
        }
        float[] scores = new float[len];
        long[] classIds = new long[len];
        NDArray boundingBoxes = (NDArray)list.get(0);
        for (NDArray array : list) {
            if (this.scoresOutputName.equals(array.getName())) {
                scores = array.toFloatArray();
                continue;
            }
            if (this.boundingBoxOutputName.equals(array.getName())) {
                boundingBoxes = array;
                continue;
            }
            if (!this.classLabelOutputName.equals(array.getName())) continue;
            classIds = Arrays.stream(array.toArray()).mapToLong(Number::longValue).toArray();
        }
        ArrayList<String> retNames = new ArrayList<String>();
        ArrayList<Double> retProbs = new ArrayList<Double>();
        ArrayList<Rectangle> retBB = new ArrayList<Rectangle>();
        for (int i = 0; i < Math.min(classIds.length, this.maxBoxes); ++i) {
            long classId = classIds[i];
            double score = scores[i];
            if (classId < 0L || !(score > (double)this.threshold)) continue;
            if (classId >= (long)this.classes.size()) {
                throw new AssertionError((Object)("Unexpected index: " + classId));
            }
            String className = (String)this.classes.get((int)classId - 1);
            float[] box = boundingBoxes.get(new long[]{i}).toFloatArray();
            float yMin = box[0];
            float xMin = box[1];
            float yMax = box[2];
            float xMax = box[3];
            double w = xMax - xMin;
            double h = yMax - yMin;
            Rectangle rect = new Rectangle((double)xMin, (double)yMin, w, h);
            retNames.add(className);
            retProbs.add(score);
            retBB.add(rect);
        }
        return new DetectedObjects(retNames, retProbs, retBB);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> arguments) {
        Builder builder = new Builder();
        builder.configPreProcess(arguments);
        builder.configPostProcess(arguments);
        return builder;
    }

    public static class Builder
    extends ObjectDetectionTranslator.ObjectDetectionBuilder<Builder> {
        int maxBoxes = 10;
        String numDetectionsOutputName = "num_detections";
        String boundingBoxOutputName = "detection_boxes";
        String scoresOutputName = "detection_scores";
        String classLabelOutputName = "detection_class_labels";

        public Builder optNumDetectionsOutputName(String numDetectionsOutputName) {
            this.numDetectionsOutputName = numDetectionsOutputName;
            return this;
        }

        public Builder optBoundingBoxOutputName(String boundingBoxOutputName) {
            this.boundingBoxOutputName = boundingBoxOutputName;
            return this;
        }

        public Builder optScoresOutputName(String scoresOutputName) {
            this.scoresOutputName = scoresOutputName;
            return this;
        }

        public Builder optClassLabelOutputName(String classLabelOutputName) {
            this.classLabelOutputName = classLabelOutputName;
            return this;
        }

        public Builder optMaxBoxes(int maxBoxes) {
            this.maxBoxes = maxBoxes;
            return this;
        }

        protected Builder self() {
            return this;
        }

        protected void configPreProcess(Map<String, ?> arguments) {
            super.configPreProcess(arguments);
        }

        protected void configPostProcess(Map<String, ?> arguments) {
            super.configPostProcess(arguments);
            this.maxBoxes = TfSsdTranslator.getIntValue((Map)arguments, (String)"maxBoxes", (int)10);
            this.threshold = TfSsdTranslator.getFloatValue((Map)arguments, (String)"threshold", (float)0.4f);
            this.numDetectionsOutputName = TfSsdTranslator.getStringValue((Map)arguments, (String)"numDetectionsOutputName", (String)"num_detections");
            this.boundingBoxOutputName = TfSsdTranslator.getStringValue((Map)arguments, (String)"boundingBoxOutputName", (String)"detection_boxes");
            this.scoresOutputName = TfSsdTranslator.getStringValue((Map)arguments, (String)"scoresOutputName", (String)"detection_scores");
            this.classLabelOutputName = TfSsdTranslator.getStringValue((Map)arguments, (String)"classLabelOutputName", (String)"detection_class_labels");
        }

        public TfSsdTranslator build() {
            this.validate();
            return new TfSsdTranslator(this);
        }
    }
}

