/*
 * Decompiled with CFR 0.152.
 */
package com.robrua.nlp.bert;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.Lists;
import com.google.common.io.Resources;
import com.robrua.nlp.bert.FullTokenizer;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URL;
import java.nio.IntBuffer;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;

public class Bert
implements AutoCloseable {
    private static final int FILE_COPY_BUFFER_BYTES = 0x100000;
    private static final String MODEL_DETAILS = "model.json";
    private static final String SEPARATOR_TOKEN = "[SEP]";
    private static final String START_TOKEN = "[CLS]";
    private static final String VOCAB_FILE = "vocab.txt";
    private final SavedModelBundle bundle;
    private final ModelDetails model;
    private final int separatorTokenId;
    private final int startTokenId;
    private final FullTokenizer tokenizer;

    public static Bert load(File model) {
        return Bert.load(Paths.get(model.toURI()));
    }

    public static Bert load(Path path) {
        ModelDetails model;
        path = path.toAbsolutePath();
        try {
            model = (ModelDetails)new ObjectMapper().readValue(path.resolve("assets").resolve(MODEL_DETAILS).toFile(), ModelDetails.class);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        return new Bert(SavedModelBundle.load((String)path.toString(), (String[])new String[]{"serve"}), model, path.resolve("assets").resolve(VOCAB_FILE));
    }

    public static Bert load(String resource) {
        Path directory = null;
        try {
            URL model = Resources.getResource((String)resource);
            directory = Files.createTempDirectory("easy-bert-", new FileAttribute[0]);
            try (ZipInputStream zip = new ZipInputStream(Resources.asByteSource((URL)model).openBufferedStream());){
                ZipEntry entry;
                while ((entry = zip.getNextEntry()) != null) {
                    Path path = directory.resolve(entry.getName());
                    if (entry.getName().endsWith("/")) {
                        Files.createDirectories(path, new FileAttribute[0]);
                    } else {
                        Files.createFile(path, new FileAttribute[0]);
                        try (OutputStream output = Files.newOutputStream(path, new OpenOption[0]);){
                            int bytes;
                            byte[] buffer = new byte[0x100000];
                            while ((bytes = zip.read(buffer)) > 0) {
                                output.write(buffer, 0, bytes);
                            }
                        }
                    }
                    zip.closeEntry();
                }
            }
            Bert bert = Bert.load(directory);
            return bert;
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        finally {
            if (directory != null && Files.exists(directory, new LinkOption[0])) {
                try {
                    Files.walk(directory, new FileVisitOption[0]).sorted(Comparator.reverseOrder()).forEach(file -> {
                        try {
                            Files.delete(file);
                        }
                        catch (IOException e) {
                            throw new RuntimeException(e);
                        }
                    });
                }
                catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
        }
    }

    private Bert(SavedModelBundle bundle, ModelDetails model, Path vocabulary) {
        this.tokenizer = new FullTokenizer(vocabulary, model.doLowerCase);
        this.bundle = bundle;
        this.model = model;
        int[] ids = this.tokenizer.convert(new String[]{START_TOKEN, SEPARATOR_TOKEN});
        this.startTokenId = ids[0];
        this.separatorTokenId = ids[1];
    }

    @Override
    public void close() {
        this.bundle.close();
    }

    /*
     * Exception decompiling
     */
    public float[] embedSequence(String sequence) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    public float[][] embedSequences(Iterable<String> sequences) {
        ArrayList list = Lists.newArrayList(sequences);
        return this.embedSequences(list.toArray(new String[list.size()]));
    }

    public float[][] embedSequences(Iterator<String> sequences) {
        ArrayList list = Lists.newArrayList(sequences);
        return this.embedSequences(list.toArray(new String[list.size()]));
    }

    /*
     * Exception decompiling
     */
    public float[][] embedSequences(String ... sequences) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    public float[][][] embedTokens(Iterable<String> sequences) {
        ArrayList list = Lists.newArrayList(sequences);
        return this.embedTokens(list.toArray(new String[list.size()]));
    }

    public float[][][] embedTokens(Iterator<String> sequences) {
        ArrayList list = Lists.newArrayList(sequences);
        return this.embedTokens(list.toArray(new String[list.size()]));
    }

    /*
     * Exception decompiling
     */
    public float[][] embedTokens(String sequence) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    /*
     * Exception decompiling
     */
    public float[][][] embedTokens(String ... sequences) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    private Inputs getInputs(String sequence) {
        String[] tokens = this.tokenizer.tokenize(sequence);
        IntBuffer inputIds = IntBuffer.allocate(this.model.maxSequenceLength);
        IntBuffer inputMask = IntBuffer.allocate(this.model.maxSequenceLength);
        IntBuffer segmentIds = IntBuffer.allocate(this.model.maxSequenceLength);
        int[] ids = this.tokenizer.convert(tokens);
        inputIds.put(this.startTokenId);
        inputMask.put(1);
        segmentIds.put(0);
        for (int i = 0; i < ids.length && i < this.model.maxSequenceLength - 2; ++i) {
            inputIds.put(ids[i]);
            inputMask.put(1);
            segmentIds.put(0);
        }
        inputIds.put(this.separatorTokenId);
        inputMask.put(1);
        segmentIds.put(0);
        while (inputIds.position() < this.model.maxSequenceLength) {
            inputIds.put(0);
            inputMask.put(0);
            segmentIds.put(0);
        }
        inputIds.rewind();
        inputMask.rewind();
        segmentIds.rewind();
        return new Inputs(inputIds, inputMask, segmentIds, 1);
    }

    private Inputs getInputs(String[] sequences) {
        String[][] tokens = this.tokenizer.tokenize(sequences);
        IntBuffer inputIds = IntBuffer.allocate(sequences.length * this.model.maxSequenceLength);
        IntBuffer inputMask = IntBuffer.allocate(sequences.length * this.model.maxSequenceLength);
        IntBuffer segmentIds = IntBuffer.allocate(sequences.length * this.model.maxSequenceLength);
        int instance = 1;
        for (String[] token : tokens) {
            int[] ids = this.tokenizer.convert(token);
            inputIds.put(this.startTokenId);
            inputMask.put(1);
            segmentIds.put(0);
            for (int i = 0; i < ids.length && i < this.model.maxSequenceLength - 2; ++i) {
                inputIds.put(ids[i]);
                inputMask.put(1);
                segmentIds.put(0);
            }
            inputIds.put(this.separatorTokenId);
            inputMask.put(1);
            segmentIds.put(0);
            while (inputIds.position() < this.model.maxSequenceLength * instance) {
                inputIds.put(0);
                inputMask.put(0);
                segmentIds.put(0);
            }
            ++instance;
        }
        inputIds.rewind();
        inputMask.rewind();
        segmentIds.rewind();
        return new Inputs(inputIds, inputMask, segmentIds, sequences.length);
    }

    private static class ModelDetails {
        public boolean doLowerCase;
        public String inputIds;
        public String inputMask;
        public String segmentIds;
        public String pooledOutput;
        public String sequenceOutput;
        public int maxSequenceLength;

        private ModelDetails() {
        }
    }

    private class Inputs
    implements AutoCloseable {
        private final Tensor<Integer> inputIds;
        private final Tensor<Integer> inputMask;
        private final Tensor<Integer> segmentIds;

        public Inputs(IntBuffer inputIds, IntBuffer inputMask, IntBuffer segmentIds, int count) {
            this.inputIds = Tensor.create((long[])new long[]{count, ((Bert)Bert.this).model.maxSequenceLength}, (IntBuffer)inputIds);
            this.inputMask = Tensor.create((long[])new long[]{count, ((Bert)Bert.this).model.maxSequenceLength}, (IntBuffer)inputMask);
            this.segmentIds = Tensor.create((long[])new long[]{count, ((Bert)Bert.this).model.maxSequenceLength}, (IntBuffer)segmentIds);
        }

        @Override
        public void close() {
            this.inputIds.close();
            this.inputMask.close();
            this.segmentIds.close();
        }

        static /* synthetic */ Tensor access$100(Inputs x0) {
            return x0.segmentIds;
        }

        static /* synthetic */ Tensor access$200(Inputs x0) {
            return x0.inputMask;
        }

        static /* synthetic */ Tensor access$300(Inputs x0) {
            return x0.inputIds;
        }
    }
}

