/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.cv;

import ai.djl.modality.cv.ImageTranslator;
import ai.djl.modality.cv.Joints;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;

public class SimplePoseTranslator
extends ImageTranslator<Joints> {
    private float threshold;

    public SimplePoseTranslator(Builder builder) {
        super(builder);
        this.threshold = builder.threshold;
    }

    @Override
    public Joints processOutput(TranslatorContext ctx, NDList list) {
        NDArray pred = list.singletonOrThrow();
        int numJoints = (int)pred.getShape().get(0);
        int height = (int)pred.getShape().get(1);
        int width = (int)pred.getShape().get(2);
        NDArray predReshaped = pred.reshape(new Shape(1L, numJoints, -1L));
        NDArray maxIndices = predReshaped.argMax(2).reshape(new Shape(1L, numJoints, -1L)).toType(DataType.FLOAT32, false);
        NDArray maxValues = predReshaped.max(new int[]{2}, true);
        NDArray result = maxIndices.tile(2, 2L);
        result.set(new NDIndex(":, :, 0"), result.get(":, :, 0").mod(width));
        result.set(new NDIndex(":, :, 1"), result.get(":, :, 1").div(width).floor());
        NDArray predMask = maxValues.gt(0.0).toType(DataType.UINT8, false).tile(2, 2L).toType(DataType.BOOLEAN, false);
        float[] flattened = result.get(predMask).toFloatArray();
        float[] flattenedConfidence = maxValues.toFloatArray();
        ArrayList<Joints.Joint> joints = new ArrayList<Joints.Joint>(numJoints);
        for (int i = 0; i < numJoints; ++i) {
            if (!(flattenedConfidence[i] > this.threshold)) continue;
            joints.add(new Joints.Joint(flattened[i * 2] / (float)width, flattened[i * 2 + 1] / (float)height, flattenedConfidence[i]));
        }
        return new Joints(joints);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder
    extends ImageTranslator.BaseBuilder<Builder> {
        float threshold;

        Builder() {
        }

        @Override
        protected Builder self() {
            return this;
        }

        public Builder optThreshold(float threshold) {
            this.threshold = threshold;
            return this.self();
        }

        public SimplePoseTranslator build() {
            return new SimplePoseTranslator(this);
        }
    }
}

