/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn;

import ai.djl.ndarray.NDArray;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;

public final class Blocks {
    private Blocks() {
    }

    public static NDArray batchFlatten(NDArray array) {
        long batch = array.size(0);
        if (batch == 0L) {
            return array.reshape(batch, array.getShape().slice(1).size());
        }
        return array.reshape(batch, -1L);
    }

    public static NDArray batchFlatten(NDArray array, long size) {
        return array.reshape(-1L, size);
    }

    public static Block batchFlattenBlock() {
        return LambdaBlock.singleton(Blocks::batchFlatten);
    }

    public static Block batchFlattenBlock(long size) {
        return LambdaBlock.singleton(array -> Blocks.batchFlatten(array, size));
    }

    public static Block identityBlock() {
        return new LambdaBlock(x -> x);
    }
}

