/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.Ensembleable;
import ai.djl.util.JsonSerializable;
import ai.djl.util.JsonUtils;
import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonSerializationContext;
import com.google.gson.JsonSerializer;
import java.lang.reflect.Type;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;

public class Classifications
implements JsonSerializable,
Ensembleable<Classifications> {
    private static final long serialVersionUID = 1L;
    private static final Gson GSON = JsonUtils.builder().registerTypeAdapter(Classifications.class, (Object)new ClassificationsSerializer()).create();
    protected List<String> classNames;
    protected List<Double> probabilities;
    protected int topK;

    public Classifications(List<String> classNames, List<Double> probabilities) {
        this.classNames = classNames;
        this.probabilities = probabilities;
        this.topK = 5;
    }

    public Classifications(List<String> classNames, NDArray probabilities) {
        this(classNames, probabilities, 5);
    }

    public Classifications(List<String> classNames, NDArray probabilities, int topK) {
        this.classNames = classNames;
        if (probabilities.getDataType() == DataType.FLOAT32) {
            this.probabilities = new ArrayList<Double>();
            for (float prob : probabilities.toFloatArray()) {
                this.probabilities.add(Double.valueOf(prob));
            }
        } else {
            NDArray array = probabilities.toType(DataType.FLOAT64, false);
            this.probabilities = Arrays.stream(array.toDoubleArray()).boxed().collect(Collectors.toList());
            array.close();
        }
        this.topK = topK;
    }

    public List<String> getClassNames() {
        return this.classNames;
    }

    public List<Double> getProbabilities() {
        return this.probabilities;
    }

    public final void setTopK(int topK) {
        this.topK = topK;
    }

    public <T extends Classification> List<T> items() {
        ArrayList<T> list = new ArrayList<T>(this.classNames.size());
        for (int i = 0; i < this.classNames.size(); ++i) {
            list.add(this.item(i));
        }
        return list;
    }

    public <T extends Classification> T item(int index) {
        return (T)new Classification(this.classNames.get(index), this.probabilities.get(index));
    }

    public <T extends Classification> List<T> topK() {
        return this.topK(this.topK);
    }

    public <T extends Classification> List<T> topK(int k) {
        List<T> items = this.items();
        items.sort(Comparator.comparingDouble(Classification::getProbability).reversed());
        int count = Math.min(items.size(), k);
        return items.subList(0, count);
    }

    public <T extends Classification> T best() {
        return this.item(this.probabilities.indexOf(Collections.max(this.probabilities)));
    }

    public <T extends Classification> T get(String className) {
        int size = this.classNames.size();
        for (int i = 0; i < size; ++i) {
            if (!this.classNames.get(i).equals(className)) continue;
            return this.item(i);
        }
        return null;
    }

    @Override
    public String toJson() {
        return GSON.toJson((Object)this) + '\n';
    }

    @Override
    public String getAsString() {
        return this.toJson();
    }

    @Override
    public ByteBuffer toByteBuffer() {
        return ByteBuffer.wrap(this.toJson().getBytes(StandardCharsets.UTF_8));
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append('[').append(System.lineSeparator());
        for (Classification item : this.topK(this.topK)) {
            sb.append('\t').append(item).append(System.lineSeparator());
        }
        sb.append(']');
        return sb.toString();
    }

    @Override
    public Classifications ensembleWith(Iterator<Classifications> it) {
        int size = this.probabilities.size();
        ArrayList<Double> newProbabilities = new ArrayList<Double>(size);
        newProbabilities.addAll(this.probabilities);
        int count = 1;
        while (it.hasNext()) {
            ++count;
            Classifications c = it.next();
            for (int i = 0; i < size; ++i) {
                newProbabilities.set(i, (Double)newProbabilities.get(i) + c.probabilities.get(i));
            }
            if (c.classNames.equals(this.classNames)) continue;
            throw new IllegalArgumentException("Found a classNames mismatch during ensembling. All input Classifications should have the same classNames, but some were different");
        }
        int total = count;
        newProbabilities.replaceAll(p -> p / (double)total);
        return new Classifications(this.classNames, newProbabilities);
    }

    public static class Classification {
        private String className;
        private double probability;

        public Classification(String className, double probability) {
            this.className = className;
            this.probability = probability;
        }

        public String getClassName() {
            return this.className;
        }

        public double getProbability() {
            return this.probability;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder(100);
            sb.append("{\"class\": \"").append(this.className).append("\", \"probability\": ");
            if (this.probability < 1.0E-5) {
                sb.append(String.format("%.1e", this.probability));
            } else {
                this.probability = (float)((int)(this.probability * 100000.0)) / 100000.0f;
                sb.append(String.format("%.5f", this.probability));
            }
            sb.append('}');
            return sb.toString();
        }
    }

    public static final class ClassificationsSerializer
    implements JsonSerializer<Classifications> {
        public JsonElement serialize(Classifications src, Type type, JsonSerializationContext ctx) {
            List list = src.topK();
            return ctx.serialize(list);
        }
    }
}

