/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.nlp;

import ai.djl.MalformedModelException;
import ai.djl.modality.nlp.Decoder;
import ai.djl.modality.nlp.Encoder;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.BlockList;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class EncoderDecoder
extends AbstractBlock {
    protected Encoder encoder;
    protected Decoder decoder;

    public EncoderDecoder(Encoder encoder, Decoder decoder) {
        this.encoder = encoder;
        this.decoder = decoder;
    }

    @Override
    public PairList<String, Shape> describeInput() {
        if (!this.isInitialized()) {
            throw new IllegalStateException("Parameter of this block are not initialised");
        }
        this.inputNames = Arrays.asList("encoderInput", "decoderInput");
        return new PairList<String, Shape>(this.inputNames, Arrays.asList(this.inputShapes));
    }

    public NDList forward(ParameterStore parameterStore, NDList encoderInputs, NDList decoderInputs, boolean training, PairList<String, Object> params) {
        if (training) {
            NDList encoderOutputs = this.encoder.forward(parameterStore, encoderInputs, true, params);
            this.decoder.initState(this.encoder.getStates(encoderOutputs));
            return this.decoder.forward(parameterStore, decoderInputs, true, params);
        }
        NDList encoderOutputs = this.encoder.forward(parameterStore, new NDList(encoderInputs), false, params);
        this.decoder.initState(this.encoder.getStates(encoderOutputs));
        return this.decoder.forward(parameterStore, new NDList(decoderInputs), false, params);
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        return this.forward(parameterStore, new NDList((NDArray)inputs.get(0)), new NDList((NDArray)inputs.get(1)), training, params);
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training) {
        return this.forward(parameterStore, inputs, training, null);
    }

    @Override
    public Shape[] initialize(NDManager manager, DataType dataType, Shape ... inputShapes) {
        this.beforeInitialize(inputShapes);
        this.encoder.initialize(manager, dataType, inputShapes[0]);
        return this.decoder.initialize(manager, dataType, inputShapes[1]);
    }

    @Override
    public BlockList getChildren() {
        return new BlockList(Arrays.asList("Encoder", "Decoder"), Arrays.asList(this.encoder, this.decoder));
    }

    @Override
    public List<Parameter> getDirectParameters() {
        return Collections.emptyList();
    }

    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        throw new IllegalArgumentException("EncodeDecoder blocks have no direct parameters");
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
        return this.decoder.getOutputShapes(manager, new Shape[]{inputShapes[1]});
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        this.encoder.saveParameters(os);
        this.decoder.saveParameters(os);
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        this.encoder.loadParameters(manager, is);
        this.decoder.loadParameters(manager, is);
    }
}

