/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.embedding.onnx;

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import dev.langchain4j.internal.Exceptions;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.embedding.onnx.PoolingMode;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class OnnxBertBiEncoder {
    private static final int MAX_SEQUENCE_LENGTH = 510;
    private final OrtEnvironment environment;
    private final OrtSession session;
    private final Set<String> expectedInputs;
    private final HuggingFaceTokenizer tokenizer;
    private final PoolingMode poolingMode;

    public OnnxBertBiEncoder(InputStream model, InputStream tokenizer, PoolingMode poolingMode) {
        try {
            this.environment = OrtEnvironment.getEnvironment();
            this.session = this.environment.createSession(this.loadModel(model));
            this.expectedInputs = this.session.getInputNames();
            this.tokenizer = HuggingFaceTokenizer.newInstance((InputStream)tokenizer, Collections.singletonMap("padding", "false"));
            this.poolingMode = (PoolingMode)((Object)ValidationUtils.ensureNotNull((Object)((Object)poolingMode), (String)"poolingMode"));
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public OnnxBertBiEncoder(OrtEnvironment environment, OrtSession session, InputStream tokenizer, PoolingMode poolingMode) {
        try {
            this.environment = environment;
            this.session = session;
            this.expectedInputs = session.getInputNames();
            this.tokenizer = HuggingFaceTokenizer.newInstance((InputStream)tokenizer, Collections.singletonMap("padding", "false"));
            this.poolingMode = (PoolingMode)((Object)ValidationUtils.ensureNotNull((Object)((Object)poolingMode), (String)"poolingMode"));
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    EmbeddingAndTokenCount embed(String text) {
        List tokens = this.tokenizer.tokenize(text);
        List<List<String>> partitions = OnnxBertBiEncoder.partition(tokens, 510);
        ArrayList<float[]> embeddings = new ArrayList<float[]>();
        for (List<String> partition : partitions) {
            try {
                OrtSession.Result result = this.encode(partition);
                Throwable throwable = null;
                try {
                    float[] embedding = this.toEmbedding(result);
                    embeddings.add(embedding);
                }
                catch (Throwable throwable2) {
                    throwable = throwable2;
                    throw throwable2;
                }
                finally {
                    if (result == null) continue;
                    if (throwable != null) {
                        try {
                            result.close();
                        }
                        catch (Throwable throwable3) {
                            throwable.addSuppressed(throwable3);
                        }
                        continue;
                    }
                    result.close();
                }
            }
            catch (OrtException e) {
                throw new RuntimeException(e);
            }
        }
        List<Integer> weights = partitions.stream().map(List::size).collect(Collectors.toList());
        float[] embedding = OnnxBertBiEncoder.normalize(this.weightedAverage(embeddings, weights));
        return new EmbeddingAndTokenCount(embedding, tokens.size());
    }

    static List<List<String>> partition(List<String> tokens, int partitionSize) {
        ArrayList<List<String>> partitions = new ArrayList<List<String>>();
        int from = 1;
        while (from < tokens.size() - 1) {
            int to = from + partitionSize;
            if (to >= tokens.size() - 1) {
                to = tokens.size() - 1;
            } else {
                while (tokens.get(to).startsWith("##")) {
                    --to;
                }
            }
            partitions.add(tokens.subList(from, to));
            from = to;
        }
        return partitions;
    }

    /*
     * Exception decompiling
     */
    private OrtSession.Result encode(List<String> tokens) throws OrtException {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 4 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    private String toText(List<String> tokens) {
        String text = this.tokenizer.buildSentence(tokens);
        List tokenized = this.tokenizer.tokenize(text);
        LinkedList tokenizedWithoutSpecialTokens = new LinkedList(tokenized);
        tokenizedWithoutSpecialTokens.remove(0);
        tokenizedWithoutSpecialTokens.remove(tokenizedWithoutSpecialTokens.size() - 1);
        if (tokenizedWithoutSpecialTokens.equals(tokens)) {
            return text;
        }
        return String.join((CharSequence)"", tokens);
    }

    private float[] toEmbedding(OrtSession.Result result) throws OrtException {
        float[][] vectors = ((float[][][])result.get(0).getValue())[0];
        return this.pool(vectors);
    }

    private float[] pool(float[][] vectors) {
        switch (this.poolingMode) {
            case CLS: {
                return OnnxBertBiEncoder.clsPool(vectors);
            }
            case MEAN: {
                return OnnxBertBiEncoder.meanPool(vectors);
            }
        }
        throw Exceptions.illegalArgument((String)("Unknown pooling mode: " + (Object)((Object)this.poolingMode)), (Object[])new Object[0]);
    }

    private static float[] clsPool(float[][] vectors) {
        return vectors[0];
    }

    private static float[] meanPool(float[][] vectors) {
        int numVectors = vectors.length;
        int vectorLength = vectors[0].length;
        float[] averagedVector = new float[vectorLength];
        for (float[] vector : vectors) {
            for (int j = 0; j < vectorLength; ++j) {
                int n = j;
                averagedVector[n] = averagedVector[n] + vector[j];
            }
        }
        int j = 0;
        while (j < vectorLength) {
            int n = j++;
            averagedVector[n] = averagedVector[n] / (float)numVectors;
        }
        return averagedVector;
    }

    private float[] weightedAverage(List<float[]> embeddings, List<Integer> weights) {
        if (embeddings.size() == 1) {
            return embeddings.get(0);
        }
        int dimensions = embeddings.get(0).length;
        float[] averagedEmbedding = new float[dimensions];
        int totalWeight = 0;
        for (int i = 0; i < embeddings.size(); ++i) {
            int weight = weights.get(i);
            totalWeight += weight;
            for (int j = 0; j < dimensions; ++j) {
                int n = j;
                averagedEmbedding[n] = averagedEmbedding[n] + embeddings.get(i)[j] * (float)weight;
            }
        }
        int j = 0;
        while (j < dimensions) {
            int n = j++;
            averagedEmbedding[n] = averagedEmbedding[n] / (float)totalWeight;
        }
        return averagedEmbedding;
    }

    private static float[] normalize(float[] vector) {
        float sumSquare = 0.0f;
        for (float v : vector) {
            sumSquare += v * v;
        }
        float norm = (float)Math.sqrt(sumSquare);
        float[] normalizedVector = new float[vector.length];
        for (int i = 0; i < vector.length; ++i) {
            normalizedVector[i] = vector[i] / norm;
        }
        return normalizedVector;
    }

    int countTokens(String text) {
        return this.tokenizer.tokenize(text).size();
    }

    /*
     * Exception decompiling
     */
    private byte[] loadModel(InputStream modelInputStream) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 3 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    static class EmbeddingAndTokenCount {
        float[] embedding;
        int tokenCount;

        EmbeddingAndTokenCount(float[] embedding, int tokenCount) {
            this.embedding = embedding;
            this.tokenCount = tokenCount;
        }
    }
}

