/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.tensorflow.conversion.graphrunner;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.nd4j.TFGraphRunnerService;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.tensorflow.conversion.TensorDataType;
import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner;

public class GraphRunnerServiceProvider
implements TFGraphRunnerService {
    private GraphRunner graphRunner;
    Map<String, INDArray> inputs;

    public TFGraphRunnerService init(List<String> inputNames, List<String> outputNames, byte[] graphBytes, Map<String, INDArray> constants, Map<String, String> inputDataTypes) {
        if (inputNames.size() != inputDataTypes.size()) {
            throw new IllegalArgumentException("inputNames.size() != inputDataTypes.size()");
        }
        HashMap<String, TensorDataType> convertedDataTypes = new HashMap<String, TensorDataType>();
        for (int i = 0; i < inputNames.size(); ++i) {
            convertedDataTypes.put(inputNames.get(i), TensorDataType.fromProtoValue(inputDataTypes.get(inputNames.get(i))));
        }
        HashMap<String, INDArray> castConstants = new HashMap<String, INDArray>();
        for (Map.Entry<String, INDArray> e : constants.entrySet()) {
            DataType requiredDtype = TensorDataType.toNd4jType(TensorDataType.fromProtoValue(inputDataTypes.get(e.getKey())));
            castConstants.put(e.getKey(), e.getValue().castTo(requiredDtype));
        }
        this.inputs = castConstants;
        this.graphRunner = GraphRunner.builder().inputNames(inputNames).outputNames(outputNames).graphBytes(graphBytes).inputDataTypes(convertedDataTypes).build();
        return this;
    }

    public Map<String, INDArray> run(Map<String, INDArray> inputs) {
        if (this.graphRunner == null) {
            throw new RuntimeException("GraphRunner not initialized.");
        }
        this.inputs.putAll(inputs);
        return this.graphRunner.run(this.inputs);
    }
}

