/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.imports.graphmapper.tf;

import java.io.BufferedInputStream;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
import org.nd4j.imports.graphmapper.BaseGraphMapper;
import org.nd4j.imports.graphmapper.ImportState;
import org.nd4j.imports.graphmapper.OpImportFilter;
import org.nd4j.imports.graphmapper.OpImportOverride;
import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper;
import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMappers;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.controlflow.IfImportState;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.shade.guava.primitives.Floats;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.shade.protobuf.Message;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.OpDef;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;

public class TFGraphMapper
extends BaseGraphMapper<GraphDef, NodeDef, AttrValue, NodeDef> {
    private static final Logger log = LoggerFactory.getLogger(TFGraphMapper.class);
    private Set<String> seenNodes = new LinkedHashSet<String>();
    public static final String VALUE_ATTR_KEY = "value";
    public static final String SHAPE_KEY = "shape";
    private static TFGraphMapper MAPPER_INSTANCE = new TFGraphMapper();
    private Set<String> graphMapper = new HashSet<String>(){
        {
            this.add("LoopCond");
            this.add("Merge");
            this.add("Exit");
            this.add("NextIteration");
            this.add("NoOp");
            this.add("Switch");
        }
    };

    private TFGraphMapper() {
    }

    public static TFGraphMapper getInstance() {
        return MAPPER_INSTANCE;
    }

    @Override
    public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) {
        try {
            GraphDef graphDef = GraphDef.parseFrom(inputFile);
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile, true));
            for (NodeDef node : graphDef.getNodeList()) {
                bufferedWriter.write(node.toString());
            }
            bufferedWriter.flush();
            bufferedWriter.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override
    public boolean isOpIgnoreException(NodeDef node) {
        return true;
    }

    @Override
    public String getTargetMappingForOp(DifferentialFunction function, NodeDef node) {
        return function.opName();
    }

    @Override
    public NodeDef getNodeWithNameFromGraph(GraphDef graph, String name) {
        for (int i = 0; i < graph.getNodeCount(); ++i) {
            NodeDef node = graph.getNode(i);
            if (!node.getName().equals(name)) continue;
            return node;
        }
        return null;
    }

    @Override
    public void mapProperty(String name, DifferentialFunction on, NodeDef node, GraphDef graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
        if (node == null) {
            throw new ND4JIllegalStateException("No node found for name " + name);
        }
        PropertyMapping mapping = propertyMappingsForFunction.get(this.getOpType(node)).get(name);
        Map<String, Field> fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
        if (mapping.getTfInputPosition() != null && mapping.getTfInputPosition() < node.getInputCount()) {
            int tfMappingIdx = mapping.getTfInputPosition();
            if (tfMappingIdx < 0) {
                tfMappingIdx += node.getInputCount();
            }
            String input = node.getInput(tfMappingIdx);
            NodeDef inputNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, input);
            INDArray arr = this.getArrayFrom(inputNode, graph);
            if (arr == null && sameDiff.hasVariable(input)) {
                arr = sameDiff.getArrForVarName(input);
            }
            if (arr == null && inputNode != null) {
                sameDiff.addPropertyToResolve(on, name);
                sameDiff.addVariableMappingForField(on, name, this.getNodeName(inputNode.getName()));
                return;
            }
            if (inputNode == null) {
                return;
            }
            Field field = fields.get(name);
            Class<?> type = field.getType();
            if (type.equals(int[].class)) {
                on.setValueFor(field, arr.data().asInt());
            } else if (type.equals(Integer.TYPE) || type.equals(Long.TYPE) || type.equals(Long.class) || type.equals(Integer.class)) {
                if (mapping.getShapePosition() != null) {
                    on.setValueFor(field, arr.size(mapping.getShapePosition()));
                } else {
                    on.setValueFor(field, arr.getInt(0));
                }
            } else if (type.equals(Float.TYPE) || type.equals(Double.TYPE) || type.equals(Float.class) || type.equals(Double.class)) {
                on.setValueFor(field, arr.getDouble(0L));
            }
        } else {
            String tfMappingAttrName = mapping.getTfAttrName();
            if (tfMappingAttrName == null) {
                return;
            }
            if (!node.containsAttr(tfMappingAttrName)) {
                return;
            }
            AttrValue attr = node.getAttrOrThrow(tfMappingAttrName);
            org.tensorflow.framework.DataType type = attr.getType();
            if (fields == null) {
                throw new ND4JIllegalStateException("No fields found for op [" + mapping + "]");
            }
            if (mapping.getPropertyNames() == null) {
                throw new ND4JIllegalStateException("no property found for [" + name + "] in op [" + on.opName() + "]");
            }
            Field field = fields.get(mapping.getPropertyNames()[0]);
            Object valueToSet = null;
            switch (type) {
                case DT_BOOL: {
                    valueToSet = attr.getB();
                    break;
                }
                case DT_INT8: {
                    valueToSet = attr.getI();
                    break;
                }
                case DT_INT16: {
                    valueToSet = attr.getI();
                    break;
                }
                case DT_INT32: {
                    valueToSet = attr.getI();
                    break;
                }
                case DT_FLOAT: {
                    valueToSet = Float.valueOf(attr.getF());
                    break;
                }
                case DT_DOUBLE: {
                    valueToSet = Float.valueOf(attr.getF());
                    break;
                }
                case DT_STRING: {
                    valueToSet = attr.getS();
                    break;
                }
                case DT_INT64: {
                    valueToSet = attr.getI();
                }
            }
            if (field != null && valueToSet != null) {
                on.setValueFor(field, valueToSet);
            }
        }
    }

    @Override
    public boolean isPlaceHolderNode(NodeDef node) {
        return node.getOp().startsWith("Placeholder");
    }

    @Override
    public void dumpBinaryProtoAsText(File inputFile, File outputFile) {
        try {
            GraphDef graphDef = GraphDef.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile, true));
            for (NodeDef node : graphDef.getNodeList()) {
                bufferedWriter.write(node.toString());
            }
            bufferedWriter.flush();
            bufferedWriter.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override
    public long[] getShapeFromAttr(AttrValue attr) {
        return this.shapeFromShapeProto(attr.getShape());
    }

    @Override
    public Map<String, AttrValue> getAttrMap(NodeDef nodeDef) {
        return nodeDef.getAttrMap();
    }

    @Override
    public String getName(NodeDef nodeDef) {
        return nodeDef.getName();
    }

    @Override
    public boolean alreadySeen(NodeDef nodeDef) {
        return this.seenNodes.contains(nodeDef.getName());
    }

    @Override
    public boolean isVariableNode(NodeDef nodeDef) {
        boolean isVar = nodeDef.getOp().startsWith("VariableV") || nodeDef.getOp().equalsIgnoreCase("const");
        return isVar;
    }

    @Override
    public boolean shouldSkip(NodeDef opType) {
        if (opType == null) {
            return true;
        }
        boolean endsWithRead = opType.getName().endsWith("/read");
        return endsWithRead;
    }

    @Override
    public boolean hasShape(NodeDef nodeDef) {
        return nodeDef.containsAttr(SHAPE_KEY);
    }

    @Override
    public long[] getShape(NodeDef nodeDef) {
        return this.getShapeFromAttr(nodeDef.getAttrOrThrow(SHAPE_KEY));
    }

    @Override
    public INDArray getArrayFrom(NodeDef nodeDef, GraphDef graph) {
        if (nodeDef == null) {
            return null;
        }
        return this.getNDArrayFromTensor(nodeDef.getName(), nodeDef, graph);
    }

    @Override
    public String getOpType(NodeDef nodeDef) {
        return nodeDef.getOp();
    }

    @Override
    public List<NodeDef> getNodeList(GraphDef graphDef) {
        return graphDef.getNodeList();
    }

    @Override
    public DifferentialFunction getMappedOp(String name) {
        return DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(name);
    }

    public String getNodeName(String name) {
        String ret = name;
        if (ret.startsWith("^")) {
            ret = ret.substring(1);
        }
        if (ret.endsWith("/read")) {
            ret = ret.replace("/read", "");
        }
        if (ret.endsWith(":0")) {
            ret = ret.substring(0, ret.length() - 2);
        }
        return ret;
    }

    public boolean isControlDependency(String name) {
        return name.startsWith("^");
    }

    @Override
    public Map<String, NodeDef> variablesForGraph(GraphDef graphDef) {
        LinkedHashMap<String, NodeDef> ret = new LinkedHashMap<String, NodeDef>();
        List<NodeDef> nodeList = graphDef.getNodeList();
        for (NodeDef nodeDef : nodeList) {
            if (nodeDef.getName().endsWith("/read")) continue;
            String name = this.translateToSameDiffName(nodeDef.getName(), nodeDef);
            ret.put(name, nodeDef);
        }
        return ret;
    }

    @Override
    public String translateToSameDiffName(String name, NodeDef node) {
        if (this.isVariableNode(node) || this.isPlaceHolder(node)) {
            return name;
        }
        StringBuilder stringBuilder = new StringBuilder();
        if (name.contains(":")) {
            name = name.substring(0, name.lastIndexOf(58));
            stringBuilder.append(name);
        } else {
            stringBuilder.append(name);
        }
        return stringBuilder.toString();
    }

    public String varNameToOpName(String varName) {
        int idx = varName.lastIndexOf(58);
        if (idx < 0) {
            return varName;
        }
        return varName.substring(0, idx);
    }

    public static int varNameToOpOutputNumber(String varName) {
        int idx = varName.lastIndexOf(58);
        if (idx < 0) {
            return 0;
        }
        String n = varName.substring(idx + 1);
        return Integer.parseInt(n);
    }

    @Override
    public Message.Builder getNewGraphBuilder() {
        return GraphDef.newBuilder();
    }

    @Override
    public GraphDef parseGraphFrom(byte[] inputStream) throws IOException {
        return GraphDef.parseFrom(inputStream);
    }

    @Override
    public GraphDef parseGraphFrom(InputStream inputStream) throws IOException {
        return GraphDef.parseFrom(inputStream);
    }

    protected void importCondition(String conditionName, NodeDef tfNode, ImportState<GraphDef, NodeDef> importState) {
    }

    @Override
    public void mapNodeType(NodeDef tfNode, ImportState<GraphDef, NodeDef> importState, OpImportOverride<GraphDef, NodeDef, AttrValue> importOverride, OpImportFilter<GraphDef, NodeDef, AttrValue> opFilter) {
        block28: {
            SameDiff diff;
            block26: {
                AttrValue shape;
                long[] shapeArr;
                int dims;
                Map<String, AttrValue> attributes;
                ArrayList<Long> dimensions;
                block27: {
                    if (this.shouldSkip(tfNode) || this.alreadySeen(tfNode) || this.isVariableNode(tfNode)) {
                        return;
                    }
                    diff = importState.getSameDiff();
                    if (!this.isVariableNode(tfNode)) break block26;
                    dimensions = new ArrayList<Long>();
                    attributes = this.getAttrMap(tfNode);
                    if (!attributes.containsKey(VALUE_ATTR_KEY)) break block27;
                    diff.var(this.getName(tfNode), this.getArrayFrom(tfNode, importState.getGraph()));
                    break block28;
                }
                if (!attributes.containsKey(SHAPE_KEY) || (dims = (shapeArr = this.getShapeFromAttr(shape = attributes.get(SHAPE_KEY))).length) <= 0) break block28;
                if (dims == 1) {
                    dimensions.add(1L);
                }
                for (int e = 0; e < dims; ++e) {
                    dimensions.add(this.getShapeFromAttr(shape)[e]);
                }
                break block28;
            }
            if (this.isPlaceHolder(tfNode)) {
                SDVariable var = diff.getVariable(this.getName(tfNode));
                Preconditions.checkState((boolean)var.isPlaceHolder(), (String)"Variable should be marked as placeholder at this point: %s", (Object)var);
            } else {
                String opName = tfNode.getOp();
                if (importOverride != null) {
                    int numInputs = tfNode.getInputCount();
                    ArrayList<SDVariable> inputs = new ArrayList<SDVariable>(numInputs);
                    ArrayList<SDVariable> controlDeps = null;
                    for (int i = 0; i < numInputs; ++i) {
                        String inName = tfNode.getInput(i);
                        boolean controlDep = this.isControlDependency(inName);
                        String name = this.getNodeName(inName);
                        SDVariable v = diff.getVariable(name);
                        if (v == null) {
                            boolean shouldSkip = false;
                            if (opFilter != null) {
                                List<NodeDef> l = importState.getGraph().getNodeList();
                                NodeDef inputNodeDef = null;
                                for (NodeDef nd : l) {
                                    if (!inName.equals(nd.getName())) continue;
                                    inputNodeDef = nd;
                                    break;
                                }
                                Preconditions.checkState((inputNodeDef != null ? 1 : 0) != 0, (String)"Could not find node with name \"%s\"", (Object)inName);
                                shouldSkip = true;
                            }
                            if (!shouldSkip) {
                                int outputIdx;
                                String inputOpName = this.varNameToOpName(inName);
                                NodeDef inputOp = importState.getVariables().get(inputOpName);
                                DataType dt = this.dataTypeForTensor(inputOp, outputIdx = TFGraphMapper.varNameToOpOutputNumber(name));
                                if (dt == DataType.UNKNOWN) {
                                    dt = null;
                                }
                                v = diff.var(name, VariableType.ARRAY, null, dt, (long[])null);
                            }
                        }
                        if (controlDep) {
                            if (controlDeps == null) {
                                controlDeps = new ArrayList<SDVariable>();
                            }
                            controlDeps.add(v);
                            continue;
                        }
                        inputs.add(v);
                    }
                    log.info("Importing op {} using override {}", (Object)opName, importOverride);
                    importOverride.initFromTensorFlow(inputs, controlDeps, tfNode, diff, this.getAttrMap(tfNode), importState.getGraph());
                } else {
                    DifferentialFunction differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName);
                    if (differentialFunction == null) {
                        throw new ND4JIllegalStateException("No tensorflow op found for " + opName + " possibly missing operation class?");
                    }
                    try {
                        DifferentialFunction newInstance = (DifferentialFunction)differentialFunction.getClass().newInstance();
                        ArrayList<SDVariable> args = new ArrayList<SDVariable>();
                        ArrayList<String> controlDeps = null;
                        newInstance.setOwnName(tfNode.getName());
                        boolean x = false;
                        for (int i = 0; i < tfNode.getInputCount(); ++i) {
                            String inName = tfNode.getInput(i);
                            String inputOpName = this.varNameToOpName(inName);
                            NodeDef inputNode = importState.getVariables().get(inputOpName);
                            if (this.shouldSkip(inputNode) && !inName.endsWith("/read")) continue;
                            boolean controlDep = this.isControlDependency(inName);
                            String name = this.getNodeName(inName);
                            SDVariable v = diff.getVariable(name);
                            if (v == null) {
                                int outputIdx;
                                NodeDef inputOp = importState.getVariables().get(inputOpName);
                                DataType dt = this.dataTypeForTensor(inputOp, outputIdx = TFGraphMapper.varNameToOpOutputNumber(name));
                                if (dt == DataType.UNKNOWN) {
                                    dt = null;
                                }
                                v = diff.var(name, VariableType.ARRAY, null, dt, (long[])null);
                            }
                            if (controlDep) {
                                if (controlDeps == null) {
                                    controlDeps = new ArrayList<String>();
                                }
                                if (controlDeps.contains(name)) continue;
                                controlDeps.add(name);
                                continue;
                            }
                            args.add(v);
                        }
                        diff.addArgsFor(args.toArray(new SDVariable[args.size()]), newInstance);
                        newInstance.setSameDiff(importState.getSameDiff());
                        if (controlDeps != null) {
                            SameDiffOp op = diff.getOps().get(newInstance.getOwnName());
                            op.setControlDeps(controlDeps);
                            for (String s : controlDeps) {
                                List<String> l;
                                Variable v = diff.getVariables().get(s);
                                if (v.getControlDepsForOp() == null) {
                                    v.setControlDeps(new ArrayList<String>());
                                }
                                if ((l = v.getControlDepsForOp()).contains(op.getName())) continue;
                                l.add(op.getName());
                            }
                        }
                        newInstance.initFromTensorFlow(tfNode, diff, this.getAttrMap(tfNode), importState.getGraph());
                        this.mapProperties(newInstance, tfNode, importState.getGraph(), importState.getSameDiff(), newInstance.mappingsForFunction());
                        importState.getSameDiff().putOpForId(newInstance.getOwnName(), newInstance);
                        diff.setBaseNameForFunctionInstanceId(tfNode.getName(), newInstance);
                    }
                    catch (Exception e) {
                        log.error("Failed to import op [{}]", (Object)opName);
                        throw new RuntimeException(e);
                    }
                }
            }
        }
    }

    public void initFunctionFromProperties(DifferentialFunction on, Map<String, AttrValue> attributesForNode, NodeDef node, GraphDef graph) {
        this.initFunctionFromProperties(on.tensorflowName(), on, attributesForNode, node, graph);
    }

    public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, AttrValue> attributesForNode, NodeDef node, GraphDef graph) {
        Map<String, PropertyMapping> map;
        Map<String, Map<String, PropertyMapping>> properties = on.mappingsForFunction();
        Map<String, PropertyMapping> tfProperties = properties.get(mappedTfName);
        Map<String, Field> fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
        Map<String, Map<String, AttributeAdapter>> attributeAdapters = on.attributeAdaptersForFunction();
        if (tfProperties == null) {
            return;
        }
        if (attributeAdapters == null || !attributeAdapters.containsKey(mappedTfName)) {
            map = tfProperties;
        } else {
            map = new LinkedHashMap<String, PropertyMapping>();
            for (Map.Entry<String, PropertyMapping> e : tfProperties.entrySet()) {
                if (attributeAdapters.get(mappedTfName).containsKey(e.getKey())) continue;
                map.put(e.getKey(), e.getValue());
            }
            for (Map.Entry<String, PropertyMapping> e : tfProperties.entrySet()) {
                if (map.containsKey(e.getKey())) continue;
                map.put(e.getKey(), e.getValue());
            }
        }
        for (Map.Entry<String, PropertyMapping> entry : map.entrySet()) {
            NodeDef inputFromNode;
            INDArray tensor;
            String tfAttrName = entry.getValue().getTfAttrName();
            Field currentField = fields.get(entry.getKey());
            AttributeAdapter adapter = null;
            if (attributeAdapters != null && !attributeAdapters.isEmpty()) {
                AttributeAdapter adapterFor;
                Map<String, AttributeAdapter> mappers = attributeAdapters.get(mappedTfName);
                adapter = adapterFor = mappers.get(entry.getKey());
            }
            if (tfAttrName != null) {
                if (currentField == null || !attributesForNode.containsKey(tfAttrName)) continue;
                AttrValue attr = attributesForNode.get(tfAttrName);
                switch (attr.getValueCase()) {
                    case B: {
                        if (adapter == null) break;
                        adapter.mapAttributeFor(attr.getB(), currentField, on);
                        break;
                    }
                    case F: {
                        break;
                    }
                    case FUNC: {
                        break;
                    }
                    case S: {
                        String setString = attr.getS().toStringUtf8();
                        if (adapter != null) {
                            adapter.mapAttributeFor(setString, currentField, on);
                            break;
                        }
                        on.setValueFor(currentField, setString);
                        break;
                    }
                    case I: {
                        int setInt = (int)attr.getI();
                        if (adapter != null) {
                            adapter.mapAttributeFor(setInt, currentField, on);
                            break;
                        }
                        on.setValueFor(currentField, setInt);
                        break;
                    }
                    case SHAPE: {
                        List<TensorShapeProto.Dim> shape = attr.getShape().getDimList();
                        int[] dimsToSet = new int[shape.size()];
                        for (int i = 0; i < dimsToSet.length; ++i) {
                            dimsToSet[i] = (int)shape.get(i).getSize();
                        }
                        if (adapter != null) {
                            adapter.mapAttributeFor(dimsToSet, currentField, on);
                            break;
                        }
                        on.setValueFor(currentField, dimsToSet);
                        break;
                    }
                    case VALUE_NOT_SET: {
                        break;
                    }
                    case PLACEHOLDER: {
                        break;
                    }
                    case LIST: {
                        AttrValue.ListValue setList = attr.getList();
                        if (!setList.getIList().isEmpty()) {
                            int[] intList = Ints.toArray(setList.getIList());
                            if (adapter != null) {
                                adapter.mapAttributeFor(intList, currentField, on);
                                break;
                            }
                            on.setValueFor(currentField, intList);
                            break;
                        }
                        if (!setList.getBList().isEmpty()) break;
                        if (!setList.getFList().isEmpty()) {
                            float[] floats = Floats.toArray(setList.getFList());
                            if (adapter != null) {
                                adapter.mapAttributeFor(floats, currentField, on);
                                break;
                            }
                            on.setValueFor(currentField, floats);
                            break;
                        }
                        if (!setList.getFuncList().isEmpty() || setList.getTensorList().isEmpty()) break;
                        break;
                    }
                    case TENSOR: {
                        INDArray tensorToGet = TFGraphMapper.getInstance().mapTensorProto(attr.getTensor());
                        if (adapter != null) {
                            adapter.mapAttributeFor(tensorToGet, currentField, on);
                            break;
                        }
                        on.setValueFor(currentField, tensorToGet);
                        break;
                    }
                    case TYPE: {
                        if (adapter == null) break;
                        adapter.mapAttributeFor((Object)attr.getType(), currentField, on);
                    }
                }
                continue;
            }
            if (entry.getValue().getTfInputPosition() == null) continue;
            int position = entry.getValue().getTfInputPosition();
            if (position < 0) {
                position += node.getInputCount();
            }
            INDArray iNDArray = tensor = (inputFromNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, node.getInput(position))) != null ? TFGraphMapper.getInstance().getNDArrayFromTensor(VALUE_ATTR_KEY, inputFromNode, graph) : null;
            if (tensor == null) {
                tensor = on.getSameDiff().getArrForVarName(this.getNodeName(node.getInput(position)));
            }
            if (tensor != null) {
                if (adapter != null) {
                    adapter.mapAttributeFor(tensor, currentField, on);
                    continue;
                }
                if (currentField.getType().equals(int[].class)) {
                    on.setValueFor(currentField, tensor.data().asInt());
                    continue;
                }
                if (currentField.getType().equals(double[].class)) {
                    on.setValueFor(currentField, tensor.data().asDouble());
                    continue;
                }
                if (currentField.getType().equals(float[].class)) {
                    on.setValueFor(currentField, tensor.data().asFloat());
                    continue;
                }
                if (currentField.getType().equals(INDArray.class)) {
                    on.setValueFor(currentField, tensor);
                    continue;
                }
                if (currentField.getType().equals(Integer.TYPE)) {
                    on.setValueFor(currentField, tensor.getInt(0));
                    continue;
                }
                if (currentField.getType().equals(Double.TYPE)) {
                    on.setValueFor(currentField, tensor.getDouble(0L));
                    continue;
                }
                if (!currentField.getType().equals(Float.TYPE)) continue;
                on.setValueFor(currentField, Float.valueOf(tensor.getFloat(0L)));
                continue;
            }
            on.getSameDiff().addPropertyToResolve(on, entry.getKey());
        }
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public DataType dataTypeForTensor(NodeDef tensorProto, int outNum) {
        org.tensorflow.framework.DataType tfType;
        String opName = tensorProto.getOp();
        OpDef opDef = TensorflowDescriptorParser.opDescs().get(opName);
        int outputArgCount = opDef == null ? 0 : opDef.getOutputArgCount();
        int[] outVarsPerOutputArg = outputArgCount == 0 ? null : new int[outputArgCount];
        int actualOutputCount = 0;
        if (outputArgCount > 0) {
            for (int i = 0; i < outputArgCount; ++i) {
                OpDef.ArgDef argDef = opDef.getOutputArg(i);
                String numAttr = argDef.getNumberAttr();
                if (numAttr != null && !numAttr.isEmpty()) {
                    int n;
                    String numAttrName = argDef.getNumberAttr();
                    outVarsPerOutputArg[i] = n = (int)tensorProto.getAttrOrThrow(numAttrName).getI();
                } else {
                    outVarsPerOutputArg[i] = 1;
                }
                actualOutputCount += outVarsPerOutputArg[i];
            }
        }
        if (opDef != null && outputArgCount > 0) {
            OpDef.ArgDef argDef;
            String typeAttr;
            Preconditions.checkState((outNum < actualOutputCount ? 1 : 0) != 0, (String)"Cannot get output argument %s from op %s with %s output variables - variable %s", (Object)outNum, (Object)actualOutputCount, (Object)tensorProto.getName(), (Object)tensorProto.getName());
            int argIdx = outNum;
            if (outputArgCount != actualOutputCount) {
                int idx = 0;
                int soFar = 0;
                while (soFar + outVarsPerOutputArg[idx] <= outNum) {
                    soFar += outVarsPerOutputArg[idx++];
                }
                argIdx = idx;
            }
            if ((typeAttr = (argDef = opDef.getOutputArg(argIdx)).getTypeAttr()) == null || !tensorProto.containsAttr(typeAttr)) return DataType.UNKNOWN;
            tfType = tensorProto.getAttrOrThrow(typeAttr).getType();
            return TFGraphMapper.convertType(tfType);
        } else {
            if (tensorProto.getOp().equals("NoOp")) {
                return DataType.UNKNOWN;
            }
            if (tensorProto.getOp().equals("Assert")) {
                return DataType.BOOL;
            }
            log.debug("No TensorFlow descriptor found for tensor \"{}\", op \"{}\"", (Object)tensorProto.getName(), (Object)tensorProto.getOp());
            if (!(tensorProto.containsAttr("dtype") || tensorProto.containsAttr("Tidx") || tensorProto.containsAttr("T"))) {
                return DataType.UNKNOWN;
            }
            tfType = tensorProto.containsAttr("dtype") ? tensorProto.getAttrOrThrow("dtype").getType() : (tensorProto.containsAttr("T") ? tensorProto.getAttrOrThrow("T").getType() : tensorProto.getAttrOrThrow("Tidx").getType());
        }
        return TFGraphMapper.convertType(tfType);
    }

    public static DataType convertType(org.tensorflow.framework.DataType tfType) {
        switch (tfType) {
            case DT_DOUBLE: {
                return DataType.DOUBLE;
            }
            case DT_FLOAT: {
                return DataType.FLOAT;
            }
            case DT_HALF: {
                return DataType.HALF;
            }
            case DT_BFLOAT16: {
                return DataType.BFLOAT16;
            }
            case DT_INT8: {
                return DataType.BYTE;
            }
            case DT_INT16: {
                return DataType.SHORT;
            }
            case DT_INT32: {
                return DataType.INT;
            }
            case DT_INT64: {
                return DataType.LONG;
            }
            case DT_UINT8: {
                return DataType.UBYTE;
            }
            case DT_STRING: {
                return DataType.UTF8;
            }
            case DT_BOOL: {
                return DataType.BOOL;
            }
        }
        return DataType.UNKNOWN;
    }

    @Override
    public boolean isStringType(NodeDef tensorProto) {
        org.tensorflow.framework.DataType dt = null;
        if (tensorProto.containsAttr("dtype")) {
            dt = tensorProto.getAttrOrThrow("dtype").getType();
        } else if (tensorProto.containsAttr("T")) {
            dt = tensorProto.getAttrOrThrow("T").getType();
        } else if (tensorProto.containsAttr("Tidx")) {
            dt = tensorProto.getAttrOrThrow("Tidx").getType();
        }
        return dt == org.tensorflow.framework.DataType.DT_STRING || dt == org.tensorflow.framework.DataType.DT_STRING_REF;
    }

    @Override
    public String getAttrValueFromNode(NodeDef nodeDef, String key) {
        return nodeDef.getAttrOrThrow(key).getS().toStringUtf8();
    }

    @Override
    public long[] getShapeFromAttribute(AttrValue attrValue) {
        TensorShapeProto shape = attrValue.getShape();
        long[] ret = new long[shape.getDimCount()];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = (int)shape.getDim(i).getSize();
        }
        return ret;
    }

    @Override
    public boolean isPlaceHolder(NodeDef nodeDef) {
        return nodeDef.getOp().startsWith("Placeholder");
    }

    @Override
    public boolean isConstant(NodeDef nodeDef) {
        return nodeDef.getOp().startsWith("Const");
    }

    @Override
    public List<String> getControlDependencies(NodeDef node) {
        int numInputs = node.getInputCount();
        if (numInputs == 0) {
            return null;
        }
        ArrayList<String> out = null;
        for (int i = 0; i < numInputs; ++i) {
            String in = node.getInput(i);
            if (!this.isControlDependency(in)) continue;
            if (out == null) {
                out = new ArrayList<String>();
            }
            out.add(this.getNodeName(in));
        }
        return out;
    }

    @Override
    public INDArray getNDArrayFromTensor(String tensorName, NodeDef node, GraphDef graph) {
        if (!node.getAttrMap().containsKey(VALUE_ATTR_KEY)) {
            return null;
        }
        TensorProto tfTensor = node.getAttrOrThrow(VALUE_ATTR_KEY).getTensor();
        INDArray out = this.mapTensorProto(tfTensor);
        return out;
    }

    public INDArray mapTensorProto(TensorProto tfTensor) {
        TFTensorMapper<?, ?> m = TFTensorMappers.newMapper(tfTensor);
        if (m == null) {
            throw new RuntimeException("Not implemented datatype: " + (Object)((Object)tfTensor.getDtype()));
        }
        INDArray out = m.toNDArray();
        return out;
    }

    protected static void setFloat16ValueFromInt(INDArray arr, int idx, int bytesAsPaddedInt) {
        ByteBuffer bb = arr.data().pointer().asByteBuffer();
        bb.put(2 * idx, (byte)(bytesAsPaddedInt >> 8 & 0xFF));
        bb.put(2 * idx + 1, (byte)(bytesAsPaddedInt & 0xFF));
    }

    @Override
    public long[] getShapeFromTensor(NodeDef tensorProto) {
        if (tensorProto.containsAttr(SHAPE_KEY)) {
            return this.shapeFromShapeProto(tensorProto.getAttrOrThrow(SHAPE_KEY).getShape());
        }
        if (!tensorProto.containsAttr(VALUE_ATTR_KEY)) {
            return null;
        }
        return this.shapeFromShapeProto(tensorProto.getAttrOrThrow(VALUE_ATTR_KEY).getTensor().getTensorShape());
    }

    @Override
    public Set<String> opsToIgnore() {
        return this.graphMapper;
    }

    @Override
    public String getInputFromNode(NodeDef node, int index) {
        return node.getInput(index);
    }

    @Override
    public int numInputsFor(NodeDef nodeDef) {
        return nodeDef.getInputCount();
    }

    private long[] shapeFromShapeProto(TensorShapeProto tensorShapeProto) {
        long[] shape = new long[tensorShapeProto.getDimList().size()];
        for (int i = 0; i < shape.length; ++i) {
            shape[i] = tensorShapeProto.getDim(i).getSize();
        }
        return shape;
    }

    public IfImportState nodesForIf(NodeDef from, GraphDef graph) {
        int currNodeIndex = graph.getNodeList().indexOf(from);
        String trueDefName = from.getInput(1);
        String falseDefName = from.getInput(0);
        String scopeId = UUID.randomUUID().toString();
        String scopeName = scopeId + "-" + trueDefName.substring(0, trueDefName.indexOf("/"));
        String trueDefScopeName = scopeName + "-true-scope";
        String falseDefScopeName = scopeName + "-false-scope";
        boolean onFalseDefinition = true;
        boolean onTrueDefinition = false;
        ArrayList<NodeDef> falseBodyNodes = new ArrayList<NodeDef>();
        ArrayList<NodeDef> trueBodyNodes = new ArrayList<NodeDef>();
        ArrayList<NodeDef> conditionNodes = new ArrayList<NodeDef>();
        LinkedHashSet<String> seenNames = new LinkedHashSet<String>();
        for (int i = currNodeIndex; i >= 0; --i) {
            if (graph.getNode(i).getName().equals(trueDefName)) {
                onFalseDefinition = false;
                onTrueDefinition = true;
            }
            if (graph.getNode(i).getName().contains("pred_id")) {
                onTrueDefinition = false;
            }
            if (onTrueDefinition && !graph.getNode(i).equals(from)) {
                trueBodyNodes.add(graph.getNode(i));
                continue;
            }
            if (onFalseDefinition && !graph.getNode(i).equals(from)) {
                falseBodyNodes.add(graph.getNode(i));
                continue;
            }
            NodeDef currNode = graph.getNode(i);
            if (currNode.equals(from)) continue;
            if (!seenNames.contains(graph.getNode(i).getName()) && !graph.getNode(i).getName().contains("pred_id")) break;
            for (int inputIdx = 0; inputIdx < currNode.getInputCount(); ++inputIdx) {
                seenNames.add(currNode.getInput(inputIdx));
            }
            seenNames.add(graph.getNode(i).getName());
            conditionNodes.add(graph.getNode(i));
        }
        Collections.reverse(falseBodyNodes);
        Collections.reverse(trueBodyNodes);
        Collections.reverse(conditionNodes);
        return IfImportState.builder().condNodes(conditionNodes).falseNodes(falseBodyNodes).trueNodes(trueBodyNodes).conditionBodyScopeName(falseDefScopeName).falseBodyScopeName(falseDefScopeName).trueBodyScopeName(trueDefScopeName).conditionBodyScopeName(scopeName).build();
    }
}

