/*
 * Decompiled with CFR 0.152.
 */
package ai.onnxruntime;

import ai.onnxruntime.OnnxRuntime;
import ai.onnxruntime.OnnxTensorLike;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtAllocator;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.logging.Logger;

public final class OrtTrainingSession
implements AutoCloseable {
    private final long nativeHandle;
    private final OrtAllocator allocator;
    private final OrtCheckpointState checkpoint;
    private final String trainPath;
    private final String evalPath;
    private final String optimizerPath;
    private final Set<String> trainInputNames;
    private final Set<String> trainOutputNames;
    private final Set<String> evalInputNames;
    private final Set<String> evalOutputNames;
    private boolean closed = false;

    OrtTrainingSession(OrtEnvironment env, OrtAllocator allocator, OrtSession.SessionOptions options, OrtCheckpointState checkpoint, String trainPath, String evalPath, String optimizerPath) throws OrtException {
        this(OrtTrainingSession.createTrainingSession(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, env.getNativeHandle(), options.getNativeHandle(), checkpoint.nativeHandle, trainPath, evalPath, optimizerPath), allocator, checkpoint, trainPath, evalPath, optimizerPath);
    }

    private OrtTrainingSession(long nativeHandle, OrtAllocator allocator, OrtCheckpointState checkpoint, String trainPath, String evalPath, String optimizerPath) throws OrtException {
        this.nativeHandle = nativeHandle;
        this.allocator = allocator;
        this.checkpoint = checkpoint;
        this.trainPath = trainPath;
        this.evalPath = evalPath;
        this.optimizerPath = optimizerPath;
        this.trainInputNames = Collections.unmodifiableSet(new LinkedHashSet<String>(Arrays.asList(this.getTrainInputNames(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, allocator.handle))));
        this.trainOutputNames = Collections.unmodifiableSet(new LinkedHashSet<String>(Arrays.asList(this.getTrainOutputNames(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, allocator.handle))));
        this.evalInputNames = Collections.unmodifiableSet(new LinkedHashSet<String>(Arrays.asList(this.getEvalInputNames(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, allocator.handle))));
        this.evalOutputNames = Collections.unmodifiableSet(new LinkedHashSet<String>(Arrays.asList(this.getEvalOutputNames(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, allocator.handle))));
    }

    private static native long createTrainingSession(long var0, long var2, long var4, long var6, long var8, String var10, String var11, String var12);

    public Set<String> getTrainInputNames() {
        return this.trainInputNames;
    }

    public Set<String> getTrainOutputNames() {
        return this.trainOutputNames;
    }

    public Set<String> getEvalInputNames() {
        return this.evalInputNames;
    }

    public Set<String> getEvalOutputNames() {
        return this.evalOutputNames;
    }

    public void addProperty(String name, float value) throws OrtException {
        this.checkpoint.addProperty(name, value);
    }

    public void addProperty(String name, int value) throws OrtException {
        this.checkpoint.addProperty(name, value);
    }

    public void addProperty(String name, String value) throws OrtException {
        this.checkpoint.addProperty(name, value);
    }

    public float getFloatProperty(String name) throws OrtException {
        return this.checkpoint.getFloatProperty(this.allocator, name);
    }

    public int getIntProperty(String name) throws OrtException {
        return this.checkpoint.getIntProperty(this.allocator, name);
    }

    public String getStringProperty(String name) throws OrtException {
        return this.checkpoint.getStringProperty(this.allocator, name);
    }

    private void checkClosed() {
        if (this.closed) {
            throw new IllegalStateException("Trying to use a closed OrtTrainingSession");
        }
    }

    @Override
    public void close() {
        if (this.closed) {
            throw new IllegalStateException("Trying to close an already closed OrtSession.");
        }
        this.closeSession(OnnxRuntime.ortTrainingApiHandle, this.nativeHandle);
        this.checkpoint.close();
        this.closed = true;
    }

    private native void closeSession(long var1, long var3);

    public void saveCheckpoint(Path outputPath, boolean saveOptimizer) throws OrtException {
        this.checkClosed();
        this.checkpoint.saveCheckpoint(outputPath, saveOptimizer);
    }

    private native String[] getTrainInputNames(long var1, long var3, long var5, long var7) throws OrtException;

    private native String[] getTrainOutputNames(long var1, long var3, long var5, long var7) throws OrtException;

    private native String[] getEvalInputNames(long var1, long var3, long var5, long var7) throws OrtException;

    private native String[] getEvalOutputNames(long var1, long var3, long var5, long var7) throws OrtException;

    public void lazyResetGrad() throws OrtException {
        this.checkClosed();
        this.lazyResetGrad(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle);
    }

    private native void lazyResetGrad(long var1, long var3, long var5) throws OrtException;

    public static void setSeed(long seed) throws OrtException {
        OrtTrainingSession.setSeed(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, seed);
    }

    private static native void setSeed(long var0, long var2, long var4) throws OrtException;

    public OrtSession.Result trainStep(Map<String, ? extends OnnxTensorLike> inputs) throws OrtException {
        return this.trainStep(inputs, this.trainOutputNames, Collections.emptyMap(), null);
    }

    public OrtSession.Result trainStep(Map<String, ? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtException {
        return this.trainStep(inputs, this.trainOutputNames, Collections.emptyMap(), runOptions);
    }

    public OrtSession.Result trainStep(Map<String, ? extends OnnxTensorLike> inputs, Set<String> requestedOutputs) throws OrtException {
        return this.trainStep(inputs, requestedOutputs, Collections.emptyMap(), null);
    }

    public OrtSession.Result trainStep(Map<String, ? extends OnnxTensorLike> inputs, Map<String, ? extends OnnxValue> pinnedOutputs) throws OrtException {
        return this.trainStep(inputs, Collections.emptySet(), pinnedOutputs, null);
    }

    public OrtSession.Result trainStep(Map<String, ? extends OnnxTensorLike> inputs, Set<String> requestedOutputs, Map<String, ? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException {
        this.checkClosed();
        if (inputs.isEmpty() && this.trainInputNames.size() != 0 || inputs.size() > this.trainInputNames.size()) {
            throw new OrtException("Unexpected number of inputs, expected [1," + this.trainInputNames.size() + ") found " + inputs.size());
        }
        int numTrainOutputs = this.trainOutputNames.size();
        int totalOutputs = requestedOutputs.size() + pinnedOutputs.size();
        if (totalOutputs == 0 || totalOutputs > numTrainOutputs) {
            throw new OrtException("Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + numTrainOutputs + ") found " + totalOutputs);
        }
        String[] inputNamesArray = new String[inputs.size()];
        long[] inputHandles = new long[inputs.size()];
        int i = 0;
        for (Map.Entry<String, ? extends OnnxTensorLike> t : inputs.entrySet()) {
            if (this.trainInputNames.contains(t.getKey())) {
                inputNamesArray[i] = t.getKey();
                inputHandles[i] = t.getValue().getNativeHandle();
                ++i;
                continue;
            }
            throw new OrtException("Unknown input name " + t.getKey() + ", expected one of " + this.trainInputNames);
        }
        String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()];
        OnnxValue[] outputValues = new OnnxValue[outputNamesArray.length];
        long[] outputHandles = new long[outputNamesArray.length];
        i = 0;
        for (Map.Entry<String, ? extends OnnxValue> e : pinnedOutputs.entrySet()) {
            if (this.trainOutputNames.contains(e.getKey())) {
                outputNamesArray[i] = e.getKey();
                outputValues[i] = e.getValue();
                outputHandles[i] = OrtSession.getHandle(e.getValue());
                ++i;
                continue;
            }
            throw new OrtException("Unknown output name " + e.getKey() + ", expected one of " + this.trainOutputNames.toString());
        }
        for (String s : requestedOutputs) {
            if (this.trainOutputNames.contains(s)) {
                if (!pinnedOutputs.containsKey(s)) {
                    outputNamesArray[i] = s;
                    ++i;
                    continue;
                }
                throw new OrtException("Output '" + s + "' was found in both the requested outputs and the pinned outputs");
            }
            throw new OrtException("Unknown output name " + s + ", expected one of " + this.trainOutputNames.toString());
        }
        long runOptionsHandle = runOptions == null ? 0L : runOptions.getNativeHandle();
        boolean[] ownedByResult = this.trainStep(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, this.allocator.handle, inputNamesArray, inputHandles, inputNamesArray.length, outputNamesArray, outputNamesArray.length, outputValues, outputHandles, runOptionsHandle);
        return new OrtSession.Result(outputNamesArray, outputValues, ownedByResult);
    }

    private native boolean[] trainStep(long var1, long var3, long var5, long var7, String[] var9, long[] var10, long var11, String[] var13, long var14, OnnxValue[] var16, long[] var17, long var18);

    public OrtSession.Result evalStep(Map<String, ? extends OnnxTensorLike> inputs) throws OrtException {
        return this.evalStep(inputs, this.evalOutputNames, Collections.emptyMap(), null);
    }

    public OrtSession.Result evalStep(Map<String, ? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtException {
        return this.evalStep(inputs, this.evalOutputNames, Collections.emptyMap(), runOptions);
    }

    public OrtSession.Result evalStep(Map<String, ? extends OnnxTensorLike> inputs, Set<String> requestedOutputs) throws OrtException {
        return this.evalStep(inputs, requestedOutputs, Collections.emptyMap(), null);
    }

    public OrtSession.Result evalStep(Map<String, ? extends OnnxTensorLike> inputs, Map<String, ? extends OnnxValue> pinnedOutputs) throws OrtException {
        return this.evalStep(inputs, Collections.emptySet(), pinnedOutputs, null);
    }

    public OrtSession.Result evalStep(Map<String, ? extends OnnxTensorLike> inputs, Set<String> requestedOutputs, Map<String, ? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException {
        this.checkClosed();
        if (inputs.isEmpty() && this.evalInputNames.size() != 0 || inputs.size() > this.evalInputNames.size()) {
            throw new OrtException("Unexpected number of inputs, expected [1," + this.evalInputNames.size() + ") found " + inputs.size());
        }
        int numEvalOutputs = this.evalOutputNames.size();
        int totalOutputs = requestedOutputs.size() + pinnedOutputs.size();
        if (totalOutputs == 0 || totalOutputs > numEvalOutputs) {
            throw new OrtException("Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + numEvalOutputs + ") found " + totalOutputs);
        }
        String[] inputNamesArray = new String[inputs.size()];
        long[] inputHandles = new long[inputs.size()];
        int i = 0;
        for (Map.Entry<String, ? extends OnnxTensorLike> t : inputs.entrySet()) {
            if (this.evalInputNames.contains(t.getKey())) {
                inputNamesArray[i] = t.getKey();
                inputHandles[i] = t.getValue().getNativeHandle();
                ++i;
                continue;
            }
            throw new OrtException("Unknown input name " + t.getKey() + ", expected one of " + this.evalInputNames.toString());
        }
        String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()];
        OnnxValue[] outputValues = new OnnxValue[outputNamesArray.length];
        long[] outputHandles = new long[outputNamesArray.length];
        i = 0;
        for (Map.Entry<String, ? extends OnnxValue> e : pinnedOutputs.entrySet()) {
            if (this.evalOutputNames.contains(e.getKey())) {
                outputNamesArray[i] = e.getKey();
                outputValues[i] = e.getValue();
                outputHandles[i] = OrtSession.getHandle(e.getValue());
                ++i;
                continue;
            }
            throw new OrtException("Unknown output name " + e.getKey() + ", expected one of " + this.evalOutputNames.toString());
        }
        for (String s : requestedOutputs) {
            if (this.evalOutputNames.contains(s)) {
                if (!pinnedOutputs.containsKey(s)) {
                    outputNamesArray[i] = s;
                    ++i;
                    continue;
                }
                throw new OrtException("Output '" + s + "' was found in both the requested outputs and the pinned outputs");
            }
            throw new OrtException("Unknown output name " + s + ", expected one of " + this.evalOutputNames.toString());
        }
        long runOptionsHandle = runOptions == null ? 0L : runOptions.getNativeHandle();
        boolean[] ownedByResult = this.evalStep(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, this.allocator.handle, inputNamesArray, inputHandles, inputNamesArray.length, outputNamesArray, outputNamesArray.length, outputValues, outputHandles, runOptionsHandle);
        return new OrtSession.Result(outputNamesArray, outputValues, ownedByResult);
    }

    private native boolean[] evalStep(long var1, long var3, long var5, long var7, String[] var9, long[] var10, long var11, String[] var13, long var14, OnnxValue[] var16, long[] var17, long var18) throws OrtException;

    public void setLearningRate(float learningRate) throws OrtException {
        this.checkClosed();
        this.setLearningRate(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, learningRate);
    }

    private native void setLearningRate(long var1, long var3, long var5, float var7) throws OrtException;

    public float getLearningRate() throws OrtException {
        this.checkClosed();
        return this.getLearningRate(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle);
    }

    private native float getLearningRate(long var1, long var3, long var5);

    public void optimizerStep() throws OrtException {
        this.optimizerStep(null);
    }

    public void optimizerStep(OrtSession.RunOptions runOptions) throws OrtException {
        this.checkClosed();
        long runOptionsHandle = runOptions == null ? 0L : runOptions.getNativeHandle();
        this.optimizerStep(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, runOptionsHandle);
    }

    private native void optimizerStep(long var1, long var3, long var5, long var7) throws OrtException;

    public void registerLinearLRScheduler(long warmupSteps, long totalSteps, float initialLearningRate) throws OrtException {
        this.registerLinearLRScheduler(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, warmupSteps, totalSteps, initialLearningRate);
    }

    private native void registerLinearLRScheduler(long var1, long var3, long var5, long var7, long var9, float var11) throws OrtException;

    public void schedulerStep() throws OrtException {
        this.checkClosed();
        this.schedulerStep(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle);
    }

    private native void schedulerStep(long var1, long var3, long var5) throws OrtException;

    public void exportModelForInference(Path outputPath, String[] outputNames) throws OrtException {
        this.checkClosed();
        if (outputNames.length == 0) {
            throw new IllegalArgumentException("Requires at least one output name");
        }
        String outputStr = outputPath.toString();
        this.exportModelForInference(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, outputStr, outputNames.length, outputNames);
    }

    private native void exportModelForInference(long var1, long var3, long var5, String var7, long var8, String[] var10) throws OrtException;

    static {
        try {
            OnnxRuntime.init();
        }
        catch (IOException e) {
            throw new RuntimeException("Failed to load onnx-runtime library", e);
        }
    }

    static final class OrtCheckpointState
    implements AutoCloseable {
        private static final Logger logger = Logger.getLogger(OrtCheckpointState.class.getName());
        final long nativeHandle;
        private boolean closed;

        OrtCheckpointState(long nativeHandle) {
            this.nativeHandle = nativeHandle;
            this.closed = false;
        }

        static OrtCheckpointState loadCheckpoint(Path checkpointPath) throws OrtException {
            String pathStr = checkpointPath.toString();
            return OrtCheckpointState.loadCheckpoint(pathStr);
        }

        static OrtCheckpointState loadCheckpoint(String checkpoint) throws OrtException {
            if (OnnxRuntime.trainingEnabled) {
                Objects.requireNonNull(checkpoint, "checkpoint path must not be null");
                return new OrtCheckpointState(OrtCheckpointState.loadCheckpoint(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, checkpoint));
            }
            throw new IllegalStateException("Training is not enabled in this build of ONNX Runtime.");
        }

        public void saveCheckpoint(Path outputPath, boolean saveOptimizer) throws OrtException {
            this.checkClosed();
            Objects.requireNonNull(outputPath, "checkpoint path must not be null");
            String outputStr = outputPath.toString();
            this.saveCheckpoint(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, outputStr, saveOptimizer);
        }

        public void addProperty(String name, float value) throws OrtException {
            this.checkClosed();
            this.addProperty(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, name, value);
        }

        public void addProperty(String name, int value) throws OrtException {
            this.checkClosed();
            this.addProperty(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, name, value);
        }

        public void addProperty(String name, String value) throws OrtException {
            this.checkClosed();
            this.addProperty(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, name, value);
        }

        public float getFloatProperty(OrtAllocator allocator, String name) throws OrtException {
            this.checkClosed();
            return this.getFloatProperty(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, allocator.handle, name);
        }

        public int getIntProperty(OrtAllocator allocator, String name) throws OrtException {
            this.checkClosed();
            return this.getIntProperty(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, allocator.handle, name);
        }

        public String getStringProperty(OrtAllocator allocator, String name) throws OrtException {
            this.checkClosed();
            return this.getStringProperty(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, allocator.handle, name);
        }

        private void checkClosed() {
            if (this.closed) {
                throw new IllegalStateException("Trying to use a closed OrtCheckpointState");
            }
        }

        public synchronized boolean isClosed() {
            return this.closed;
        }

        @Override
        public synchronized void close() {
            if (!this.closed) {
                this.close(OnnxRuntime.ortTrainingApiHandle, this.nativeHandle);
                this.closed = true;
            } else {
                logger.warning("Closing a checkpoint twice");
            }
        }

        private static native long loadCheckpoint(long var0, long var2, String var4) throws OrtException;

        private native void saveCheckpoint(long var1, long var3, long var5, String var7, boolean var8) throws OrtException;

        private native void addProperty(long var1, long var3, long var5, String var7, int var8) throws OrtException;

        private native void addProperty(long var1, long var3, long var5, String var7, float var8) throws OrtException;

        private native void addProperty(long var1, long var3, long var5, String var7, String var8) throws OrtException;

        private native int getIntProperty(long var1, long var3, long var5, long var7, String var9) throws OrtException;

        private native float getFloatProperty(long var1, long var3, long var5, long var7, String var9) throws OrtException;

        private native String getStringProperty(long var1, long var3, long var5, long var7, String var9) throws OrtException;

        private native void close(long var1, long var3);
    }
}

