/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.onnxruntime.engine;

import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.onnxruntime.engine.OrtNDArray;
import ai.djl.onnxruntime.engine.OrtNDManager;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxMap;
import ai.onnxruntime.OnnxModelMetadata;
import ai.onnxruntime.OnnxSequence;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class OrtSymbolBlock
extends AbstractSymbolBlock
implements AutoCloseable {
    private OrtSession session;
    private OrtNDManager manager;

    public OrtSymbolBlock(OrtSession session, OrtNDManager manager) {
        this.session = session;
        this.manager = manager;
        manager.attachInternal(NDManager.nextUid(), new AutoCloseable[]{this});
    }

    public void removeLastBlock() {
        throw new UnsupportedOperationException("ONNX Runtime not supported");
    }

    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        ArrayList inputNames = new ArrayList(this.session.getInputNames());
        if (inputs.size() != inputNames.size()) {
            throw new IllegalArgumentException("Input mismatch, looking for: " + inputNames);
        }
        ConcurrentHashMap<String, OnnxTensor> container = new ConcurrentHashMap<String, OnnxTensor>();
        OrtNDManager sub = (OrtNDManager)this.manager.newSubManager();
        try {
            if (((NDArray)inputs.get(0)).getName() != null) {
                for (NDArray input : inputs) {
                    String name = input.getName();
                    if (name == null) {
                        throw new IllegalArgumentException("All or none of input tensors must have a name.");
                    }
                    if (!inputNames.contains(name)) {
                        throw new IllegalArgumentException("Invalid input tensor name: " + name);
                    }
                    OrtNDArray ortNDArray = sub.from(input);
                    container.put(name, ortNDArray.getTensor());
                }
            } else {
                for (int i = 0; i < inputNames.size(); ++i) {
                    OrtNDArray ortNDArray = sub.from((NDArray)inputs.get(i));
                    container.put((String)inputNames.get(i), ortNDArray.getTensor());
                }
            }
            OrtSession.Result results = this.session.run(container);
            NDList ret = this.evaluateOutput(results);
            ret.attach(inputs.head().getManager());
            NDList nDList = ret;
            if (sub != null) {
                sub.close();
            }
            return nDList;
        }
        catch (Throwable throwable) {
            try {
                if (sub != null) {
                    try {
                        sub.close();
                    }
                    catch (Throwable throwable2) {
                        throwable.addSuppressed(throwable2);
                    }
                }
                throw throwable;
            }
            catch (OrtException e) {
                throw new EngineException((Throwable)e);
            }
        }
    }

    public PairList<String, Shape> describeInput() {
        PairList result = new PairList();
        for (String name : this.session.getInputNames()) {
            result.add((Object)name, null);
        }
        return result;
    }

    public Map<String, String> getCustomMetadata() {
        try {
            OnnxModelMetadata modelMetadata = this.session.getMetadata();
            return modelMetadata.getCustomMetadata();
        }
        catch (OrtException e) {
            throw new EngineException((Throwable)e);
        }
    }

    private NDList evaluateOutput(OrtSession.Result results) {
        NDList output = new NDList();
        for (Map.Entry r : results) {
            OnnxValue value = (OnnxValue)r.getValue();
            if (value instanceof OnnxTensor) {
                OrtNDArray array = this.manager.createInternal((OnnxTensor)value);
                array.setName((String)r.getKey());
                output.add((Object)array);
                continue;
            }
            if (value instanceof OnnxSequence) {
                OnnxSequence seq = (OnnxSequence)value;
                if (seq.getInfo().isSequenceOfMaps()) {
                    NDArray array = this.seq2Nd(seq);
                    array.setName((String)r.getKey());
                    output.add((Object)array);
                    continue;
                }
                output.addAll(this.seq2NdList(seq));
                continue;
            }
            throw new UnsupportedOperationException("Unsupported output type! " + (String)r.getKey());
        }
        return output;
    }

    private NDArray seq2Nd(OnnxSequence seq) {
        try {
            List values = seq.getValue();
            ArrayList finalData = new ArrayList();
            OnnxJavaType type = seq.getInfo().mapInfo.valueType;
            for (OnnxMap map : values) {
                finalData.addAll(map.getValue().values());
            }
            Shape shape = new Shape(new long[]{values.size(), finalData.size() / values.size()});
            ByteBuffer buffer = ByteBuffer.allocate(finalData.size() * type.size);
            switch (type) {
                case FLOAT: {
                    finalData.forEach(ele -> buffer.putFloat(((Float)ele).floatValue()));
                    buffer.rewind();
                    return this.manager.create(buffer.asFloatBuffer(), shape, DataType.FLOAT32);
                }
                case DOUBLE: {
                    finalData.forEach(ele -> buffer.putDouble((Double)ele));
                    buffer.rewind();
                    return this.manager.create(buffer.asDoubleBuffer(), shape, DataType.FLOAT64);
                }
                case BOOL: 
                case INT8: {
                    DataType dp = type == OnnxJavaType.BOOL ? DataType.BOOLEAN : DataType.INT8;
                    finalData.forEach(ele -> buffer.put((Byte)ele));
                    buffer.rewind();
                    return this.manager.create(buffer, shape, dp);
                }
                case INT32: {
                    finalData.forEach(ele -> buffer.putInt((Integer)ele));
                    buffer.rewind();
                    return this.manager.create(buffer.asIntBuffer(), shape, DataType.INT32);
                }
                case INT64: {
                    finalData.forEach(ele -> buffer.putLong((Long)ele));
                    buffer.rewind();
                    return this.manager.create(buffer.asLongBuffer(), shape, DataType.INT64);
                }
            }
            throw new UnsupportedOperationException("type is not supported: " + type);
        }
        catch (OrtException e) {
            throw new EngineException((Throwable)e);
        }
    }

    private NDList seq2NdList(OnnxSequence sequence) {
        try {
            NDList list = new NDList();
            for (OnnxValue value : sequence.getValue()) {
                list.add((Object)this.manager.createInternal((OnnxTensor)value));
            }
            return list;
        }
        catch (OrtException e) {
            throw new EngineException((Throwable)e);
        }
    }

    @Override
    public void close() {
        if (this.session != null) {
            try {
                this.session.close();
                this.session = null;
            }
            catch (OrtException e) {
                throw new EngineException((Throwable)e);
            }
        }
    }

    public ParameterList getDirectParameters() {
        throw new UnsupportedOperationException("Not yet supported");
    }
}

