/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.UninitializedParameterException;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.initializer.XavierInitializer;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Objects;
import java.util.UUID;

public class Parameter
implements AutoCloseable {
    private static final byte VERSION = 1;
    private String id = UUID.randomUUID().toString();
    private String name;
    private Shape shape;
    private Type type;
    private Initializer initializer;
    private NDArray array;
    private boolean requiresGrad;

    Parameter(Builder builder) {
        this.name = builder.name;
        this.shape = builder.shape;
        this.type = builder.type;
        this.array = builder.array;
        this.requiresGrad = builder.requiresGrad;
        this.initializer = builder.initializer != null ? builder.initializer : this.type.getInitializer();
    }

    public String getId() {
        return this.id;
    }

    public String getName() {
        return this.name == null ? "" : this.name;
    }

    public Type getType() {
        return this.type;
    }

    public void setArray(NDArray array) {
        if (this.shape != null) {
            throw new IllegalStateException("array has been set! Use either setArray or setShape");
        }
        this.array = array;
        this.shape = array.getShape();
        array.setName(this.name);
    }

    public void setShape(Shape shape) {
        if (this.array != null) {
            throw new IllegalStateException("array has been set! Use either setArray or setShape");
        }
        this.shape = shape;
    }

    public Shape getShape() {
        return this.shape;
    }

    public NDArray getArray() {
        if (!this.isInitialized()) {
            throw new UninitializedParameterException("The array for parameter \"" + this.getName() + "\" has not been initialized");
        }
        return this.array;
    }

    public boolean requiresGradient() {
        return this.requiresGrad;
    }

    public void freeze(boolean freeze) {
        this.requiresGrad = !freeze;
        this.array.setRequiresGradient(this.requiresGrad);
    }

    public boolean isInitialized() {
        return this.array != null;
    }

    public void setInitializer(Initializer initializer) {
        this.initializer = initializer;
    }

    public Initializer getInitializer() {
        return this.initializer;
    }

    public void initialize(NDManager manager, DataType dataType) {
        if (!this.isInitialized()) {
            Objects.requireNonNull(this.initializer, "No initializer has been set");
            Objects.requireNonNull(this.shape, "No parameter shape has been set");
            this.array = this.initializer.initialize(manager, this.shape, dataType);
            this.array.setName(this.name);
        }
        if (this.requiresGradient()) {
            this.array.setRequiresGradient(true);
        }
    }

    public void save(DataOutputStream dos) throws IOException {
        if (!this.isInitialized()) {
            dos.writeChar(78);
            return;
        }
        dos.writeChar(80);
        dos.writeByte(1);
        dos.writeUTF(this.getName());
        dos.write(this.array.encode());
    }

    public void load(NDManager manager, DataInputStream dis) throws IOException, MalformedModelException {
        char magic = dis.readChar();
        if (magic == 'N') {
            return;
        }
        if (magic != 'P') {
            throw new MalformedModelException("Invalid input data.");
        }
        byte version = dis.readByte();
        if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        String parameterName = dis.readUTF();
        if (!parameterName.equals(this.getName())) {
            throw new MalformedModelException("Unexpected parameter name: " + parameterName + ", expected: " + this.name);
        }
        this.array = manager.decode(dis);
        this.shape = this.array.getShape();
    }

    @Override
    public void close() {
        if (this.array != null) {
            this.array.close();
            this.array = null;
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        String name;
        Shape shape;
        Type type;
        Initializer initializer;
        NDArray array;
        boolean requiresGrad = true;

        public Builder setName(String name) {
            this.name = name;
            return this;
        }

        public Builder setType(Type type) {
            this.type = type;
            return this;
        }

        public Builder optShape(Shape shape) {
            this.shape = shape;
            return this;
        }

        public Builder optInitializer(Initializer initializer) {
            this.initializer = initializer;
            return this;
        }

        public Builder optArray(NDArray array) {
            this.array = array;
            return this;
        }

        public Builder optRequiresGrad(boolean requiresGrad) {
            this.requiresGrad = requiresGrad;
            return this;
        }

        public Parameter build() {
            return new Parameter(this);
        }
    }

    public static enum Type {
        WEIGHT(new XavierInitializer(XavierInitializer.RandomType.GAUSSIAN, XavierInitializer.FactorType.IN, 2.0f)),
        BIAS(Initializer.ZEROS),
        GAMMA(Initializer.ONES),
        BETA(Initializer.ZEROS),
        RUNNING_MEAN(Initializer.ZEROS),
        RUNNING_VAR(Initializer.ONES),
        OTHER(null);

        private final transient Initializer initializer;

        private Type(Initializer initializer) {
            this.initializer = initializer;
        }

        public Initializer getInitializer() {
            return this.initializer;
        }
    }
}

