/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.DataType;
import ai.djl.util.JsonSerializable;
import ai.djl.util.JsonUtils;
import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonSerializationContext;
import com.google.gson.JsonSerializer;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;

public class Classifications
implements JsonSerializable {
    private static final long serialVersionUID = 1L;
    private static final Gson GSON = JsonUtils.builder().registerTypeAdapter(Classifications.class, (Object)new ClassificationsSerializer()).create();
    protected List<String> classNames;
    protected List<Double> probabilities;

    public Classifications(List<String> classNames, List<Double> probabilities) {
        this.classNames = classNames;
        this.probabilities = probabilities;
    }

    public Classifications(List<String> classNames, NDArray probabilities) {
        this.classNames = classNames;
        NDArray array = probabilities.toType(DataType.FLOAT64, false);
        this.probabilities = Arrays.stream(array.toDoubleArray()).boxed().collect(Collectors.toList());
        array.close();
    }

    public <T extends Classification> List<T> items() {
        ArrayList<T> list = new ArrayList<T>(this.classNames.size());
        for (int i = 0; i < this.classNames.size(); ++i) {
            list.add(this.item(i));
        }
        return list;
    }

    public <T extends Classification> T item(int index) {
        return (T)new Classification(this.classNames.get(index), this.probabilities.get(index));
    }

    public <T extends Classification> List<T> topK(int k) {
        List<T> items = this.items();
        items.sort(Comparator.comparingDouble(Classification::getProbability).reversed());
        int count = Math.min(items.size(), k);
        return items.subList(0, count);
    }

    public <T extends Classification> T best() {
        return this.item(this.probabilities.indexOf(Collections.max(this.probabilities)));
    }

    public <T extends Classification> T get(String className) {
        int size = this.classNames.size();
        for (int i = 0; i < size; ++i) {
            if (!this.classNames.get(i).equals(className)) continue;
            return this.item(i);
        }
        return null;
    }

    @Override
    public String toJson() {
        return GSON.toJson((Object)this);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append('[').append(System.lineSeparator());
        for (Classification item : this.topK(5)) {
            sb.append('\t').append(item).append(System.lineSeparator());
        }
        sb.append(']');
        return sb.toString();
    }

    public static final class ClassificationsSerializer
    implements JsonSerializer<Classifications> {
        public JsonElement serialize(Classifications src, Type type, JsonSerializationContext ctx) {
            List list = src.topK(5);
            return ctx.serialize(list);
        }
    }

    public static class Classification {
        private String className;
        private double probability;

        public Classification(String className, double probability) {
            this.className = className;
            this.probability = probability;
        }

        public String getClassName() {
            return this.className;
        }

        public double getProbability() {
            return this.probability;
        }

        public String toString() {
            if (this.probability < 1.0E-5) {
                return String.format("class: \"%s\", probability: %.1e", this.className, this.probability);
            }
            this.probability = (float)((int)(this.probability * 100000.0)) / 100000.0f;
            return String.format("class: \"%s\", probability: %.5f", this.className, this.probability);
        }
    }
}

