/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.MalformedModelException;
import ai.djl.mxnet.engine.CachedOp;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.engine.Symbol;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterList;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MxSymbolBlock
extends AbstractSymbolBlock {
    private static final Logger logger = LoggerFactory.getLogger(MxSymbolBlock.class);
    private static final byte VERSION = 3;
    private NDManager manager;
    private CachedOp op;
    private Symbol symbol;
    private List<Parameter> mxNetParams;
    private Map<String, Parameter> parameters;
    private Map<String, Shape> paramShapes;
    private Shape[] outputShapes;
    private PairList<String, Shape> inputDescriptions;
    private PairList<String, Shape> outputDescriptions;
    private boolean first;

    public MxSymbolBlock(NDManager manager, Symbol symbol) {
        this(manager);
        this.symbol = symbol;
        this.initBlock();
    }

    public MxSymbolBlock(NDManager manager) {
        super((byte)3);
        this.manager = manager;
    }

    public void setInputNames(List<String> inputNames) {
        this.inputNames = inputNames;
        HashSet<String> nameLookup = new HashSet<String>(inputNames);
        this.parameters = new LinkedHashMap<String, Parameter>(this.mxNetParams.size());
        for (Parameter mxNetParameter : this.mxNetParams) {
            if (nameLookup.contains(mxNetParameter.getName())) continue;
            this.parameters.put(mxNetParameter.getName(), mxNetParameter);
        }
    }

    public List<Parameter> getAllParameters() {
        return this.mxNetParams;
    }

    public List<String> getLayerNames() {
        return this.symbol.getLayerNames();
    }

    public Symbol getSymbol() {
        return this.symbol;
    }

    public void optimizeFor(String optimization) {
        Symbol newSymbol = this.symbol.optimizeFor(optimization, this.manager.getDevice());
        this.symbol.close();
        this.symbol = newSymbol;
    }

    public PairList<String, Shape> describeInput() {
        if (this.inputDescriptions == null) {
            this.inputDescriptions = new PairList();
            for (String name : this.inputNames) {
                logger.warn("Input shapes are unknown, please run predict or forward once and call describeInput again.");
                this.inputDescriptions.add((Object)name, (Object)new Shape(new long[0]));
            }
        }
        return this.inputDescriptions;
    }

    public ParameterList getDirectParameters() {
        return new ParameterList(this.parameters);
    }

    public PairList<String, Shape> describeOutput() {
        if (this.outputDescriptions == null) {
            logger.warn("Output shapes are unknown, please run predict or forward once and call describeOutput again.");
        }
        return this.outputDescriptions;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        if (this.first) {
            MxSymbolBlock mxSymbolBlock = this;
            synchronized (mxSymbolBlock) {
                if (this.first) {
                    this.op = JnaUtils.createCachedOp(this, (MxNDManager)this.manager, training);
                    this.inputDescriptions = new PairList();
                    this.outputDescriptions = new PairList();
                    for (NDArray array : inputs) {
                        this.inputDescriptions.add((Object)array.getName(), (Object)array.getShape());
                    }
                    NDList outputs = this.op.forward(parameterStore, inputs, training);
                    for (NDArray array : outputs) {
                        this.outputDescriptions.add((Object)array.getName(), (Object)array.getShape());
                    }
                    this.first = false;
                    return outputs;
                }
            }
        }
        return this.op.forward(parameterStore, inputs, training);
    }

    public Shape[] getOutputShapes(Shape[] inputShapes) {
        if (this.outputShapes == null) {
            String[] outputNames = this.symbol.getOutputNames();
            this.outputShapes = new Shape[outputNames.length];
            for (int i = 0; i < this.outputShapes.length; ++i) {
                this.outputShapes[i] = this.getParameterShape(outputNames[i], inputShapes);
            }
        }
        return this.outputShapes;
    }

    public void removeLastBlock() {
        List<String> layerNames = this.getLayerNames();
        String layerName = layerNames.get(layerNames.size() - 2);
        Symbol sliced = this.symbol.get(layerName);
        this.symbol.close();
        this.symbol = sliced;
        HashSet<String> set = new HashSet<String>(Arrays.asList(this.symbol.getAllNames()));
        for (int i = this.mxNetParams.size() - 1; i >= 0; --i) {
            Parameter parameter = this.mxNetParams.get(i);
            if (set.contains(parameter.getName())) continue;
            this.mxNetParams.remove(i).close();
            this.parameters.remove(parameter.getName(), parameter);
        }
    }

    private Shape getParameterShape(String name, Shape[] inputShapes) {
        if (this.paramShapes == null) {
            PairList pairs = new PairList();
            for (int i = 0; i < this.inputNames.size(); ++i) {
                pairs.add((Object)((String)this.inputNames.get(i)), (Object)inputShapes[i]);
            }
            this.paramShapes = this.symbol.inferShape((PairList<String, Shape>)pairs);
        }
        if (this.paramShapes.containsKey(name)) {
            return this.paramShapes.get(name);
        }
        throw new IllegalArgumentException("Name " + name + " not found");
    }

    public void saveParameters(DataOutputStream os) throws IOException {
        os.writeByte(3);
        String json = this.symbol.toJsonString();
        byte[] bytes = json.getBytes(StandardCharsets.UTF_8);
        os.writeInt(bytes.length);
        os.write(bytes);
        int size = this.inputNames.size();
        os.writeInt(size);
        for (String name : this.inputNames) {
            os.writeUTF(name);
        }
        for (Parameter parameter : this.mxNetParams) {
            parameter.save(os);
        }
    }

    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        byte version = is.readByte();
        if (version > 3) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        if (version < 3 && this.symbol == null) {
            throw new IllegalStateException("Symbol is required for version 2, please use Model to load");
        }
        if (version == 3) {
            int len = is.readInt();
            byte[] bytes = new byte[len];
            if (is.read(bytes) == -1) {
                throw new MalformedModelException("InputStream ends at symbol loading!");
            }
            this.symbol = Symbol.loadJson((MxNDManager)manager, new String(bytes, StandardCharsets.UTF_8));
            this.initBlock();
        }
        int size = is.readInt();
        for (int i = 0; i < size; ++i) {
            this.inputNames.add(is.readUTF());
        }
        for (Parameter parameter : this.mxNetParams) {
            parameter.load(this.manager, is);
        }
        this.setInputNames(this.inputNames);
    }

    private void initBlock() {
        this.inputNames = new ArrayList();
        String[] allNames = this.symbol.getAllNames();
        this.mxNetParams = new ArrayList<Parameter>(allNames.length);
        HashSet<String> auxNameSet = new HashSet<String>(Arrays.asList(this.symbol.getAuxNames()));
        for (String name : allNames) {
            Parameter.Type type = MxSymbolBlock.inferType(name);
            boolean requireGrad = !auxNameSet.contains(name);
            this.mxNetParams.add(Parameter.builder().setName(name).setType(type).optRequiresGrad(requireGrad).build());
        }
        this.first = true;
    }

    private static Parameter.Type inferType(String name) {
        if (name.endsWith("bias")) {
            return Parameter.Type.BIAS;
        }
        if (name.endsWith("gamma")) {
            return Parameter.Type.GAMMA;
        }
        if (name.endsWith("beta")) {
            return Parameter.Type.BETA;
        }
        if (name.endsWith("moving_mean") || name.endsWith("running_mean")) {
            return Parameter.Type.RUNNING_MEAN;
        }
        if (name.endsWith("moving_var") || name.endsWith("running_var")) {
            return Parameter.Type.RUNNING_VAR;
        }
        if (name.endsWith("weight")) {
            return Parameter.Type.WEIGHT;
        }
        return Parameter.Type.OTHER;
    }
}

