/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.core;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterBlock;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Collections;
import java.util.List;

public class Prelu
extends ParameterBlock {
    private static final byte VERSION = 1;
    private Parameter alpha = new Parameter("alpha", this, ParameterType.OTHER);

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, PairList<String, Object> params) {
        NDArray data = inputs.singletonOrThrow();
        NDList list = new NDList(data, parameterStore.getValue(this.alpha, data.getDevice()));
        NDArrayEx ex = data.getNDArrayInternal();
        return ex.prelu(list, params);
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
        return new Shape[]{inputs[0]};
    }

    @Override
    public List<Parameter> getDirectParameters() {
        return Collections.singletonList(this.alpha);
    }

    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        if ("alpha".equals(name)) {
            return new Shape(new long[0]);
        }
        throw new IllegalArgumentException("Invalid parameter name");
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.writeByte(1);
        this.alpha.save(os);
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        byte version = is.readByte();
        if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        this.alpha.load(manager, is);
    }
}

