/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.tensorflow.zoo.cv.objectdetction;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.SingleShotDetectionTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;

public class TfSsdTranslator
extends SingleShotDetectionTranslator {
    private int maxBoxes;
    private float threshHold;

    protected TfSsdTranslator(Builder builder) {
        super((SingleShotDetectionTranslator.Builder)builder);
        this.maxBoxes = builder.maxBoxes;
        this.threshHold = builder.getThreshold();
    }

    public NDList processInput(TranslatorContext ctx, Image input) {
        return new NDList(new NDArray[]{((NDArray)super.processInput(ctx, input).get(0)).expandDims(0)});
    }

    public Batchifier getBatchifier() {
        return null;
    }

    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
        int len = (int)((NDArray)list.get(0)).getShape().get(0);
        float[] scores = new float[len];
        long[] classIds = new long[len];
        NDArray boundingBoxes = (NDArray)list.get(0);
        for (NDArray array : list) {
            DataType dType = array.getDataType();
            int dim = array.getShape().dimension();
            if (dType == DataType.FLOAT32 && dim == 1) {
                scores = array.toFloatArray();
                continue;
            }
            if (dType == DataType.FLOAT32 && dim == 2) {
                boundingBoxes = array;
                continue;
            }
            if (dType == DataType.INT64 && dim == 1) {
                classIds = array.toLongArray();
                continue;
            }
            throw new IllegalStateException("Unexpected result NDArray type:" + dType + ", and dim: " + dim);
        }
        ArrayList<String> retNames = new ArrayList<String>();
        ArrayList<Double> retProbs = new ArrayList<Double>();
        ArrayList<Rectangle> retBB = new ArrayList<Rectangle>();
        for (int i = 0; i < Math.min(classIds.length, this.maxBoxes); ++i) {
            long classId = classIds[i];
            double score = scores[i];
            if (classId < 0L || !(score > (double)this.threshHold)) continue;
            if (classId >= (long)this.classes.size()) {
                throw new AssertionError((Object)("Unexpected index: " + classId));
            }
            String className = (String)this.classes.get((int)classId - 1);
            float[] box = boundingBoxes.get(new long[]{i}).toFloatArray();
            float yMin = box[0];
            float xMin = box[1];
            float yMax = box[2];
            float xMax = box[3];
            double w = xMax - xMin;
            double h = yMax - yMin;
            Rectangle rect = new Rectangle((double)xMin, (double)yMin, w, h);
            retNames.add(className);
            retProbs.add(score);
            retBB.add(rect);
        }
        return new DetectedObjects(retNames, retProbs, retBB);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder
    extends SingleShotDetectionTranslator.Builder {
        private int maxBoxes = 10;

        public Builder optMaxBoxes(int maxBoxes) {
            this.maxBoxes = maxBoxes;
            return this;
        }

        public TfSsdTranslator build() {
            this.validate();
            return new TfSsdTranslator(this);
        }
    }
}

