/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers;

import java.util.Map;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayerUtils;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class RecurrentAttentionLayer
extends SameDiffLayer {
    private long nIn;
    private long nOut;
    private int nHeads;
    private long headSize;
    private boolean projectInput;
    private Activation activation;
    private boolean hasBias;
    private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
    private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
    private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv";
    private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo";
    private static final String WEIGHT_KEY = "W";
    private static final String BIAS_KEY = "b";
    private static final String RECURRENT_WEIGHT_KEY = "RW";
    private int timeSteps;

    private RecurrentAttentionLayer() {
    }

    protected RecurrentAttentionLayer(Builder builder) {
        super(builder);
        this.nIn = builder.nIn;
        this.nOut = builder.nOut;
        this.nHeads = builder.nHeads;
        this.headSize = builder.headSize == 0 ? this.nOut / (long)this.nHeads : (long)builder.headSize;
        this.projectInput = builder.projectInput;
        this.activation = builder.activation;
        this.hasBias = builder.hasBias;
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, this.getLayerName());
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
        if (inputType == null || inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Invalid input for Recurrent Attention layer (layer name = \"" + this.getLayerName() + "\"): expect RNN input type with size > 0. Got: " + inputType);
        }
        if (this.nIn <= 0L || override) {
            InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent)inputType;
            this.nIn = r.getSize();
        }
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Invalid input for Recurrent Attention layer (layer index = " + layerIndex + ", layer name = \"" + this.getLayerName() + "\"): expect RNN input type with size > 0. Got: " + inputType);
        }
        InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent)inputType;
        return InputType.recurrent(this.nOut, itr.getTimeSeriesLength());
    }

    @Override
    public void defineParameters(SDLayerParams params) {
        params.clear();
        params.addWeightParam(WEIGHT_KEY, this.nIn, this.nOut);
        params.addWeightParam(RECURRENT_WEIGHT_KEY, this.nOut, this.nOut);
        if (this.hasBias) {
            params.addBiasParam(BIAS_KEY, this.nOut);
        }
        if (this.projectInput) {
            params.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, this.nHeads, this.headSize, this.nOut);
            params.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, this.nHeads, this.headSize, this.nIn);
            params.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, this.nHeads, this.headSize, this.nIn);
            params.addWeightParam(WEIGHT_KEY_OUT_PROJECTION, (long)this.nHeads * this.headSize, this.nOut);
        }
    }

    @Override
    public void initializeParameters(Map<String, INDArray> params) {
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            block21: for (Map.Entry<String, INDArray> e : params.entrySet()) {
                String keyName;
                switch (keyName = e.getKey()) {
                    case "W": {
                        WeightInitUtil.initWeights((double)this.nIn, (double)this.nOut, e.getValue().shape(), this.weightInit, null, 'c', e.getValue());
                        continue block21;
                    }
                    case "RW": {
                        WeightInitUtil.initWeights((double)this.nOut, (double)this.nOut, e.getValue().shape(), this.weightInit, null, 'c', e.getValue());
                        continue block21;
                    }
                    case "b": {
                        e.getValue().assign((Number)0);
                        continue block21;
                    }
                    case "Wo": {
                        WeightInitUtil.initWeights((double)this.nIn, (double)this.headSize, e.getValue().shape(), this.weightInit, null, 'c', e.getValue());
                        continue block21;
                    }
                }
                WeightInitUtil.initWeights((double)((long)this.nHeads * this.headSize), (double)this.nOut, e.getValue().shape(), this.weightInit, null, 'c', e.getValue());
            }
        }
    }

    @Override
    public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) {
        if (this.activation == null) {
            this.activation = SameDiffLayerUtils.fromIActivation(globalConfig.getActivationFn());
        }
    }

    @Override
    public void validateInput(INDArray input) {
        long inputLength = input.size(2);
        Preconditions.checkArgument((inputLength == (long)this.timeSteps ? 1 : 0) != 0, (String)"This layer only supports fixed length mini-batches. Expected %s time steps but got %s.", (long)this.timeSteps, (long)inputLength);
    }

    @Override
    public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) {
        SDVariable W = paramTable.get(WEIGHT_KEY);
        SDVariable R = paramTable.get(RECURRENT_WEIGHT_KEY);
        SDVariable b = paramTable.get(BIAS_KEY);
        SDVariable[] inputSlices = sameDiff.unstack(layerInput, 2);
        this.timeSteps = inputSlices.length;
        SDVariable[] outputSlices = new SDVariable[this.timeSteps];
        SDVariable prev = null;
        for (int i = 0; i < this.timeSteps; ++i) {
            SDVariable x_i = inputSlices[i];
            outputSlices[i] = x_i.mmul(W);
            if (this.hasBias) {
                outputSlices[i] = outputSlices[i].add(b);
            }
            if (prev != null) {
                SDVariable attn;
                if (this.projectInput) {
                    SDVariable Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
                    SDVariable Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
                    SDVariable Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
                    SDVariable Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
                    attn = sameDiff.nn.multiHeadDotProductAttention(this.getLayerName() + "_attention_" + i, prev, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true);
                } else {
                    attn = sameDiff.nn.dotProductAttention(this.getLayerName() + "_attention_" + i, prev, layerInput, layerInput, mask, true);
                }
                attn = sameDiff.squeeze(attn, 2);
                outputSlices[i] = outputSlices[i].add(attn.mmul(R));
            }
            outputSlices[i] = this.activation.asSameDiff(sameDiff, outputSlices[i]);
            outputSlices[i] = sameDiff.expandDims(outputSlices[i], 2);
            prev = outputSlices[i];
        }
        return sameDiff.concat(2, outputSlices);
    }

    public long getNIn() {
        return this.nIn;
    }

    public long getNOut() {
        return this.nOut;
    }

    public int getNHeads() {
        return this.nHeads;
    }

    public long getHeadSize() {
        return this.headSize;
    }

    public boolean isProjectInput() {
        return this.projectInput;
    }

    public Activation getActivation() {
        return this.activation;
    }

    public boolean isHasBias() {
        return this.hasBias;
    }

    public int getTimeSteps() {
        return this.timeSteps;
    }

    public void setNIn(long nIn) {
        this.nIn = nIn;
    }

    public void setNOut(long nOut) {
        this.nOut = nOut;
    }

    public void setNHeads(int nHeads) {
        this.nHeads = nHeads;
    }

    public void setHeadSize(long headSize) {
        this.headSize = headSize;
    }

    public void setProjectInput(boolean projectInput) {
        this.projectInput = projectInput;
    }

    public void setActivation(Activation activation) {
        this.activation = activation;
    }

    public void setHasBias(boolean hasBias) {
        this.hasBias = hasBias;
    }

    public void setTimeSteps(int timeSteps) {
        this.timeSteps = timeSteps;
    }

    @Override
    public String toString() {
        return "RecurrentAttentionLayer(nIn=" + this.getNIn() + ", nOut=" + this.getNOut() + ", nHeads=" + this.getNHeads() + ", headSize=" + this.getHeadSize() + ", projectInput=" + this.isProjectInput() + ", activation=" + this.getActivation() + ", hasBias=" + this.isHasBias() + ", timeSteps=" + this.getTimeSteps() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof RecurrentAttentionLayer)) {
            return false;
        }
        RecurrentAttentionLayer other = (RecurrentAttentionLayer)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (this.getNIn() != other.getNIn()) {
            return false;
        }
        if (this.getNOut() != other.getNOut()) {
            return false;
        }
        if (this.getNHeads() != other.getNHeads()) {
            return false;
        }
        if (this.getHeadSize() != other.getHeadSize()) {
            return false;
        }
        if (this.isProjectInput() != other.isProjectInput()) {
            return false;
        }
        Activation this$activation = this.getActivation();
        Activation other$activation = other.getActivation();
        if (this$activation == null ? other$activation != null : !this$activation.equals(other$activation)) {
            return false;
        }
        if (this.isHasBias() != other.isHasBias()) {
            return false;
        }
        return this.getTimeSteps() == other.getTimeSteps();
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof RecurrentAttentionLayer;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        long $nIn = this.getNIn();
        result = result * 59 + (int)($nIn >>> 32 ^ $nIn);
        long $nOut = this.getNOut();
        result = result * 59 + (int)($nOut >>> 32 ^ $nOut);
        result = result * 59 + this.getNHeads();
        long $headSize = this.getHeadSize();
        result = result * 59 + (int)($headSize >>> 32 ^ $headSize);
        result = result * 59 + (this.isProjectInput() ? 79 : 97);
        Activation $activation = this.getActivation();
        result = result * 59 + ($activation == null ? 43 : $activation.hashCode());
        result = result * 59 + (this.isHasBias() ? 79 : 97);
        result = result * 59 + this.getTimeSteps();
        return result;
    }

    public static class Builder
    extends SameDiffLayer.Builder<Builder> {
        private int nIn;
        private int nOut;
        private int nHeads;
        private int headSize;
        private boolean projectInput = true;
        private boolean hasBias = true;
        private Activation activation = Activation.TANH;

        public Builder nIn(int nIn) {
            this.nIn = nIn;
            return this;
        }

        public Builder nOut(int nOut) {
            this.nOut = nOut;
            return this;
        }

        public Builder nHeads(int nHeads) {
            this.nHeads = nHeads;
            return this;
        }

        public Builder headSize(int headSize) {
            this.headSize = headSize;
            return this;
        }

        public Builder projectInput(boolean projectInput) {
            this.projectInput = projectInput;
            return this;
        }

        public Builder hasBias(boolean hasBias) {
            this.hasBias = hasBias;
            return this;
        }

        public Builder activation(Activation activation) {
            this.activation = activation;
            return this;
        }

        @Override
        public RecurrentAttentionLayer build() {
            Preconditions.checkArgument((this.projectInput || this.nHeads == 1 ? 1 : 0) != 0, (String)"projectInput must be true when nHeads != 1");
            Preconditions.checkArgument((this.projectInput || this.nIn == this.nOut ? 1 : 0) != 0, (String)"nIn must be equal to nOut when projectInput is false");
            Preconditions.checkArgument((!this.projectInput || this.nOut != 0 ? 1 : 0) != 0, (String)"nOut must be specified when projectInput is true");
            Preconditions.checkArgument((this.nOut % this.nHeads == 0 || this.headSize > 0 ? 1 : 0) != 0, (String)"nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
            return new RecurrentAttentionLayer(this);
        }

        public int getNIn() {
            return this.nIn;
        }

        public int getNOut() {
            return this.nOut;
        }

        public int getNHeads() {
            return this.nHeads;
        }

        public int getHeadSize() {
            return this.headSize;
        }

        public boolean isProjectInput() {
            return this.projectInput;
        }

        public boolean isHasBias() {
            return this.hasBias;
        }

        public Activation getActivation() {
            return this.activation;
        }

        public void setNIn(int nIn) {
            this.nIn = nIn;
        }

        public void setNOut(int nOut) {
            this.nOut = nOut;
        }

        public void setNHeads(int nHeads) {
            this.nHeads = nHeads;
        }

        public void setHeadSize(int headSize) {
            this.headSize = headSize;
        }

        public void setProjectInput(boolean projectInput) {
            this.projectInput = projectInput;
        }

        public void setHasBias(boolean hasBias) {
            this.hasBias = hasBias;
        }

        public void setActivation(Activation activation) {
            this.activation = activation;
        }
    }
}

