/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.core;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.IOException;

public class Prelu
extends AbstractBlock {
    private static final byte VERSION = 2;
    private Parameter alpha = this.addParameter(Parameter.builder().setName("alpha").setType(Parameter.Type.WEIGHT).optShape(new Shape(new long[0])).build());

    public Prelu() {
        super((byte)2);
    }

    @Override
    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDArray input = inputs.singletonOrThrow();
        NDArray alphaArr = parameterStore.getValue(this.alpha, input.getDevice(), training);
        return Prelu.prelu(input, alphaArr);
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputs) {
        return new Shape[]{inputs[0]};
    }

    @Override
    public void loadMetadata(byte loadVersion, DataInputStream is) throws IOException, MalformedModelException {
        if (loadVersion == this.version) {
            this.readInputShapes(is);
        } else if (loadVersion != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + loadVersion);
        }
    }

    public static NDList prelu(NDArray input, NDArray alpha) {
        NDArrayEx ex = input.getNDArrayInternal();
        return ex.prelu(input, alpha);
    }
}

