/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.transformer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.transformer.BertBlock;
import ai.djl.nn.transformer.MissingOps;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.Arrays;
import java.util.function.Function;

public class BertMaskedLanguageModelBlock
extends AbstractBlock {
    private static final byte VERSION = 1;
    private Linear sequenceProjection;
    private BatchNorm sequenceNorm;
    private Parameter dictionaryBias;
    private Function<NDArray, NDArray> hiddenActivation;

    public BertMaskedLanguageModelBlock(BertBlock bertBlock, Function<NDArray, NDArray> hiddenActivation) {
        super((byte)1);
        this.sequenceProjection = this.addChildBlock("sequenceProjection", Linear.builder().setUnits(bertBlock.getEmbeddingSize()).optBias(true).build());
        this.sequenceNorm = this.addChildBlock("sequenceNorm", BatchNorm.builder().optAxis(1).build());
        this.dictionaryBias = this.addParameter(Parameter.builder().setName("dictionaryBias").setType(Parameter.Type.BIAS).optShape(new Shape(bertBlock.getTokenDictionarySize())).build());
        this.hiddenActivation = hiddenActivation;
    }

    public static NDArray gatherFromIndices(NDArray sequences, NDArray indices) {
        int batchSize = (int)sequences.getShape().get(0);
        int sequenceLength = (int)sequences.getShape().get(1);
        int width = (int)sequences.getShape().get(2);
        int indicesPerSequence = (int)indices.getShape().get(1);
        NDArray sequenceOffsets = indices.getManager().newSubManager(indices.getDevice()).arange(0, batchSize).mul(sequenceLength).reshape(batchSize, 1L);
        NDArray absoluteIndices = indices.add(sequenceOffsets).reshape(1L, (long)batchSize * (long)indicesPerSequence);
        NDArray flattenedSequences = sequences.reshape((long)batchSize * (long)sequenceLength, width);
        return MissingOps.gatherNd(flattenedSequences, absoluteIndices);
    }

    @Override
    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape ... inputShapes) {
        this.inputNames = Arrays.asList("sequence", "maskedIndices", "embeddingTable");
        int width = (int)inputShapes[0].get(2);
        this.sequenceProjection.initialize(manager, dataType, new Shape(-1L, width));
        this.sequenceNorm.initialize(manager, dataType, new Shape(-1L, width));
    }

    @Override
    protected NDList forwardInternal(ParameterStore ps, NDList inputs, boolean training, PairList<String, Object> params) {
        NDArray sequenceOutput = (NDArray)inputs.get(0);
        NDArray maskedIndices = (NDArray)inputs.get(1);
        NDArray embeddingTable = (NDArray)inputs.get(2);
        try (NDManager scope = NDManager.subManagerOf(sequenceOutput);){
            scope.tempAttachAll(sequenceOutput, maskedIndices);
            NDArray gatheredTokens = BertMaskedLanguageModelBlock.gatherFromIndices(sequenceOutput, maskedIndices);
            NDArray projectedTokens = this.hiddenActivation.apply(this.sequenceProjection.forward(ps, new NDList(gatheredTokens), training).head());
            NDArray normalizedTokens = this.sequenceNorm.forward(ps, new NDList(projectedTokens), training).head();
            NDArray embeddingTransposed = embeddingTable.transpose();
            embeddingTransposed.attach(gatheredTokens.getManager());
            NDArray logits = normalizedTokens.dot(embeddingTransposed);
            NDArray logitsWithBias = logits.add(ps.getValue(this.dictionaryBias, logits.getDevice(), training));
            NDArray logProbs = logitsWithBias.logSoftmax(1);
            NDList nDList = scope.ret(new NDList(logProbs));
            return nDList;
        }
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        int batchSize = (int)inputShapes[0].get(0);
        int indexCount = (int)inputShapes[1].get(1);
        int dictionarySize = (int)inputShapes[2].get(0);
        return new Shape[]{new Shape((long)batchSize * (long)indexCount, dictionarySize)};
    }
}

