/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.jni;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.engine.PtSymbolBlock;
import ai.djl.pytorch.jni.IValue;
import ai.djl.pytorch.jni.PyTorchLibrary;
import ai.djl.util.NativeResource;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;

public final class IValueUtils {
    private IValueUtils() {
    }

    public static NDList forward(PtSymbolBlock block, NDList inputs, boolean isTrain) {
        IValue[] iValues = IValueUtils.getInputs(inputs);
        long[] iValueHandles = Arrays.stream(iValues).mapToLong(NativeResource::getHandle).toArray();
        long result = PyTorchLibrary.LIB.moduleForward(block.getHandle(), iValueHandles, isTrain);
        PtNDManager manager = (PtNDManager)((NDArray)inputs.get(0)).getManager();
        Arrays.stream(iValues).forEach(IValue::close);
        try (IValue iValue = new IValue(result);){
            NDList nDList = iValue.toNDList(manager);
            return nDList;
        }
    }

    public static IValue forward(PtSymbolBlock block, IValue ... inputs) {
        long[] handles = Arrays.stream(inputs).mapToLong(NativeResource::getHandle).toArray();
        return new IValue(PyTorchLibrary.LIB.moduleForward(block.getHandle(), handles, false));
    }

    private static int addToMap(Map<String, Integer> map, String key, List<PairList<String, PtNDArray>> list) {
        return map.computeIfAbsent(key, k -> {
            list.add(new PairList());
            return list.size() - 1;
        });
    }

    private static IValue[] getInputs(NDList ndList) {
        ArrayList<PairList<String, PtNDArray>> outputs = new ArrayList<PairList<String, PtNDArray>>();
        ConcurrentHashMap<String, Integer> indexMap = new ConcurrentHashMap<String, Integer>();
        for (NDArray array : ndList) {
            String name = array.getName();
            if (name != null && name.contains(".")) {
                String[] strings = name.split("\\.", 2);
                int index = IValueUtils.addToMap(indexMap, strings[0], outputs);
                PairList pl = (PairList)outputs.get(index);
                pl.add((Object)strings[1], (Object)((PtNDArray)array));
                continue;
            }
            if (name != null && Pattern.matches("\\w+\\[]", name)) {
                int index = IValueUtils.addToMap(indexMap, name, outputs);
                PairList pl = (PairList)outputs.get(index);
                pl.add((Object)"[]", (Object)((PtNDArray)array));
                continue;
            }
            PairList pl = new PairList();
            pl.add(null, (Object)((PtNDArray)array));
            outputs.add((PairList<String, PtNDArray>)pl);
        }
        IValue[] ret = new IValue[outputs.size()];
        for (int i = 0; i < outputs.size(); ++i) {
            PairList pl = (PairList)outputs.get(i);
            String key = (String)pl.get(0).getKey();
            if (key == null) {
                ret[i] = IValue.from((PtNDArray)((Object)pl.get(0).getValue()));
                continue;
            }
            if ("[]".equals(key)) {
                PtNDArray[] arrays = pl.values().toArray(new PtNDArray[0]);
                ret[i] = IValue.listFrom(arrays);
                continue;
            }
            ConcurrentHashMap<String, PtNDArray> map = new ConcurrentHashMap<String, PtNDArray>();
            for (Pair pair : pl) {
                map.put((String)pair.getKey(), (PtNDArray)((Object)pair.getValue()));
            }
            ret[i] = IValue.stringMapFrom(map);
        }
        return ret;
    }
}

