/*
 * Decompiled with CFR 0.152.
 */
package opennlp.dl.vectors;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.io.File;
import java.io.IOException;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import opennlp.dl.AbstractDL;
import opennlp.dl.Tokens;
import opennlp.tools.tokenize.Tokenizer;
import opennlp.tools.tokenize.WordpieceTokenizer;

public class SentenceVectorsDL
extends AbstractDL {
    public SentenceVectorsDL(File model, File vocabulary) throws OrtException, IOException {
        this.env = OrtEnvironment.getEnvironment();
        this.session = this.env.createSession(model.getPath(), new OrtSession.SessionOptions());
        this.vocab = this.loadVocab(new File(vocabulary.getPath()));
        this.tokenizer = new WordpieceTokenizer(this.vocab.keySet());
    }

    public float[] getVectors(String sentence) throws OrtException {
        Tokens tokens = this.tokenize(sentence, this.tokenizer, this.vocab);
        HashMap<String, OnnxTensor> inputs = new HashMap<String, OnnxTensor>();
        inputs.put("input_ids", OnnxTensor.createTensor((OrtEnvironment)this.env, (LongBuffer)LongBuffer.wrap(tokens.ids()), (long[])new long[]{1L, tokens.ids().length}));
        inputs.put("attention_mask", OnnxTensor.createTensor((OrtEnvironment)this.env, (LongBuffer)LongBuffer.wrap(tokens.mask()), (long[])new long[]{1L, tokens.mask().length}));
        inputs.put("token_type_ids", OnnxTensor.createTensor((OrtEnvironment)this.env, (LongBuffer)LongBuffer.wrap(tokens.types()), (long[])new long[]{1L, tokens.types().length}));
        float[][][] v = (float[][][])this.session.run(inputs).get(0).getValue();
        return v[0][0];
    }

    private Tokens tokenize(String text, Tokenizer tokenizer, Map<String, Integer> vocab) {
        String[] tokens = tokenizer.tokenize(text);
        int[] ids = new int[tokens.length];
        long[] mask = new long[ids.length];
        for (int x = 0; x < tokens.length; ++x) {
            ids[x] = vocab.get(tokens[x]);
        }
        long[] lids = Arrays.stream(ids).mapToLong(i -> i).toArray();
        long[] types = new long[ids.length];
        Arrays.fill(types, 1L);
        return new Tokens(tokens, lids, mask, types);
    }
}

