/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.jna;

import ai.djl.Device;
import ai.djl.mxnet.engine.MxNDArray;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.training.Trainer;
import ai.djl.util.PairList;
import com.sun.jna.Pointer;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FunctionInfo {
    private Pointer handle;
    private String name;
    private PairList<String, String> arguments;
    private static final Logger logger = LoggerFactory.getLogger(Trainer.class);

    FunctionInfo(Pointer pointer, String functionName, PairList<String, String> arguments) {
        this.handle = pointer;
        this.name = functionName;
        this.arguments = arguments;
    }

    public int invoke(NDManager manager, NDArray[] src, NDArray[] dest, PairList<String, ?> params) {
        this.checkDevices(src);
        this.checkDevices(dest);
        return JnaUtils.imperativeInvoke(this.handle, src, dest, params).size();
    }

    public NDArray[] invoke(NDManager manager, NDArray[] src, PairList<String, ?> params) {
        this.checkDevices(src);
        PairList<Pointer, SparseFormat> pairList = JnaUtils.imperativeInvoke(this.handle, src, null, params);
        MxNDManager mxManager = (MxNDManager)manager;
        return (NDArray[])pairList.stream().map(pair -> {
            if (pair.getValue() != SparseFormat.DENSE) {
                return mxManager.create((Pointer)pair.getKey(), (SparseFormat)pair.getValue());
            }
            return mxManager.create((Pointer)pair.getKey());
        }).toArray(MxNDArray[]::new);
    }

    public String getFunctionName() {
        return this.name;
    }

    public List<String> getArgumentNames() {
        return this.arguments.keys();
    }

    public List<String> getArgumentTypes() {
        return this.arguments.values();
    }

    private void checkDevices(NDArray[] src) {
        if (logger.isDebugEnabled() && src.length > 1) {
            Device device = src[0].getDevice();
            for (int i = 1; i < src.length; ++i) {
                if (device.equals((Object)src[i].getDevice())) continue;
                logger.warn("Please make sure all the NDArrays are in the same device. You can call toDevice() to move the NDArray to the desired Device.");
            }
        }
    }
}

