/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.engine;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterList;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.jni.IValue;
import ai.djl.pytorch.jni.IValueUtils;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PtSymbolBlock
extends AbstractSymbolBlock
implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(PtSymbolBlock.class);
    private AtomicReference<Long> handle;
    private String uid;
    private PtNDManager manager;
    private boolean isTrain;
    private PairList<String, Shape> inputDescriptions;
    private PairList<String, Shape> outputDescriptions;
    private boolean first;
    private Map<String, Parameter> parameters;

    public PtSymbolBlock(PtNDManager manager, long handle) {
        this(manager);
        this.handle = new AtomicReference<Long>(handle);
        this.uid = String.valueOf(handle);
        manager.attachInternal(this.uid, new AutoCloseable[]{this});
    }

    public PtSymbolBlock(PtNDManager manager) {
        this.manager = manager;
        this.isTrain = true;
        this.first = true;
    }

    @Override
    public void close() {
        Long pointer = this.handle.getAndSet(null);
        if (pointer != null) {
            JniUtils.deleteModule(pointer);
            this.manager.detachInternal(this.uid);
            this.manager = null;
        }
    }

    public IValue forward(IValue ... inputs) {
        return IValueUtils.forward(this, inputs);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        if (this.isTrain != training) {
            this.isTrain = training;
            if (this.isTrain) {
                JniUtils.enableTrainingMode(this);
            } else {
                JniUtils.enableInferenceMode(this);
            }
        }
        if (System.getProperty("ai.djl.pytorch.graph_optimizer") != null) {
            boolean setOptimizer = Boolean.getBoolean("ai.djl.pytorch.graph_optimizer");
            JniUtils.setGraphExecutorOptimize(setOptimizer);
        }
        if (this.first) {
            PtSymbolBlock ptSymbolBlock = this;
            synchronized (ptSymbolBlock) {
                if (this.first) {
                    this.inputDescriptions = new PairList();
                    this.outputDescriptions = new PairList();
                    for (NDArray array : inputs) {
                        this.inputDescriptions.add((Object)array.getName(), (Object)array.getShape());
                    }
                    NDList outputs = IValueUtils.forward(this, inputs, training);
                    for (NDArray array : outputs) {
                        this.outputDescriptions.add((Object)array.getName(), (Object)array.getShape());
                    }
                    this.first = false;
                    return outputs;
                }
            }
        }
        return IValueUtils.forward(this, inputs, training);
    }

    public PairList<String, Shape> describeInput() {
        if (this.inputDescriptions == null) {
            logger.warn("Input shapes are unknown, please run predict or forward once and call describeInput again.");
        }
        return this.inputDescriptions;
    }

    public ParameterList getDirectParameters() {
        if (this.parameters == null) {
            NDList params = JniUtils.moduleGetParams(this, this.manager);
            this.parameters = new LinkedHashMap<String, Parameter>(params.size());
            for (NDArray param : params) {
                this.parameters.put(param.getName(), Parameter.builder().setName(param.getName()).setType(PtSymbolBlock.inferType(param.getName())).optArray(param).build());
            }
        }
        return new ParameterList(this.parameters);
    }

    private static Parameter.Type inferType(String name) {
        if (name.contains("bias")) {
            return Parameter.Type.BIAS;
        }
        if (name.contains("gamma")) {
            return Parameter.Type.GAMMA;
        }
        if (name.contains("beta")) {
            return Parameter.Type.BETA;
        }
        if (name.contains("moving_mean") || name.contains("running_mean")) {
            return Parameter.Type.RUNNING_MEAN;
        }
        if (name.contains("moving_var") || name.contains("running_var")) {
            return Parameter.Type.RUNNING_VAR;
        }
        if (name.contains("weight")) {
            return Parameter.Type.WEIGHT;
        }
        return Parameter.Type.OTHER;
    }

    public PairList<String, Shape> describeOutput() {
        if (this.outputDescriptions == null) {
            logger.warn("Output shapes are unknown, please run predict or forward once and call describeOutput again.");
        }
        return this.outputDescriptions;
    }

    public Shape[] getOutputShapes(Shape[] inputShapes) {
        try (NDManager manager = NDManager.newBaseManager();){
            NDList list = new NDList();
            for (Shape shape : inputShapes) {
                list.add((Object)manager.ones(shape));
            }
            NDList result = this.forwardInternal(new ParameterStore(manager, false), list, false, null);
            Shape[] shapeArray = (Shape[])result.stream().map(NDArray::getShape).toArray(Shape[]::new);
            return shapeArray;
        }
    }

    public Shape[] getOutputShapes(Shape[] inputShapes, DataType[] dataTypes) {
        try (NDManager manager = NDManager.newBaseManager();){
            NDList list = new NDList();
            for (int i = 0; i < inputShapes.length; ++i) {
                list.add((Object)manager.ones(inputShapes[i], dataTypes == null ? DataType.FLOAT32 : dataTypes[i]));
            }
            NDList result = this.forwardInternal(new ParameterStore(manager, false), list, false, null);
            Shape[] shapeArray = (Shape[])result.stream().map(NDArray::getShape).toArray(Shape[]::new);
            return shapeArray;
        }
    }

    public void saveParameters(DataOutputStream os) throws IOException {
        os.writeByte(this.version);
        JniUtils.writeModule(this, os, true);
    }

    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        byte loadVersion = is.readByte();
        if (loadVersion != this.version) {
            throw new MalformedModelException("Unsupported encoding version: " + loadVersion);
        }
        long rawHandle = JniUtils.loadModuleHandle(is, manager.getDevice(), true, true);
        this.handle = new AtomicReference<Long>(rawHandle);
        this.uid = String.valueOf(rawHandle);
        manager.attachInternal(this.uid, new AutoCloseable[]{this});
    }

    public Long getHandle() {
        Long reference = this.handle.get();
        if (reference == null) {
            throw new IllegalStateException("PyTorch model handle has been released!");
        }
        return reference;
    }
}

