/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.transformer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.transformer.PointwiseFeedForwardBlock;
import ai.djl.nn.transformer.ScaledDotProductAttentionBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.Collections;
import java.util.function.Function;

public class TransformerEncoderBlock
extends AbstractBlock {
    private static final byte VERSION = 1;
    private ScaledDotProductAttentionBlock selfAttentionBlock;
    private Dropout selfAttentionDropout;
    private BatchNorm attentionNorm;
    private PointwiseFeedForwardBlock pointWisefullyConnected;
    private Dropout fullyConnectedDropout;
    private BatchNorm outputNorm;

    public TransformerEncoderBlock(int embeddingSize, int headCount, int hiddenSize, float dropoutProbability, Function<NDList, NDList> activationFunction) {
        super((byte)1);
        this.selfAttentionBlock = this.addChildBlock("selfAttention", ScaledDotProductAttentionBlock.builder().setEmbeddingSize(embeddingSize).setHeadCount(headCount).optAttentionProbsDropoutProb(dropoutProbability).build());
        this.selfAttentionDropout = Dropout.builder().optRate(dropoutProbability).build();
        this.attentionNorm = this.addChildBlock("attentionNorm", BatchNorm.builder().optAxis(2).build());
        this.pointWisefullyConnected = this.addChildBlock("outputBlock", new PointwiseFeedForwardBlock(Collections.singletonList(hiddenSize), embeddingSize, activationFunction));
        this.fullyConnectedDropout = Dropout.builder().optRate(dropoutProbability).build();
        this.outputNorm = this.addChildBlock("outputNorm", BatchNorm.builder().optAxis(2).build());
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        return inputShapes;
    }

    @Override
    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape ... inputShapes) {
        this.selfAttentionBlock.initialize(manager, dataType, inputShapes);
        this.attentionNorm.initialize(manager, dataType, inputShapes);
        this.pointWisefullyConnected.initialize(manager, dataType, inputShapes);
        this.outputNorm.initialize(manager, dataType, inputShapes);
    }

    @Override
    protected NDList forwardInternal(ParameterStore ps, NDList inputs, boolean training, PairList<String, Object> params) {
        NDArray embedding = inputs.head();
        NDList attentionOutput = this.selfAttentionBlock.forward(ps, inputs, training);
        NDList attentionOutputAfterDropout = this.selfAttentionDropout.forward(ps, attentionOutput, training);
        NDArray withResidual = attentionOutputAfterDropout.singletonOrThrow().add(embedding);
        NDList normalized = this.attentionNorm.forward(ps, new NDList(withResidual), training);
        NDList afterFullyConnected = this.pointWisefullyConnected.forward(ps, normalized, training);
        NDList afterFullyConnectedDropout = this.fullyConnectedDropout.forward(ps, afterFullyConnected, training);
        NDList outputWithResidual = new NDList(afterFullyConnectedDropout.singletonOrThrow().add(embedding));
        return this.outputNorm.forward(ps, new NDList(outputWithResidual), training);
    }
}

