/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.tensorflow.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.tensorflow.engine.SavedModelBundle;
import ai.djl.tensorflow.engine.TfNDArray;
import ai.djl.tensorflow.engine.TfNDManager;
import ai.djl.tensorflow.engine.javacpp.JavacppUtils;
import ai.djl.training.ParameterStore;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import org.bytedeco.javacpp.Pointer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.internal.c_api.TFE_TensorHandle;
import org.tensorflow.internal.c_api.TF_Graph;
import org.tensorflow.internal.c_api.TF_Operation;
import org.tensorflow.internal.c_api.TF_Session;
import org.tensorflow.internal.c_api.TF_Tensor;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.proto.framework.TensorShapeProto;

public class TfSymbolBlock
extends AbstractSymbolBlock
implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(TfSymbolBlock.class);
    private static final byte VERSION = 1;
    private SavedModelBundle bundle;
    private TF_Graph graphHandle;
    private TF_Session sessionHandle;
    private SignatureDef servingDefault;
    private PairList<String, Shape> inputDescriptions;
    private PairList<String, Shape> outputDescriptions;
    private TF_Operation[] inputOpHandles;
    private int[] inputOpIndices;
    private TF_Operation[] outputOpHandles;
    private int[] outputOpIndices;
    private TF_Operation[] targetOpHandles;

    public TfSymbolBlock(SavedModelBundle bundle, String signatureDefKey) {
        super((byte)1);
        this.bundle = bundle;
        this.graphHandle = bundle.getGraph();
        this.sessionHandle = bundle.getSession();
        MetaGraphDef metaGraphDef = bundle.getMetaGraphDef();
        Map signatureDefMap = metaGraphDef.getSignatureDefMap();
        if (signatureDefMap.containsKey(signatureDefKey)) {
            this.servingDefault = (SignatureDef)signatureDefMap.get(signatureDefKey);
        } else {
            Set keys = signatureDefMap.keySet();
            logger.warn("SignatureDefKey: " + signatureDefKey + "not found in Saved Model Bundle.Available keys: " + String.join((CharSequence)" ", keys) + "Please use .optOption(\"SignatureDefKey\", \"value\") with Criteria.builder to load the model.Normally the value is \"default\" for TF1.x models and \"serving_default\" for TF2.x models. Refer to: https://www.tensorflow.org/guide/saved_modelLoading the model using next available key.");
            this.servingDefault = (SignatureDef)signatureDefMap.get(keys.iterator().next());
        }
        this.describeInput();
        this.describeOutput();
        this.targetOpHandles = new TF_Operation[0];
    }

    public void removeLastBlock() {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        TF_Tensor[] inputTensorHandles = new TF_Tensor[this.inputDescriptions.size()];
        for (int i = 0; i < this.inputDescriptions.size(); ++i) {
            String inputName = (String)this.inputDescriptions.get(i).getKey();
            TfNDArray currentNDArray = (TfNDArray)((Object)inputs.get(i));
            String name = currentNDArray.getName();
            if (name == null || name.isEmpty() || name.equals(inputName)) {
                inputTensorHandles[i] = JavacppUtils.resolveTFETensor((TFE_TensorHandle)currentNDArray.getHandle());
                continue;
            }
            for (NDArray array : inputs) {
                if (!array.getName().equals(inputName)) continue;
                inputTensorHandles[i] = JavacppUtils.resolveTFETensor((TFE_TensorHandle)((TfNDArray)array).getHandle());
            }
        }
        TF_Tensor[] outputs = JavacppUtils.runSession(this.sessionHandle, null, inputTensorHandles, this.inputOpHandles, this.inputOpIndices, this.outputOpHandles, this.outputOpIndices, this.targetOpHandles);
        TfNDManager tfNDManager = (TfNDManager)inputs.head().getManager();
        NDList resultNDList = new NDList();
        for (int i = 0; i < outputs.length; ++i) {
            TfNDArray array = new TfNDArray(tfNDManager, JavacppUtils.createTFETensor(outputs[i]));
            array.setName((String)this.outputDescriptions.get(i).getKey());
            resultNDList.add((Object)array);
        }
        Arrays.stream(inputTensorHandles).forEach(Pointer::close);
        Arrays.stream(outputs).forEach(Pointer::close);
        return resultNDList;
    }

    public void initialize(NDManager manager, ai.djl.ndarray.types.DataType dataType, Shape ... inputShapes) {
        throw new IllegalStateException("TfSymbolBlock can't be initialized");
    }

    public boolean isInitialized() {
        return this.bundle != null;
    }

    public final PairList<String, Shape> describeInput() {
        if (this.inputDescriptions == null) {
            this.inputDescriptions = new PairList();
            Map inputsMap = this.servingDefault.getInputsMap();
            ArrayList keys = new ArrayList(inputsMap.keySet());
            Collections.sort(keys);
            this.inputOpHandles = new TF_Operation[keys.size()];
            this.inputOpIndices = new int[keys.size()];
            for (int i = 0; i < keys.size(); ++i) {
                TensorInfo tensorInfo = (TensorInfo)inputsMap.get(keys.get(i));
                TensorShapeProto shapeProto = tensorInfo.getTensorShape();
                this.inputDescriptions.add(keys.get(i), (Object)new Shape(shapeProto.getDimList().stream().mapToLong(TensorShapeProto.Dim::getSize).toArray()));
                Pair<TF_Operation, Integer> pair = JavacppUtils.getGraphOperationByName(this.graphHandle, tensorInfo.getName());
                this.inputOpHandles[i] = (TF_Operation)pair.getKey();
                this.inputOpIndices[i] = (Integer)pair.getValue();
            }
        }
        return this.inputDescriptions;
    }

    public final PairList<String, Shape> describeOutput() {
        if (this.outputDescriptions == null) {
            this.outputDescriptions = new PairList();
            Map outputsMap = this.servingDefault.getOutputsMap();
            ArrayList keys = new ArrayList(outputsMap.keySet());
            Collections.sort(keys);
            ArrayList<Object> outputOpHandlesList = new ArrayList<Object>();
            ArrayList<Object> outputOpIndicesList = new ArrayList<Object>();
            for (String key : keys) {
                TensorInfo tensorInfo = (TensorInfo)outputsMap.get(key);
                TensorShapeProto shapeProto = tensorInfo.getTensorShape();
                if (tensorInfo.getDtype() == DataType.DT_STRING) continue;
                this.outputDescriptions.add((Object)key, (Object)new Shape(shapeProto.getDimList().stream().mapToLong(TensorShapeProto.Dim::getSize).toArray()));
                Pair<TF_Operation, Integer> pair = JavacppUtils.getGraphOperationByName(this.graphHandle, tensorInfo.getName());
                outputOpHandlesList.add(pair.getKey());
                outputOpIndicesList.add(pair.getValue());
            }
            this.outputOpHandles = outputOpHandlesList.toArray(new TF_Operation[0]);
            this.outputOpIndices = outputOpIndicesList.stream().mapToInt(i -> i).toArray();
        }
        return this.outputDescriptions;
    }

    public Shape[] getOutputShapes(Shape[] inputShapes) {
        return new Shape[0];
    }

    @Override
    public void close() {
        if (this.bundle != null) {
            this.bundle.close();
        }
        Arrays.stream(this.inputOpHandles).forEach(Pointer::close);
        Arrays.stream(this.outputOpHandles).forEach(Pointer::close);
        Arrays.stream(this.targetOpHandles).forEach(Pointer::close);
    }
}

