/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.recurrent;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.nn.recurrent.RecurrentCell;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class LSTM
extends RecurrentCell {
    private static final LayoutType[] EXPECTED_LAYOUT = new LayoutType[]{LayoutType.TIME, LayoutType.BATCH, LayoutType.CHANNEL};
    private static final byte VERSION = 1;
    private boolean clipLstmState;
    private double lstmStateClipMin;
    private double lstmStateClipMax;
    private List<Parameter> parameters = Arrays.asList(new Parameter("i2iWeight", this, ParameterType.WEIGHT), new Parameter("i2iBias", this, ParameterType.BIAS), new Parameter("h2iWeight", this, ParameterType.WEIGHT), new Parameter("h2iBias", this, ParameterType.BIAS), new Parameter("i2fWeight", this, ParameterType.WEIGHT), new Parameter("i2fBias", this, ParameterType.BIAS), new Parameter("h2fWeight", this, ParameterType.WEIGHT), new Parameter("h2fBias", this, ParameterType.BIAS), new Parameter("i2gWeight", this, ParameterType.WEIGHT), new Parameter("i2gBias", this, ParameterType.BIAS), new Parameter("h2gWeight", this, ParameterType.WEIGHT), new Parameter("h2gBias", this, ParameterType.BIAS), new Parameter("i2oWeight", this, ParameterType.WEIGHT), new Parameter("i2oBias", this, ParameterType.BIAS), new Parameter("h2oWeight", this, ParameterType.WEIGHT), new Parameter("h2oBias", this, ParameterType.BIAS));
    private Parameter state = new Parameter("state", this, ParameterType.OTHER);
    private Parameter stateCell = new Parameter("state_cell", this, ParameterType.OTHER);

    LSTM(Builder builder) {
        super(builder);
        this.mode = "lstm";
        this.clipLstmState = builder.clipLstmState;
        this.lstmStateClipMin = builder.lstmStateClipMin;
        this.lstmStateClipMax = builder.lstmStateClipMax;
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, PairList<String, Object> params) {
        inputs = this.opInputs(parameterStore, inputs);
        NDArrayEx ex = inputs.head().getNDArrayInternal();
        if (this.clipLstmState) {
            return ex.lstm(inputs, this.stateSize, this.dropRate, this.numStackedLayers, this.useSequenceLength, this.useBidirectional, this.stateOutputs, this.lstmStateClipMin, this.lstmStateClipMax, params);
        }
        return ex.rnn(inputs, this.mode, this.stateSize, this.dropRate, this.numStackedLayers, this.useSequenceLength, this.useBidirectional, this.stateOutputs, params);
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
        Shape inputShape = inputShapes[0];
        return new Shape[]{new Shape(inputShape.get(0), inputShape.get(1), this.stateSize)};
    }

    @Override
    public List<Parameter> getDirectParameters() {
        ArrayList<Parameter> directParameters = new ArrayList<Parameter>(this.parameters);
        directParameters.add(this.state);
        directParameters.add(this.stateCell);
        return directParameters;
    }

    @Override
    public void beforeInitialize(Shape[] inputs) {
        this.inputShapes = inputs;
        Shape inputShape = inputs[0];
        Block.validateLayout(EXPECTED_LAYOUT, inputShape.getLayout());
    }

    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        Shape shape = inputShapes[0];
        long channelSize = shape.get(2);
        long batchSize = shape.get(1);
        switch (name) {
            case "i2iWeight": 
            case "i2fWeight": 
            case "i2gWeight": 
            case "i2oWeight": {
                return new Shape(this.stateSize, channelSize);
            }
            case "h2iWeight": 
            case "h2fWeight": 
            case "h2gWeight": 
            case "h2oWeight": {
                return new Shape(this.stateSize, this.stateSize);
            }
            case "h2iBias": 
            case "i2iBias": 
            case "h2fBias": 
            case "i2fBias": 
            case "h2gBias": 
            case "i2gBias": 
            case "h2oBias": 
            case "i2oBias": {
                return new Shape(this.stateSize);
            }
            case "state": 
            case "state_cell": {
                return new Shape(this.numStackedLayers, batchSize, this.stateSize);
            }
        }
        throw new IllegalArgumentException("Invalid parameter name");
    }

    private NDList opInputs(ParameterStore parameterStore, NDList inputs) {
        this.validateInputSize(inputs);
        NDArray head = inputs.head();
        Device device = head.getDevice();
        NDList result = new NDList(head);
        try (NDList parameterList = new NDList();){
            for (Parameter parameter : this.parameters) {
                NDArray array = parameterStore.getValue(parameter, device);
                parameterList.add(array.flatten());
            }
            NDArray array = NDArrays.concat(parameterList);
            result.add(array);
        }
        result.add(parameterStore.getValue(this.state, device));
        result.add(parameterStore.getValue(this.stateCell, device));
        if (this.useSequenceLength) {
            result.add(inputs.get(1));
        }
        return result;
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.writeByte(1);
        for (Parameter parameter : this.parameters) {
            parameter.save(os);
        }
        this.state.save(os);
        this.stateCell.save(os);
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        byte version = is.readByte();
        if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        for (Parameter parameter : this.parameters) {
            parameter.load(manager, is);
        }
        this.state.load(manager, is);
        this.stateCell.load(manager, is);
    }

    public static final class Builder
    extends RecurrentCell.BaseBuilder<Builder> {
        @Override
        protected Builder self() {
            return this;
        }

        public LSTM build() {
            if (this.stateSize == -1L || this.numStackedLayers == -1) {
                throw new IllegalArgumentException("Must set stateSize and numStackedLayers");
            }
            return new LSTM(this);
        }
    }
}

