/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;

public class LambdaBlock
extends ParameterBlock {
    private static final byte VERSION = 2;
    private Function<NDList, NDList> lambda;

    public LambdaBlock(Function<NDList, NDList> lambda) {
        this.lambda = lambda;
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        return this.lambda.apply(inputs);
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
        try (NDManager subManager = manager.newSubManager();){
            NDList input = new NDList(inputShapes.length);
            for (Shape shape : inputShapes) {
                input.add(subManager.create(shape));
            }
            NDList output = this.lambda.apply(input);
            Shape[] outputShapes = new Shape[output.size()];
            for (int i = 0; i < output.size(); ++i) {
                outputShapes[i] = ((NDArray)output.get(i)).getShape();
            }
            Shape[] shapeArray = outputShapes;
            return shapeArray;
        }
    }

    @Override
    public List<Parameter> getDirectParameters() {
        return Collections.emptyList();
    }

    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        throw new IllegalArgumentException("LambdaBlocks have no parameters");
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.writeByte(2);
        this.saveInputShapes(os);
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        byte version = is.readByte();
        if (version == 2) {
            this.readInputShapes(is);
        } else if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
    }

    @Override
    public String toString() {
        return "Lambda()";
    }
}

