/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.huggingface.translator;

import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.translator.ZeroShotClassificationInput;
import ai.djl.modality.nlp.translator.ZeroShotClassificationOutput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Batchifier;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.JsonUtils;
import ai.djl.util.Pair;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParseException;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Locale;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ZeroShotClassificationTranslator
implements NoBatchifyTranslator<ZeroShotClassificationInput, ZeroShotClassificationOutput> {
    private static final Logger logger = LoggerFactory.getLogger(ZeroShotClassificationTranslator.class);
    private HuggingFaceTokenizer tokenizer;
    private int entailmentId;
    private int contradictionId;
    private boolean tokenTypeId;
    private boolean int32;
    private Predictor<NDList, NDList> predictor;

    ZeroShotClassificationTranslator(HuggingFaceTokenizer tokenizer, boolean tokenTypeId, boolean int32) {
        this.tokenizer = tokenizer;
        this.tokenTypeId = tokenTypeId;
        this.int32 = int32;
    }

    ZeroShotClassificationTranslator(HuggingFaceTokenizer tokenizer, boolean tokenTypeId, boolean int32, int entailmentId, int contradictionId) {
        this(tokenizer, tokenTypeId, int32);
        this.entailmentId = entailmentId;
        this.contradictionId = contradictionId;
    }

    public void prepare(TranslatorContext ctx) throws IOException, ModelException {
        Model model = ctx.getModel();
        this.predictor = model.newPredictor((Translator)new NoopTranslator(null));
        ctx.getPredictorManager().attachInternal(NDManager.nextUid(), new AutoCloseable[]{this.predictor});
        Path configFile = model.getModelPath().resolve("config.json");
        if (!Files.isRegularFile(configFile, new LinkOption[0])) {
            return;
        }
        try (BufferedReader reader = Files.newBufferedReader(configFile);){
            String modelType;
            int typeVocabSize;
            JsonElement typeVocabSizeObj;
            JsonObject config = (JsonObject)JsonUtils.GSON.fromJson((Reader)reader, JsonObject.class);
            if (config.has("label2id")) {
                JsonObject label2Id = config.getAsJsonObject("label2id");
                for (Map.Entry entry : label2Id.entrySet()) {
                    String key = ((String)entry.getKey()).toLowerCase(Locale.ROOT);
                    int value = ((JsonElement)entry.getValue()).getAsInt();
                    if (key.startsWith("entail")) {
                        this.entailmentId = value;
                        continue;
                    }
                    if (!key.startsWith("contra")) continue;
                    this.contradictionId = value;
                }
            }
            boolean inferredWithTokenType = false;
            if (config.has("type_vocab_size") && (typeVocabSizeObj = config.get("type_vocab_size")).isJsonPrimitive() && (typeVocabSize = typeVocabSizeObj.getAsInt()) > 1) {
                inferredWithTokenType = true;
            }
            if (!inferredWithTokenType && config.has("model_type") && ("bert".equals(modelType = config.get("model_type").getAsString().toLowerCase(Locale.ROOT)) || "albert".equals(modelType) || "xlnet".equals(modelType) || modelType.startsWith("deberta"))) {
                inferredWithTokenType = true;
            }
            this.tokenTypeId = inferredWithTokenType;
        }
        catch (JsonParseException | IOException e) {
            logger.error("Failed to read or parse config.json for label2id", e);
        }
    }

    public NDList processInput(TranslatorContext ctx, ZeroShotClassificationInput input) {
        ctx.setAttachment("input", (Object)input);
        return new NDList();
    }

    public ZeroShotClassificationOutput processOutput(TranslatorContext ctx, NDList list) throws TranslateException {
        double[] finalScores;
        String[] finalLabels;
        ZeroShotClassificationInput input = (ZeroShotClassificationInput)ctx.getAttachment("input");
        String template = input.getHypothesisTemplate();
        String[] candidates = input.getCandidates();
        if (candidates == null || candidates.length == 0) {
            throw new TranslateException("Missing candidates in input");
        }
        NDManager manager = ctx.getNDManager();
        NDList output = new NDList(candidates.length);
        for (String candidate : candidates) {
            String hypothesis = this.applyTemplate(template, candidate);
            Encoding encoding = this.tokenizer.encode(input.getText(), hypothesis);
            NDList in = encoding.toNDList(manager, this.tokenTypeId, this.int32);
            NDList batch = Batchifier.STACK.batchify(new NDList[]{in});
            output.add((Object)((NDArray)((NDList)this.predictor.predict((Object)batch)).get(0)));
        }
        NDArray combinedLogits = NDArrays.concat((NDList)output);
        if (input.isMultiLabel() || candidates.length == 1) {
            int i;
            NDArray entailmentScores;
            if (combinedLogits.getShape().get(1) == 2L) {
                NDArray probs = combinedLogits.softmax(1);
                entailmentScores = probs.get(":, " + this.entailmentId, new Object[0]);
            } else {
                NDArray entailContrLogits = combinedLogits.get(new NDIndex(":, {}", new Object[]{manager.create(new int[]{this.contradictionId, this.entailmentId})}));
                NDArray scoresProbs = entailContrLogits.softmax(1);
                entailmentScores = scoresProbs.get(":, 1", new Object[0]);
            }
            float[] floatScores = entailmentScores.toFloatArray();
            ArrayList<Pair> pairs = new ArrayList<Pair>();
            for (i = 0; i < floatScores.length; ++i) {
                Pair pair = new Pair((Object)floatScores[i], (Object)candidates[i]);
                pairs.add(pair);
            }
            pairs.sort(Comparator.comparingDouble(e -> (Double)e.getKey()).reversed());
            finalLabels = new String[candidates.length];
            finalScores = new double[candidates.length];
            for (i = 0; i < candidates.length; ++i) {
                finalLabels[i] = (String)((Pair)pairs.get(i)).getValue();
                finalScores[i] = (Double)((Pair)pairs.get(i)).getKey();
            }
        } else {
            NDArray entailLogits = combinedLogits.get(":, " + this.entailmentId, new Object[0]);
            NDArray exp = entailLogits.exp();
            NDArray sum = exp.sum();
            NDArray normalizedScores = exp.div(sum);
            long[] indices = normalizedScores.argSort(-1, false).toLongArray();
            float[] probabilities = normalizedScores.toFloatArray();
            finalLabels = new String[candidates.length];
            finalScores = new double[candidates.length];
            for (int i = 0; i < finalLabels.length; ++i) {
                int index = (int)indices[i];
                finalLabels[i] = candidates[index];
                finalScores[i] = probabilities[index];
            }
        }
        return new ZeroShotClassificationOutput(input.getText(), finalLabels, finalScores);
    }

    private String applyTemplate(String template, String arg) {
        int pos = template.indexOf("{}");
        if (pos == -1) {
            return template + arg;
        }
        int len = template.length();
        return template.substring(0, pos) + arg + template.substring(pos + 2, len);
    }

    public static Builder builder(HuggingFaceTokenizer tokenizer) {
        return new Builder(tokenizer);
    }

    public static Builder builder(HuggingFaceTokenizer tokenizer, Map<String, ?> arguments) {
        Builder builder = ZeroShotClassificationTranslator.builder(tokenizer);
        builder.configure(arguments);
        return builder;
    }

    public static final class Builder {
        private HuggingFaceTokenizer tokenizer;
        private boolean tokenTypeId;
        private boolean int32;
        private int entailmentId = 2;
        private int contradictionId;

        Builder(HuggingFaceTokenizer tokenizer) {
            this.tokenizer = tokenizer;
        }

        public Builder optTokenTypeId(boolean tokenTypeId) {
            this.tokenTypeId = tokenTypeId;
            return this;
        }

        public Builder optInt32(boolean int32) {
            this.int32 = int32;
            return this;
        }

        public Builder optEntailmentId(int entailmentId) {
            this.entailmentId = entailmentId;
            return this;
        }

        public Builder optContradictionId(int contradictionId) {
            this.contradictionId = contradictionId;
            return this;
        }

        public void configure(Map<String, ?> arguments) {
            this.optTokenTypeId(ArgumentsUtil.booleanValue(arguments, (String)"tokenTypeId"));
            this.optInt32(ArgumentsUtil.booleanValue(arguments, (String)"int32"));
        }

        public ZeroShotClassificationTranslator build() throws IOException {
            return new ZeroShotClassificationTranslator(this.tokenizer, this.tokenTypeId, this.int32, this.entailmentId, this.contradictionId);
        }
    }
}

