/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.ParameterType;
import ai.djl.training.initializer.Initializer;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Objects;
import java.util.UUID;

public class Parameter
implements AutoCloseable {
    private static final byte VERSION = 1;
    private static final int BUFFER_SIZE = 81920;
    private String id = UUID.randomUUID().toString();
    private String name;
    private Block block;
    private ParameterType type;
    private DataType mandatoryDataType;
    private Initializer initializer;
    private NDArray array;
    private boolean requireGrad;

    public Parameter(String name, Block block, ParameterType type) {
        this(name, block, type, true);
    }

    public Parameter(String name, Block block, ParameterType type, boolean requireGrad) {
        this.name = name;
        this.block = block;
        this.type = type;
        this.requireGrad = requireGrad;
        this.initializer = type.getInitializer();
    }

    public String getId() {
        return this.id;
    }

    public String getName() {
        return this.name == null ? "" : this.name;
    }

    public ParameterType getType() {
        return this.type;
    }

    public void setArray(NDArray array) {
        this.array = array;
        array.setName(this.name);
    }

    public NDArray getArray() {
        if (!this.isInitialized()) {
            throw new IllegalStateException("The array has not been initialized");
        }
        return this.array;
    }

    public boolean requireGradient() {
        return this.requireGrad;
    }

    public void setMandatoryDataType(DataType mandatoryDataType) {
        this.mandatoryDataType = mandatoryDataType;
    }

    public boolean isInitialized() {
        return this.array != null;
    }

    public void setInitializer(Initializer initializer, boolean overwrite) {
        if (overwrite || this.initializer == null) {
            this.initializer = initializer;
        }
    }

    public void initialize(NDManager manager, DataType dataType, Shape[] inputShapes) {
        Objects.requireNonNull(this.initializer, "No initializer has been set");
        if (!this.isInitialized()) {
            Shape shape = this.block.getParameterShape(this.name, inputShapes);
            this.array = this.initializer.initialize(manager, shape, this.mandatoryDataType == null ? dataType : this.mandatoryDataType);
            this.array.setName(this.name);
        }
        if (this.requireGradient()) {
            this.array.attachGradient();
        }
    }

    public void save(DataOutputStream dos) throws IOException {
        if (!this.isInitialized()) {
            dos.writeChar(78);
            return;
        }
        dos.writeChar(80);
        dos.writeByte(1);
        dos.writeUTF(this.getName());
        dos.writeUTF(this.array.getSparseFormat().name());
        dos.writeUTF(this.array.getDataType().name());
        Shape shape = this.array.getShape();
        dos.write(shape.getEncoded());
        ByteBuffer bb = this.array.toByteBuffer();
        int length = bb.remaining();
        dos.writeInt(length);
        if (length > 0) {
            byte[] buf;
            if (length > 81920) {
                buf = new byte[81920];
                while (length > 81920) {
                    bb.get(buf);
                    dos.write(buf);
                    length = bb.remaining();
                }
            }
            buf = new byte[length];
            bb.get(buf);
            dos.write(buf);
        }
        dos.flush();
    }

    public void load(NDManager manager, DataInputStream dis) throws IOException, MalformedModelException {
        int length;
        char magic = dis.readChar();
        if (magic == 'N') {
            return;
        }
        if (magic != 'P') {
            throw new MalformedModelException("Invalid input data.");
        }
        byte version = dis.readByte();
        if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        String parameterName = dis.readUTF();
        if (!parameterName.equals(this.getName())) {
            throw new MalformedModelException("Unexpected parameter name: " + parameterName + ", expected: " + this.name);
        }
        dis.readUTF();
        DataType dataType = DataType.valueOf(dis.readUTF());
        Shape shape = Shape.decode(dis);
        ByteBuffer data = manager.allocateDirect(length);
        if (length > 0) {
            byte[] buf = new byte[81920];
            for (length = dis.readInt(); length > 81920; length -= 81920) {
                dis.readFully(buf);
                data.put(buf);
            }
            dis.readFully(buf, 0, length);
            data.put(buf, 0, length);
            data.rewind();
        }
        this.array = manager.create(dataType.asDataType(data), shape);
    }

    @Override
    public void close() {
        if (this.array != null) {
            this.array.close();
            this.array = null;
        }
    }
}

