/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.BlockList;
import ai.djl.nn.LambdaBlock;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;

public class SequentialBlock
extends AbstractBlock {
    private static final byte VERSION = 2;
    private List<Block> blocks = new ArrayList<Block>();

    public SequentialBlock addAll(Block ... blocks) {
        this.blocks.addAll(Arrays.asList(blocks));
        return this;
    }

    public SequentialBlock addAll(Collection<Block> blocks) {
        this.blocks.addAll(blocks);
        return this;
    }

    public SequentialBlock add(Block block) {
        if (block != null) {
            this.blocks.add(block);
        }
        return this;
    }

    public SequentialBlock add(Function<NDList, NDList> f) {
        this.blocks.add(new LambdaBlock(f));
        return this;
    }

    public void removeLastBlock() {
        this.blocks.remove(this.blocks.size() - 1);
    }

    public void replaceLastBlock(Block block) {
        this.removeLastBlock();
        if (block != null) {
            this.blocks.add(block);
        }
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDList current = inputs;
        for (Block block : this.blocks) {
            current = block.forward(parameterStore, current, training);
        }
        return current;
    }

    @Override
    public Shape[] initialize(NDManager manager, DataType dataType, Shape ... inputShapes) {
        this.beforeInitialize(inputShapes);
        Shape[] shapes = inputShapes;
        for (Block child : this.getChildren().values()) {
            shapes = child.initialize(manager, dataType, shapes);
        }
        return this.getOutputShapes(manager, inputShapes);
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
        if (this.blocks.isEmpty()) {
            throw new IllegalArgumentException("The sequential block is empty");
        }
        Shape[] current = inputs;
        for (Block block : this.blocks) {
            current = block.getOutputShapes(manager, current);
        }
        return current;
    }

    @Override
    public List<Parameter> getDirectParameters() {
        return Collections.emptyList();
    }

    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        throw new IllegalArgumentException("SequentialBlocks have no parameters");
    }

    @Override
    public BlockList getChildren() {
        int size = this.blocks.size();
        BlockList children = new BlockList(size);
        int precision = (int)Math.log10(size) + 1;
        String format = "%0" + precision + "d:%s";
        for (int i = 0; i < size; ++i) {
            Block block = this.blocks.get(i);
            String name = String.format(format, i, block.getClass().getSimpleName());
            children.add(name, block);
        }
        return children;
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.writeByte(2);
        this.saveInputShapes(os);
        for (Block block : this.blocks) {
            block.saveParameters(os);
        }
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        byte version = is.readByte();
        if (version == 2) {
            this.readInputShapes(is);
        } else if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        for (Block block : this.blocks) {
            block.loadParameters(manager, is);
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("Sequential(\n");
        for (Block block : this.blocks) {
            String blockString = block.toString().replaceAll("(?m)^", "\t");
            sb.append(blockString).append('\n');
        }
        sb.append(')');
        return sb.toString();
    }
}

