/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.recurrent;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.nn.recurrent.RecurrentCell;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;

public class RNN
extends RecurrentCell {
    private static final LayoutType[] EXPECTED_LAYOUT = new LayoutType[]{LayoutType.TIME, LayoutType.BATCH, LayoutType.CHANNEL};
    private static final byte VERSION = 1;
    private Parameter i2hWeight;
    private Parameter h2hWeight;
    private Parameter i2hBias;
    private Parameter h2hBias;
    private Parameter state;

    RNN(Builder builder) {
        super(builder);
        this.mode = builder.activation == Activation.RELU ? "rnn_relu" : "rnn_tanh";
        this.i2hWeight = new Parameter("i2h_weight", this, ParameterType.WEIGHT);
        this.h2hWeight = new Parameter("h2h_weight", this, ParameterType.WEIGHT);
        this.i2hBias = new Parameter("i2h_bias", this, ParameterType.BIAS);
        this.h2hBias = new Parameter("h2h_bias", this, ParameterType.BIAS);
        this.state = new Parameter("state", this, ParameterType.OTHER);
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, PairList<String, Object> params) {
        inputs = this.opInputs(parameterStore, inputs);
        NDArrayEx ex = inputs.head().getNDArrayInternal();
        return ex.rnn(inputs, this.mode, this.stateSize, this.dropRate, this.numStackedLayers, this.useSequenceLength, this.useBidirectional, this.stateOutputs, params);
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
        Shape inputShape = inputs[0];
        return new Shape[]{new Shape(inputShape.get(0), inputShape.get(1), this.stateSize)};
    }

    @Override
    public List<Parameter> getDirectParameters() {
        return Arrays.asList(this.i2hWeight, this.i2hBias, this.h2hWeight, this.h2hBias, this.state);
    }

    @Override
    public void beforeInitialize(Shape[] inputs) {
        this.inputShapes = inputs;
        Shape inputShape = inputs[0];
        Block.validateLayout(EXPECTED_LAYOUT, inputShape.getLayout());
    }

    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        Shape shape = inputShapes[0];
        long channelSize = shape.get(2);
        long batchSize = shape.get(1);
        switch (name) {
            case "i2h_weight": {
                return new Shape(this.stateSize, channelSize);
            }
            case "h2h_weight": {
                return new Shape(this.stateSize, this.stateSize);
            }
            case "i2h_bias": 
            case "h2h_bias": {
                return new Shape(this.stateSize);
            }
            case "state": {
                return new Shape(this.numStackedLayers, batchSize, this.stateSize);
            }
        }
        throw new IllegalArgumentException("Invalid parameter name");
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.writeByte(1);
        this.i2hWeight.save(os);
        this.h2hWeight.save(os);
        this.i2hBias.save(os);
        this.h2hBias.save(os);
        this.state.save(os);
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        byte version = is.readByte();
        if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        this.i2hWeight.load(manager, is);
        this.h2hWeight.load(manager, is);
        this.i2hBias.load(manager, is);
        this.h2hBias.load(manager, is);
        this.state.load(manager, is);
    }

    private NDList opInputs(ParameterStore parameterStore, NDList inputs) {
        this.validateInputSize(inputs);
        NDArray head = inputs.head();
        Device device = head.getDevice();
        NDList result = new NDList(head);
        try (NDList parameterList = new NDList(4);){
            parameterList.add(parameterStore.getValue(this.i2hWeight, device).flatten());
            parameterList.add(parameterStore.getValue(this.i2hBias, device).flatten());
            parameterList.add(parameterStore.getValue(this.h2hWeight, device).flatten());
            parameterList.add(parameterStore.getValue(this.h2hBias, device).flatten());
            NDArray array = NDArrays.concat(parameterList);
            result.add(array);
        }
        result.add(parameterStore.getValue(this.state, device));
        if (this.useSequenceLength) {
            result.add(inputs.get(1));
        }
        return result;
    }

    public static enum Activation {
        RELU,
        TANH;

    }

    public static final class Builder
    extends RecurrentCell.BaseBuilder<Builder> {
        @Override
        protected Builder self() {
            return this;
        }

        public RNN build() {
            if (this.stateSize == -1L || this.numStackedLayers == -1) {
                throw new IllegalArgumentException("Must set stateSize and numStackedLayers");
            }
            return new RNN(this);
        }
    }
}

