/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow.op;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.List;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.BaseGradientAdapter;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.internal.c_api.NativeOperation;
import org.tensorflow.internal.c_api.NativeOutputVector;
import org.tensorflow.internal.c_api.NativeStatus;
import org.tensorflow.internal.c_api.TF_Scope;
import org.tensorflow.op.CustomGradient;
import org.tensorflow.op.GradientScope;
import org.tensorflow.op.Ops;
import org.tensorflow.op.RawOp;
import org.tensorflow.op.RawOpInputs;

final class TypedGradientAdapter<T extends RawOpInputs<?>>
extends BaseGradientAdapter {
    private final CustomGradient<T> gradient;
    private final Class<T> opInputClass;
    private final Constructor<T> ctor;

    TypedGradientAdapter(CustomGradient<T> gradient, Class<T> opInputClass) {
        this.gradient = gradient;
        this.opInputClass = opInputClass;
        this.ctor = this.opInputClass.getDeclaredConstructors()[0];
    }

    @Override
    public NativeStatus call(TF_Scope scope, NativeOperation op, NativeOutputVector grad_inputs, NativeOutputVector grad_outputs) {
        try (PointerScope pointerScope = new PointerScope();){
            Graph g = Graph.findGraphForPointer(scope.graph());
            if (g == null) {
                throw new IllegalStateException("No graph found for native gradient scope.");
            }
            RawOpInputs rawOp = (RawOpInputs)this.ctor.newInstance(BaseGradientAdapter.getGraphOp(g, op.node()));
            GradientScope nativeScope = new GradientScope(scope, g, null).withSubScope(((RawOp)rawOp.getOutputs()).op().name());
            Ops tf = new Ops(nativeScope);
            List<Output<?>> gradInputs = BaseGradientAdapter.fromNativeOutputs(g, grad_inputs);
            BaseGradientAdapter.useDangerousLockedBuilders(g, true);
            List<Operand<?>> gradOutputs = this.gradient.call(tf, rawOp, gradInputs);
            BaseGradientAdapter.useDangerousLockedBuilders(g, false);
            BaseGradientAdapter.putToNativeOutputs(gradOutputs, grad_outputs);
        }
        catch (IllegalAccessException | InstantiationException | InvocationTargetException e) {
            throw new RuntimeException("Could not instantiate Op class " + this.opInputClass, e);
        }
        return NativeStatus.OK();
    }
}

