/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.iterator;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.deeplearning4j.iterator.LabeledSentenceProvider;
import org.deeplearning4j.iterator.bert.BertMaskedLMMasker;
import org.deeplearning4j.iterator.bert.BertSequenceMasker;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;

public class BertIterator
implements MultiDataSetIterator {
    protected Task task;
    protected TokenizerFactory tokenizerFactory;
    protected int maxTokens = -1;
    protected int minibatchSize = 32;
    protected boolean padMinibatches = false;
    protected MultiDataSetPreProcessor preProcessor;
    protected LabeledSentenceProvider sentenceProvider = null;
    protected LengthHandling lengthHandling;
    protected FeatureArrays featureArrays;
    protected Map<String, Integer> vocabMap;
    protected BertSequenceMasker masker = null;
    protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
    protected String maskToken;
    protected String prependToken;
    protected List<String> vocabKeysAsList;

    protected BertIterator(Builder b) {
        this.task = b.task;
        this.tokenizerFactory = b.tokenizerFactory;
        this.maxTokens = b.maxTokens;
        this.minibatchSize = b.minibatchSize;
        this.padMinibatches = b.padMinibatches;
        this.preProcessor = b.preProcessor;
        this.sentenceProvider = b.sentenceProvider;
        this.lengthHandling = b.lengthHandling;
        this.featureArrays = b.featureArrays;
        this.vocabMap = b.vocabMap;
        this.masker = b.masker;
        this.unsupervisedLabelFormat = b.unsupervisedLabelFormat;
        this.maskToken = b.maskToken;
        this.prependToken = b.prependToken;
    }

    public boolean hasNext() {
        return this.sentenceProvider.hasNext();
    }

    public org.nd4j.linalg.dataset.api.MultiDataSet next() {
        return this.next(this.minibatchSize);
    }

    public void remove() {
        throw new UnsupportedOperationException("Not supported");
    }

    public org.nd4j.linalg.dataset.api.MultiDataSet next(int num) {
        INDArray[] lm;
        INDArray[] fm;
        INDArray[] f;
        int outLength;
        Preconditions.checkState((boolean)this.hasNext(), (String)"No next element available");
        ArrayList<Pair<String, String>> list = new ArrayList<Pair<String, String>>(num);
        int count = 0;
        if (this.sentenceProvider != null) {
            while (this.sentenceProvider.hasNext() && count++ < num) {
                list.add(this.sentenceProvider.nextSentence());
            }
        } else {
            throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
        }
        ArrayList<Pair> tokenizedSentences = new ArrayList<Pair>(num);
        int longestSeq = -1;
        for (Pair pair : list) {
            List<String> tokens = this.tokenizeSentence((String)pair.getFirst());
            tokenizedSentences.add(new Pair(tokens, pair.getSecond()));
            longestSeq = Math.max(longestSeq, tokens.size());
        }
        switch (this.lengthHandling) {
            case FIXED_LENGTH: {
                outLength = this.maxTokens;
                break;
            }
            case ANY_LENGTH: {
                outLength = longestSeq;
                break;
            }
            case CLIP_ONLY: {
                outLength = Math.min(this.maxTokens, longestSeq);
                break;
            }
            default: {
                throw new RuntimeException("Not implemented length handling mode: " + (Object)((Object)this.lengthHandling));
            }
        }
        int n = tokenizedSentences.size();
        int mbPadded = this.padMinibatches ? this.minibatchSize : n;
        int[][] outIdxs = new int[mbPadded][outLength];
        int[][] outMask = new int[mbPadded][outLength];
        for (int i = 0; i < tokenizedSentences.size(); ++i) {
            Pair p = (Pair)tokenizedSentences.get(i);
            List t = (List)p.getFirst();
            for (int j = 0; j < outLength && j < t.size(); ++j) {
                int idx;
                Preconditions.checkState((boolean)this.vocabMap.containsKey(t.get(j)), (String)"Unknown token encontered: token \"%s\" is not in vocabulary", t.get(j));
                outIdxs[i][j] = idx = this.vocabMap.get(t.get(j)).intValue();
                outMask[i][j] = 1;
            }
        }
        INDArray outIdxsArr = Nd4j.createFromArray((int[][])outIdxs);
        INDArray outMaskArr = Nd4j.createFromArray((int[][])outMask);
        if (this.featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) {
            INDArray outSegmentIdArr = Nd4j.zeros((DataType)DataType.INT, (long[])new long[]{mbPadded, outLength});
            f = new INDArray[]{outIdxsArr, outSegmentIdArr};
            fm = new INDArray[]{outMaskArr, null};
        } else {
            f = new INDArray[]{outIdxsArr};
            fm = new INDArray[]{outMaskArr};
        }
        INDArray[] l = new INDArray[1];
        if (this.task == Task.SEQ_CLASSIFICATION) {
            int numClasses;
            int[] classLabels = new int[mbPadded];
            if (this.sentenceProvider != null) {
                numClasses = this.sentenceProvider.numLabelClasses();
                List<String> labels = this.sentenceProvider.allLabels();
                for (int i = 0; i < n; ++i) {
                    String lbl = (String)((Pair)tokenizedSentences.get(i)).getRight();
                    classLabels[i] = labels.indexOf(lbl);
                    Preconditions.checkState((classLabels[i] >= 0 ? 1 : 0) != 0, (String)"Provided label \"%s\" for sentence does not exist in set of classes/categories", (Object)lbl);
                }
            } else {
                throw new RuntimeException();
            }
            l[0] = Nd4j.create((DataType)DataType.FLOAT, (long[])new long[]{mbPadded, numClasses});
            for (int i = 0; i < n; ++i) {
                l[0].putScalar((long)i, (long)classLabels[i], 1.0);
            }
            lm = null;
            if (this.padMinibatches && n != mbPadded) {
                INDArray a = Nd4j.zeros((DataType)DataType.FLOAT, (long[])new long[]{mbPadded, 1L});
                lm = new INDArray[]{a};
                a.get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)n), NDArrayIndex.all()}).assign((Number)1);
            }
        } else if (this.task == Task.UNSUPERVISED) {
            INDArray labelArr;
            if (this.vocabKeysAsList == null) {
                String[] arr = new String[this.vocabMap.size()];
                for (Map.Entry<String, Integer> e : this.vocabMap.entrySet()) {
                    arr[e.getValue().intValue()] = e.getKey();
                }
                this.vocabKeysAsList = Arrays.asList(arr);
            }
            int vocabSize = this.vocabMap.size();
            INDArray lMask = Nd4j.zeros((DataType)DataType.INT, (long[])new long[]{mbPadded, outLength});
            if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) {
                labelArr = Nd4j.create((DataType)DataType.INT, (long[])new long[]{mbPadded, outLength});
            } else if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) {
                labelArr = Nd4j.create((DataType)DataType.FLOAT, (long[])new long[]{mbPadded, vocabSize, outLength});
            } else if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) {
                labelArr = Nd4j.create((DataType)DataType.FLOAT, (long[])new long[]{outLength, mbPadded, vocabSize});
            } else {
                throw new IllegalStateException("Unknown unsupervised label format: " + (Object)((Object)this.unsupervisedLabelFormat));
            }
            for (int i = 0; i < n; ++i) {
                List tokens = (List)((Pair)tokenizedSentences.get(i)).getFirst();
                Pair<List<String>, boolean[]> p = this.masker.maskSequence(tokens, this.maskToken, this.vocabKeysAsList);
                List maskedTokens = (List)p.getFirst();
                boolean[] predictionTarget = (boolean[])p.getSecond();
                int seqLen = Math.min(predictionTarget.length, outLength);
                for (int j = 0; j < seqLen; ++j) {
                    if (!predictionTarget[j]) continue;
                    String oldToken = (String)((List)((Pair)tokenizedSentences.get(i)).getFirst()).get(j);
                    int targetTokenIdx = this.vocabMap.get(oldToken);
                    if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) {
                        labelArr.putScalar((long)i, (long)j, (double)targetTokenIdx);
                    } else if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) {
                        labelArr.putScalar((long)i, (long)j, (long)targetTokenIdx, 1.0);
                    } else if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) {
                        labelArr.putScalar((long)j, (long)i, (long)targetTokenIdx, 1.0);
                    }
                    lMask.putScalar((long)i, (long)j, 1.0);
                    String newToken = (String)maskedTokens.get(j);
                    int newTokenIdx = this.vocabMap.get(newToken);
                    outIdxsArr.putScalar((long)i, (long)j, (double)newTokenIdx);
                }
            }
            l[0] = labelArr;
            lm = new INDArray[]{lMask};
        } else {
            throw new IllegalStateException("Task not yet implemented: " + (Object)((Object)this.task));
        }
        MultiDataSet mds = new MultiDataSet(f, l, fm, lm);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((org.nd4j.linalg.dataset.api.MultiDataSet)mds);
        }
        return mds;
    }

    private List<String> tokenizeSentence(String sentence) {
        Tokenizer t = this.tokenizerFactory.create(sentence);
        ArrayList<String> tokens = new ArrayList<String>();
        if (this.prependToken != null) {
            tokens.add(this.prependToken);
        }
        while (t.hasMoreTokens()) {
            String token = t.nextToken();
            tokens.add(token);
        }
        return tokens;
    }

    public boolean resetSupported() {
        return true;
    }

    public boolean asyncSupported() {
        return true;
    }

    public void reset() {
        if (this.sentenceProvider != null) {
            this.sentenceProvider.reset();
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    public MultiDataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    public static class Builder {
        protected Task task;
        protected TokenizerFactory tokenizerFactory;
        protected LengthHandling lengthHandling = LengthHandling.FIXED_LENGTH;
        protected int maxTokens = -1;
        protected int minibatchSize = 32;
        protected boolean padMinibatches = false;
        protected MultiDataSetPreProcessor preProcessor;
        protected LabeledSentenceProvider sentenceProvider = null;
        protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID;
        protected Map<String, Integer> vocabMap;
        protected BertSequenceMasker masker = new BertMaskedLMMasker();
        protected UnsupervisedLabelFormat unsupervisedLabelFormat;
        protected String maskToken;
        protected String prependToken;

        public Builder task(Task task) {
            this.task = task;
            return this;
        }

        public Builder tokenizer(TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength) {
            if (lengthHandling == null) {
                throw new NullPointerException("lengthHandling is marked @NonNull but is null");
            }
            this.lengthHandling = lengthHandling;
            this.maxTokens = maxLength;
            return this;
        }

        public Builder minibatchSize(int minibatchSize) {
            this.minibatchSize = minibatchSize;
            return this;
        }

        public Builder padMinibatches(boolean padMinibatches) {
            this.padMinibatches = padMinibatches;
            return this;
        }

        public Builder preProcessor(MultiDataSetPreProcessor preProcessor) {
            this.preProcessor = preProcessor;
            return this;
        }

        public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) {
            this.sentenceProvider = sentenceProvider;
            return this;
        }

        public Builder featureArrays(FeatureArrays featureArrays) {
            this.featureArrays = featureArrays;
            return this;
        }

        public Builder vocabMap(Map<String, Integer> vocabMap) {
            this.vocabMap = vocabMap;
            return this;
        }

        public Builder masker(BertSequenceMasker masker) {
            this.masker = masker;
            return this;
        }

        public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat) {
            this.unsupervisedLabelFormat = labelFormat;
            return this;
        }

        public Builder maskToken(String maskToken) {
            this.maskToken = maskToken;
            return this;
        }

        public Builder prependToken(String prependToken) {
            this.prependToken = prependToken;
            return this;
        }

        public BertIterator build() {
            Preconditions.checkState((this.task != null ? 1 : 0) != 0, (String)"No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
            Preconditions.checkState((this.tokenizerFactory != null ? 1 : 0) != 0, (String)"No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
            Preconditions.checkState((this.vocabMap != null ? 1 : 0) != 0, (String)"Cannot create iterator: No vocabMap has been set. Use Builder.vocabMap(Map<String,Integer>) to set");
            Preconditions.checkState((this.task != Task.UNSUPERVISED || this.masker != null ? 1 : 0) != 0, (String)"If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method");
            Preconditions.checkState((this.task != Task.UNSUPERVISED || this.unsupervisedLabelFormat != null ? 1 : 0) != 0, (String)"If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method");
            Preconditions.checkState((this.task != Task.UNSUPERVISED || this.maskToken != null ? 1 : 0) != 0, (String)"If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified");
            return new BertIterator(this);
        }
    }

    public static enum UnsupervisedLabelFormat {
        RANK2_IDX,
        RANK3_NCL,
        RANK3_LNC;

    }

    public static enum FeatureArrays {
        INDICES_MASK,
        INDICES_MASK_SEGMENTID;

    }

    public static enum LengthHandling {
        FIXED_LENGTH,
        ANY_LENGTH,
        CLIP_ONLY;

    }

    public static enum Task {
        UNSUPERVISED,
        SEQ_CLASSIFICATION;

    }
}

