/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.Block;
import ai.djl.nn.ParameterType;
import ai.djl.training.initializer.Initializer;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Objects;
import java.util.UUID;

public class Parameter
implements AutoCloseable {
    private static final byte VERSION = 1;
    private String id = UUID.randomUUID().toString();
    private String name;
    private Block block;
    private ParameterType type;
    private DataType mandatoryDataType;
    private Initializer initializer;
    private NDArray array;
    private boolean requireGrad;
    private SparseFormat gradientFormat;

    public Parameter(String name, Block block, ParameterType type) {
        this(name, block, type, true, SparseFormat.DENSE);
    }

    public Parameter(String name, Block block, ParameterType type, boolean requireGrad) {
        this(name, block, type, requireGrad, SparseFormat.DENSE);
    }

    public Parameter(String name, Block block, ParameterType type, boolean requireGrad, SparseFormat gradientFormat) {
        this.name = name;
        this.block = block;
        this.type = type;
        this.requireGrad = requireGrad;
        this.initializer = type.getInitializer();
        this.gradientFormat = gradientFormat;
    }

    public String getId() {
        return this.id;
    }

    public String getName() {
        return this.name == null ? "" : this.name;
    }

    public ParameterType getType() {
        return this.type;
    }

    public void setArray(NDArray array) {
        this.array = array;
        array.setName(this.name);
    }

    public NDArray getArray() {
        if (!this.isInitialized()) {
            throw new IllegalStateException("The array has not been initialized");
        }
        return this.array;
    }

    public boolean requireGradient() {
        return this.requireGrad;
    }

    public void setMandatoryDataType(DataType mandatoryDataType) {
        this.mandatoryDataType = mandatoryDataType;
    }

    public boolean isInitialized() {
        return this.array != null;
    }

    public void setInitializer(Initializer initializer, boolean overwrite) {
        if (overwrite || this.initializer == null) {
            this.initializer = initializer;
        }
    }

    public void initialize(NDManager manager, DataType dataType, Shape[] inputShapes) {
        Objects.requireNonNull(this.initializer, "No initializer has been set");
        if (!this.isInitialized()) {
            Shape shape = this.block.getParameterShape(this.name, inputShapes);
            this.array = this.initializer.initialize(manager, shape, this.mandatoryDataType == null ? dataType : this.mandatoryDataType);
            this.array.setName(this.name);
        }
        if (this.requireGradient()) {
            this.array.attachGradient(this.gradientFormat);
        }
    }

    public void save(DataOutputStream dos) throws IOException {
        if (!this.isInitialized()) {
            dos.writeChar(78);
            return;
        }
        dos.writeChar(80);
        dos.writeByte(1);
        dos.writeUTF(this.getName());
        dos.write(this.array.encode());
    }

    public void load(NDManager manager, DataInputStream dis) throws IOException, MalformedModelException {
        char magic = dis.readChar();
        if (magic == 'N') {
            return;
        }
        if (magic != 'P') {
            throw new MalformedModelException("Invalid input data.");
        }
        byte version = dis.readByte();
        if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        String parameterName = dis.readUTF();
        if (!parameterName.equals(this.getName())) {
            throw new MalformedModelException("Unexpected parameter name: " + parameterName + ", expected: " + this.name);
        }
        this.array = manager.decode(dis);
    }

    @Override
    public void close() {
        if (this.array != null) {
            this.array.close();
            this.array = null;
        }
    }
}

