/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.iterator.bert;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.deeplearning4j.iterator.bert.BertSequenceMasker;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;

public class BertMaskedLMMasker
implements BertSequenceMasker {
    public static final double DEFAULT_MASK_PROB = 0.15;
    public static final double DEFAULT_MASK_TOKEN_PROB = 0.8;
    public static final double DEFAULT_RANDOM_WORD_PROB = 0.1;
    protected final Random r;
    protected final double maskProb;
    protected final double maskTokenProb;
    protected final double randomTokenProb;

    public BertMaskedLMMasker() {
        this(new Random(), 0.15, 0.8, 0.1);
    }

    public BertMaskedLMMasker(Random r, double maskProb, double maskTokenProb, double randomTokenProb) {
        Preconditions.checkArgument((maskProb > 0.0 && maskProb < 1.0 ? 1 : 0) != 0, (String)"Probability must be beteen 0 and 1, got %s", (double)maskProb);
        Preconditions.checkState((maskTokenProb >= 0.0 && maskTokenProb <= 1.0 ? 1 : 0) != 0, (String)"Mask token probability must be between 0 and 1, got %s", (double)maskTokenProb);
        Preconditions.checkState((randomTokenProb >= 0.0 && randomTokenProb <= 1.0 ? 1 : 0) != 0, (String)"Random token probability must be between 0 and 1, got %s", (double)randomTokenProb);
        Preconditions.checkState((maskTokenProb + randomTokenProb <= 1.0 ? 1 : 0) != 0, (String)"Sum of maskTokenProb (%s) and randomTokenProb (%s) must be <= 1.0, got sum is %s", (double)maskTokenProb, (double)randomTokenProb, (double)(maskTokenProb + randomTokenProb));
        this.r = r;
        this.maskProb = maskProb;
        this.maskTokenProb = maskTokenProb;
        this.randomTokenProb = randomTokenProb;
    }

    @Override
    public Pair<List<String>, boolean[]> maskSequence(List<String> input, String maskToken, List<String> vocabWords) {
        ArrayList<String> out = new ArrayList<String>(input.size());
        boolean[] masked = new boolean[input.size()];
        for (int i = 0; i < input.size(); ++i) {
            if (this.r.nextDouble() < this.maskProb) {
                double d = this.r.nextDouble();
                if (d < this.maskTokenProb) {
                    out.add(maskToken);
                } else if (d < this.maskTokenProb + this.randomTokenProb) {
                    String random = vocabWords.get(this.r.nextInt(vocabWords.size()));
                    out.add(random);
                } else {
                    out.add(input.get(i));
                }
                masked[i] = true;
                continue;
            }
            out.add(input.get(i));
        }
        return new Pair(out, (Object)masked);
    }
}

