/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.core;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.IOException;

public class Prelu
extends AbstractBlock {
    private static final byte VERSION = 2;
    private Parameter alpha = this.addParameter(new Parameter("alpha", this, ParameterType.OTHER), new Shape(new long[0]));

    public Prelu() {
        super((byte)2);
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDArray input = inputs.singletonOrThrow();
        NDArray alphaArr = parameterStore.getValue(this.alpha, input.getDevice());
        return Prelu.prelu(input, alphaArr);
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
        return new Shape[]{inputs[0]};
    }

    @Override
    public void loadMetadata(byte version, DataInputStream is) throws IOException, MalformedModelException {
        if (version == 2) {
            this.readInputShapes(is);
        } else if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
    }

    public static NDList prelu(NDArray input, NDArray alpha) {
        NDArrayEx ex = input.getNDArrayInternal();
        return ex.prelu(input, alpha);
    }
}

