/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.embedding;

import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import dev.langchain4j.model.embedding.BertTokenizer;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;

public class OnnxBertCrossEncoder {
    private final OrtEnvironment environment;
    private final OrtSession session;
    private final BertTokenizer tokenizer;

    public static void main(String[] args) throws IOException {
        Path path = Paths.get("C:\\dev\\ai\\onnx\\ms-marco-MiniLM-L-6-v2\\model.onnx", new String[0]);
        OnnxBertCrossEncoder crossEncoder = new OnnxBertCrossEncoder(Files.newInputStream(path, new OpenOption[0]));
        float score = crossEncoder.encode("How many people live in Berlin?", "Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.");
        System.out.println(score);
    }

    public OnnxBertCrossEncoder(InputStream modelInputStream) {
        try {
            this.environment = OrtEnvironment.getEnvironment();
            this.session = this.environment.createSession(this.loadModel(modelInputStream));
            this.tokenizer = new BertTokenizer();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public float encode(String first, String second) {
        try (OrtSession.Result result = this.runModel(first, second);){
            float f = ((float[][])result.get(0).getValue())[0][0];
            return f;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /*
     * Exception decompiling
     */
    private OrtSession.Result runModel(String first, String second) throws OrtException {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 4 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    /*
     * Exception decompiling
     */
    private byte[] loadModel(InputStream modelInputStream) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 3 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    private static float[] toEmbedding(OrtSession.Result result) throws OrtException {
        float[][] vectors = ((float[][][])result.get(0).getValue())[0];
        return OnnxBertCrossEncoder.normalize(OnnxBertCrossEncoder.meanPool(vectors));
    }

    private static float[] meanPool(float[][] vectors) {
        int numVectors = vectors.length;
        int vectorLength = vectors[0].length;
        float[] averagedVector = new float[vectorLength];
        for (int i = 0; i < numVectors; ++i) {
            for (int j = 0; j < vectorLength; ++j) {
                int n = j;
                averagedVector[n] = averagedVector[n] + vectors[i][j];
            }
        }
        int j = 0;
        while (j < vectorLength) {
            int n = j++;
            averagedVector[n] = averagedVector[n] / (float)numVectors;
        }
        return averagedVector;
    }

    private static float[] normalize(float[] vector) {
        float sumSquare = 0.0f;
        for (float v : vector) {
            sumSquare += v * v;
        }
        float norm = (float)Math.sqrt(sumSquare);
        float[] normalizedVector = new float[vector.length];
        for (int i = 0; i < vector.length; ++i) {
            normalizedVector[i] = vector[i] / norm;
        }
        return normalizedVector;
    }

    int countTokens(String text) {
        return this.tokenizer.tokenize(text).size();
    }
}

