/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.imports.graphmapper.onnx;

import com.github.os72.protobuf351.ByteString;
import com.github.os72.protobuf351.Message;
import com.google.common.primitives.Floats;
import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import java.io.BufferedInputStream;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.BaseGraphMapper;
import org.nd4j.imports.graphmapper.ImportState;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

public class OnnxGraphMapper
extends BaseGraphMapper<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto, OnnxProto3.TypeProto.Tensor> {
    private static OnnxGraphMapper INSTANCE = new OnnxGraphMapper();

    public static OnnxGraphMapper getInstance() {
        return INSTANCE;
    }

    @Override
    public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) {
        try {
            OnnxProto3.ModelProto graphDef = OnnxProto3.ModelProto.parseFrom(inputFile);
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile, true));
            for (OnnxProto3.NodeProto node : graphDef.getGraph().getNodeList()) {
                bufferedWriter.write(node.toString() + "\n");
            }
            bufferedWriter.flush();
            bufferedWriter.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph) {
        Map<String, Map<String, PropertyMapping>> properties = on.mappingsForFunction();
        Map<String, PropertyMapping> tfProperties = properties.get(mappedTfName);
        Map<String, Field> fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
        Map<String, Map<String, AttributeAdapter>> attributeAdapters = on.attributeAdaptersForFunction();
        for (Map.Entry<String, PropertyMapping> entry : tfProperties.entrySet()) {
            String tfAttrName = entry.getValue().getTfAttrName();
            Field currentField = fields.get(entry.getKey());
            AttributeAdapter adapter = null;
            if (tfAttrName == null || currentField == null) continue;
            if (attributeAdapters != null && !attributeAdapters.isEmpty()) {
                AttributeAdapter adapterFor;
                Map<String, AttributeAdapter> mappers = attributeAdapters.get(on.tensorflowName());
                adapter = adapterFor = mappers.get(entry.getKey());
            }
            if (!attributesForNode.containsKey(tfAttrName)) continue;
            OnnxProto3.AttributeProto attr = attributesForNode.get(tfAttrName);
            switch (attr.getType()) {
                case STRING: {
                    String setString = attr.getS().toStringUtf8();
                    if (adapter != null) {
                        adapter.mapAttributeFor(setString, currentField, on);
                        break;
                    }
                    on.setValueFor(currentField, setString);
                    break;
                }
                case INT: {
                    int setInt = (int)attr.getI();
                    if (adapter != null) {
                        adapter.mapAttributeFor(setInt, currentField, on);
                        break;
                    }
                    on.setValueFor(currentField, setInt);
                    break;
                }
                case INTS: {
                    List<Long> setList = attr.getIntsList();
                    if (setList.isEmpty()) break;
                    int[] intList = Ints.toArray(setList);
                    if (adapter != null) {
                        adapter.mapAttributeFor(intList, currentField, on);
                        break;
                    }
                    on.setValueFor(currentField, intList);
                    break;
                }
                case FLOATS: {
                    List<Float> floatsList = attr.getFloatsList();
                    if (floatsList.isEmpty()) break;
                    float[] floats = Floats.toArray(floatsList);
                    if (adapter != null) {
                        adapter.mapAttributeFor(floats, currentField, on);
                        break;
                    }
                    on.setValueFor(currentField, floats);
                    break;
                }
                case TENSOR: {
                    INDArray tensorToGet = this.mapTensorProto(attr.getT());
                    if (adapter != null) {
                        adapter.mapAttributeFor(tensorToGet, currentField, on);
                        break;
                    }
                    on.setValueFor(currentField, tensorToGet);
                }
            }
        }
    }

    @Override
    public boolean isOpIgnoreException(OnnxProto3.NodeProto node) {
        return false;
    }

    @Override
    public String getTargetMappingForOp(DifferentialFunction function, OnnxProto3.NodeProto node) {
        return function.opName();
    }

    @Override
    public void mapProperty(String name, DifferentialFunction on, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
        PropertyMapping mapping = propertyMappingsForFunction.get(name).get(this.getTargetMappingForOp(on, node));
        Map<String, Field> fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
        Map<String, Object> propsForFunction = on.propertiesForFunction();
        if (mapping.getTfAttrName() == null) {
            int tfMappingIdx = mapping.getTfInputPosition();
            if (tfMappingIdx < 0) {
                tfMappingIdx += node.getInputCount();
            }
            String input = node.getInput(tfMappingIdx);
            OnnxProto3.NodeProto inputNode = OnnxGraphMapper.getInstance().getNodeWithNameFromGraph(graph, input);
            INDArray arr = sameDiff.getArrForVarName(input);
            Field field = fields.get(mapping.getPropertyNames()[0]);
            Class<?> type = field.getType();
            if (type.equals(int[].class)) {
                try {
                    field.set(arr.data().asInt(), on);
                }
                catch (IllegalAccessException e) {
                    e.printStackTrace();
                }
            } else if (type.equals(Integer.TYPE) || type.equals(Long.TYPE) || type.equals(Long.class) || type.equals(Integer.class)) {
                try {
                    field.set(arr.getInt(0), on);
                }
                catch (IllegalAccessException e) {
                    e.printStackTrace();
                }
            } else if (type.equals(Float.TYPE) || type.equals(Double.TYPE) || type.equals(Float.class) || type.equals(Double.class)) {
                try {
                    field.set(arr.getDouble(0L), on);
                }
                catch (IllegalAccessException e) {
                    e.printStackTrace();
                }
            }
        } else {
            String tfMappingAttrName = mapping.getOnnxAttrName();
            OnnxProto3.AttributeProto attr = this.getAttrMap(node).get(tfMappingAttrName);
            OnnxProto3.AttributeProto.AttributeType type = attr.getType();
            Field field = fields.get(mapping.getPropertyNames()[0]);
            Number valueToSet = null;
            switch (type) {
                case INT: {
                    valueToSet = attr.getI();
                    break;
                }
                case FLOAT: {
                    valueToSet = Float.valueOf(attr.getF());
                    break;
                }
                case STRING: {
                    valueToSet = Float.valueOf(attr.getF());
                }
            }
            try {
                field.set(valueToSet, on);
            }
            catch (IllegalAccessException e) {
                e.printStackTrace();
            }
        }
    }

    @Override
    public OnnxProto3.NodeProto getNodeWithNameFromGraph(OnnxProto3.GraphProto graph, String name) {
        for (int i = 0; i < graph.getNodeCount(); ++i) {
            OnnxProto3.NodeProto node = graph.getNode(i);
            if (!node.getName().equals(name)) continue;
            return node;
        }
        return null;
    }

    @Override
    public boolean isPlaceHolderNode(OnnxProto3.TypeProto.Tensor node) {
        return false;
    }

    @Override
    public void dumpBinaryProtoAsText(File inputFile, File outputFile) {
        try {
            OnnxProto3.ModelProto graphDef = OnnxProto3.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile, true));
            for (OnnxProto3.NodeProto node : graphDef.getGraph().getNodeList()) {
                bufferedWriter.write(node.toString());
            }
            bufferedWriter.flush();
            bufferedWriter.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override
    public DifferentialFunction getMappedOp(String name) {
        return DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(name);
    }

    @Override
    public Map<String, OnnxProto3.TypeProto.Tensor> variablesForGraph(OnnxProto3.GraphProto graphProto) {
        int i;
        HashMap<String, OnnxProto3.TypeProto.Tensor> ret = new HashMap<String, OnnxProto3.TypeProto.Tensor>();
        for (i = 0; i < graphProto.getInputCount(); ++i) {
            ret.put(graphProto.getInput(i).getName(), graphProto.getInput(i).getType().getTensorType());
        }
        for (i = 0; i < graphProto.getOutputCount(); ++i) {
            ret.put(graphProto.getOutput(i).getName(), graphProto.getOutput(i).getType().getTensorType());
        }
        for (i = 0; i < graphProto.getNodeCount(); ++i) {
            int j;
            String name;
            OnnxProto3.NodeProto node = graphProto.getNode(i);
            String string = name = node.getName().isEmpty() ? String.valueOf(i) : node.getName();
            if (!ret.containsKey(name)) {
                this.addDummyTensor(name, ret);
            }
            for (j = 0; j < node.getInputCount(); ++j) {
                if (ret.containsKey(node.getInput(j))) continue;
                this.addDummyTensor(node.getInput(j), ret);
            }
            for (j = 0; j < node.getOutputCount(); ++j) {
                if (ret.containsKey(node.getOutput(j))) continue;
                this.addDummyTensor(node.getOutput(j), ret);
            }
        }
        return ret;
    }

    @Override
    public String translateToSameDiffName(String name, OnnxProto3.NodeProto node) {
        return null;
    }

    protected void addDummyTensor(String name, Map<String, OnnxProto3.TypeProto.Tensor> to) {
        OnnxProto3.TensorShapeProto.Dimension dim = OnnxProto3.TensorShapeProto.Dimension.newBuilder().setDimValue(-1L).build();
        OnnxProto3.TypeProto.Tensor typeProto = OnnxProto3.TypeProto.Tensor.newBuilder().setShape(OnnxProto3.TensorShapeProto.newBuilder().addDim(dim).addDim(dim).build()).build();
        to.put(name, typeProto);
    }

    @Override
    public Message.Builder getNewGraphBuilder() {
        return OnnxProto3.GraphProto.newBuilder();
    }

    @Override
    public OnnxProto3.GraphProto parseGraphFrom(byte[] inputStream) throws IOException {
        return OnnxProto3.ModelProto.parseFrom(inputStream).getGraph();
    }

    @Override
    public OnnxProto3.GraphProto parseGraphFrom(InputStream inputStream) throws IOException {
        return OnnxProto3.ModelProto.parseFrom(inputStream).getGraph();
    }

    @Override
    public void mapNodeType(OnnxProto3.NodeProto tfNode, ImportState<OnnxProto3.GraphProto, OnnxProto3.TypeProto.Tensor> importState) {
        DifferentialFunction differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(tfNode.getOpType());
        if (differentialFunction == null) {
            throw new NoOpNameFoundException("No op name found " + tfNode.getOpType());
        }
        SameDiff diff = importState.getSameDiff();
        int idx = importState.getGraph().getNodeList().indexOf(tfNode);
        String name = !tfNode.getName().isEmpty() ? tfNode.getName() : String.valueOf(idx);
        try {
            DifferentialFunction newInstance = (DifferentialFunction)differentialFunction.getClass().newInstance();
            SDVariable[] args = new SDVariable[tfNode.getInputCount()];
            newInstance.setSameDiff(importState.getSameDiff());
            newInstance.initFromOnnx(tfNode, diff, this.getAttrMap(tfNode), importState.getGraph());
            importState.getSameDiff().putFunctionForId(newInstance.getOwnName(), newInstance);
            diff.setBaseNameForFunctionInstanceId(tfNode.getName(), newInstance);
            diff.addVarNameForImport(tfNode.getName());
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override
    public DataBuffer.Type dataTypeForTensor(OnnxProto3.TypeProto.Tensor tensorProto) {
        return this.nd4jTypeFromOnnxType(tensorProto.getElemType());
    }

    @Override
    public boolean unknownTypeNodeImportable(OnnxProto3.TypeProto.Tensor tensor) {
        return false;
    }

    public DataBuffer.Type nd4jTypeFromOnnxType(OnnxProto3.TensorProto.DataType dataType) {
        switch (dataType) {
            case DOUBLE: {
                return DataBuffer.Type.DOUBLE;
            }
            case FLOAT: {
                return DataBuffer.Type.FLOAT;
            }
            case FLOAT16: {
                return DataBuffer.Type.HALF;
            }
            case INT32: 
            case INT64: {
                return DataBuffer.Type.INT;
            }
        }
        return DataBuffer.Type.UNKNOWN;
    }

    @Override
    public String getAttrValueFromNode(OnnxProto3.NodeProto nodeProto, String key) {
        for (OnnxProto3.AttributeProto attributeProto : nodeProto.getAttributeList()) {
            if (!attributeProto.getName().equals(key)) continue;
            return attributeProto.getS().toString();
        }
        throw new ND4JIllegalStateException("No key found for " + key);
    }

    @Override
    public long[] getShapeFromAttribute(OnnxProto3.AttributeProto attributeProto) {
        return Longs.toArray(attributeProto.getT().getDimsList());
    }

    @Override
    public boolean isPlaceHolder(OnnxProto3.TypeProto.Tensor nodeType) {
        return false;
    }

    @Override
    public boolean isConstant(OnnxProto3.TypeProto.Tensor nodeType) {
        return false;
    }

    @Override
    public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) {
        DataBuffer.Type type = this.dataTypeForTensor(tensorProto);
        if (!tensorProto.isInitialized()) {
            throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
        }
        OnnxProto3.TensorProto tensor = null;
        for (int i = 0; i < graph.getInitializerCount(); ++i) {
            OnnxProto3.TensorProto initializer = graph.getInitializer(i);
            if (!initializer.getName().equals(tensorName)) continue;
            tensor = initializer;
            break;
        }
        if (tensor == null) {
            return null;
        }
        ByteString bytes = tensor.getRawData();
        ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
        ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
        directAlloc.put(byteBuffer);
        directAlloc.rewind();
        long[] shape = this.getShapeFromTensor(tensorProto);
        DataBuffer buffer = Nd4j.createBuffer(directAlloc, type, ArrayUtil.prod((long[])shape));
        INDArray arr = Nd4j.create(buffer).reshape(shape);
        return arr;
    }

    public INDArray mapTensorProto(OnnxProto3.TensorProto tensor) {
        if (tensor == null) {
            return null;
        }
        DataBuffer.Type type = this.nd4jTypeFromOnnxType(tensor.getDataType());
        ByteString bytes = tensor.getRawData();
        ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
        ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
        directAlloc.put(byteBuffer);
        directAlloc.rewind();
        long[] shape = this.getShapeFromTensor(tensor);
        DataBuffer buffer = Nd4j.createBuffer(directAlloc, type, ArrayUtil.prod((long[])shape));
        INDArray arr = Nd4j.create(buffer).reshape(shape);
        return arr;
    }

    @Override
    public long[] getShapeFromTensor(OnnxProto3.TypeProto.Tensor tensorProto) {
        long[] ret = new long[Math.max(2, tensorProto.getShape().getDimCount())];
        int dimCount = tensorProto.getShape().getDimCount();
        if (dimCount >= 2) {
            for (int i = 0; i < ret.length; ++i) {
                ret[i] = (int)tensorProto.getShape().getDim(i).getDimValue();
            }
        } else {
            ret[0] = 1L;
            for (int i = 1; i < ret.length; ++i) {
                ret[i] = (int)tensorProto.getShape().getDim(i - 1).getDimValue();
            }
        }
        return ret;
    }

    @Override
    public long[] getShapeFromTensor(OnnxProto3.TensorProto tensorProto) {
        long[] ret = new long[Math.max(2, tensorProto.getDimsCount())];
        int dimCount = tensorProto.getDimsCount();
        if (dimCount >= 2) {
            for (int i = 0; i < ret.length; ++i) {
                ret[i] = (int)tensorProto.getDims(i);
            }
        } else {
            ret[0] = 1L;
            for (int i = 1; i < ret.length; ++i) {
                ret[i] = (int)tensorProto.getDims(i - 1);
            }
        }
        return ret;
    }

    @Override
    public Set<String> opsToIgnore() {
        return Collections.emptySet();
    }

    @Override
    public String getInputFromNode(OnnxProto3.NodeProto node, int index) {
        return node.getInput(index);
    }

    @Override
    public int numInputsFor(OnnxProto3.NodeProto nodeProto) {
        return nodeProto.getInputCount();
    }

    @Override
    public long[] getShapeFromAttr(OnnxProto3.AttributeProto attr) {
        return Longs.toArray(attr.getT().getDimsList());
    }

    @Override
    public Map<String, OnnxProto3.AttributeProto> getAttrMap(OnnxProto3.NodeProto nodeProto) {
        HashMap<String, OnnxProto3.AttributeProto> proto = new HashMap<String, OnnxProto3.AttributeProto>();
        for (int i = 0; i < nodeProto.getAttributeCount(); ++i) {
            OnnxProto3.AttributeProto attributeProto = nodeProto.getAttribute(i);
            proto.put(attributeProto.getName(), attributeProto);
        }
        return proto;
    }

    @Override
    public String getName(OnnxProto3.NodeProto nodeProto) {
        return nodeProto.getName();
    }

    @Override
    public boolean alreadySeen(OnnxProto3.NodeProto nodeProto) {
        return false;
    }

    @Override
    public boolean isVariableNode(OnnxProto3.NodeProto nodeProto) {
        return nodeProto.getOpType().contains("Var");
    }

    @Override
    public boolean shouldSkip(OnnxProto3.NodeProto opType) {
        return false;
    }

    @Override
    public boolean hasShape(OnnxProto3.NodeProto nodeProto) {
        return false;
    }

    @Override
    public long[] getShape(OnnxProto3.NodeProto nodeProto) {
        return null;
    }

    @Override
    public INDArray getArrayFrom(OnnxProto3.NodeProto nodeProto, OnnxProto3.GraphProto graph) {
        return null;
    }

    @Override
    public String getOpType(OnnxProto3.NodeProto nodeProto) {
        return nodeProto.getOpType();
    }

    @Override
    public List<OnnxProto3.NodeProto> getNodeList(OnnxProto3.GraphProto graphProto) {
        return graphProto.getNodeList();
    }
}

