/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.validation;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.validation.TestCase;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GradCheckUtil {
    private static final Logger log = LoggerFactory.getLogger(GradCheckUtil.class);
    private static final boolean DEFAULT_PRINT = true;
    private static final boolean DEFAULT_EXIT_FIRST_FAILURE = false;
    private static final boolean DEFAULT_DEBUG_MODE = false;
    private static final double DEFAULT_EPS = 1.0E-5;
    private static final double DEFAULT_MAX_REL_ERROR = 1.0E-5;
    private static final double DEFAULT_MIN_ABS_ERROR = 1.0E-6;

    public static boolean checkGradients(TestCase t) {
        return GradCheckUtil.checkGradients(t.sameDiff(), t.placeholderValues(), t.gradCheckEpsilon(), t.gradCheckMaxRelativeError(), t.gradCheckMinAbsError(), t.gradCheckPrint(), t.gradCheckDefaultExitFirstFailure(), false, t.gradCheckDebugMode(), t.gradCheckSkipVariables(), t.gradCheckMask());
    }

    public static boolean checkGradients(SameDiff sd, Map<String, INDArray> placeholderValues, String ... skipVariables) {
        HashSet<String> skip = null;
        if (skipVariables != null) {
            skip = new HashSet<String>();
            Collections.addAll(skip, skipVariables);
        }
        return GradCheckUtil.checkGradients(sd, placeholderValues, 1.0E-5, 1.0E-5, 1.0E-6, true, false, false, false, skip, null);
    }

    public static boolean checkGradients(SameDiff sd, Map<String, INDArray> placeholderValues, boolean print, boolean exitOnFirstFailure) {
        return GradCheckUtil.checkGradients(sd, placeholderValues, 1.0E-5, 1.0E-5, 1.0E-6, print, exitOnFirstFailure);
    }

    public static boolean checkGradients(SameDiff sd, Map<String, INDArray> placeholderValues, double eps, double maxRelError, double minAbsError, boolean print, boolean exitOnFirstFailure) {
        return GradCheckUtil.checkGradients(sd, placeholderValues, eps, maxRelError, minAbsError, print, exitOnFirstFailure, false, false, null, null);
    }

    /*
     * WARNING - void declaration
     */
    public static boolean checkGradients(SameDiff sd, Map<String, INDArray> placeholderValues, double eps, double maxRelError, double minAbsError, boolean print, boolean exitOnFirstFailure, boolean skipValidation, boolean debugMode, Set<String> skipVariables, Map<String, INDArray> gradCheckMask) {
        void var19_24;
        boolean debugBefore = sd.isDebugMode();
        if (debugMode) {
            sd.enableDebugMode();
        }
        if (!skipValidation) {
            GradCheckUtil.validateInternalState(sd, true);
        }
        if (Nd4j.dataType() != DataType.DOUBLE) {
            throw new IllegalStateException("Data type must be set to double");
        }
        HashSet<String> fnOutputs = new HashSet<String>();
        for (DifferentialFunction differentialFunction : sd.functions()) {
            for (SDVariable s : differentialFunction.outputVariables()) {
                fnOutputs.add(s.getVarName());
            }
        }
        for (Variable v : sd.getVariables().values()) {
            if (v.getVariable().getVariableType() == VariableType.ARRAY || v.getVariable().getArr(true) != null) continue;
            throw new IllegalStateException("Variable \"" + v.getName() + "\" does not have array associated with it");
        }
        List<String> lossFnVariables = sd.getLossVariables();
        Preconditions.checkState((lossFnVariables != null && !lossFnVariables.isEmpty() ? 1 : 0) != 0, (String)"Expected 1 or more loss function variables for gradient check, got %s", lossFnVariables);
        HashSet<String> gradVarNames = new HashSet<String>();
        for (Variable variable : sd.getVariables().values()) {
            if (!variable.getVariable().dataType().isFPType() || variable.getVariable().getVariableType() != VariableType.VARIABLE && variable.getVariable().getVariableType() != VariableType.PLACEHOLDER) continue;
            SDVariable g = variable.getVariable().getGradient();
            Preconditions.checkNotNull((Object)g, (String)"No gradient variable found for variable %s", (Object)variable.getVariable());
            gradVarNames.add(g.getVarName());
        }
        sd.execBackwards(placeholderValues, new ArrayList<String>(gradVarNames));
        HashMap<String, INDArray> grad = new HashMap<String, INDArray>();
        for (SDVariable v : sd.variables()) {
            if (fnOutputs.contains(v.getVarName()) || !v.hasGradient()) continue;
            SDVariable g = sd.grad(v.getVarName());
            if (g == null) {
                throw new IllegalStateException("Null gradient variable for \"" + v.getVarName() + "\"");
            }
            INDArray ga = g.getArr();
            if (ga == null) {
                throw new IllegalStateException("Null gradient array encountered for variable: " + v.getVarName());
            }
            if (!Arrays.equals(v.getArr().shape(), g.getArr().shape())) {
                throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" + v.getVarName() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " + Arrays.toString(ga.shape()));
            }
            grad.put(v.getVarName(), ga.dup());
        }
        boolean bl = false;
        int totalCount = 0;
        double maxError = 0.0;
        for (SDVariable s : sd.variables()) {
            INDArray varMask;
            if (fnOutputs.contains(s.getVarName())) continue;
            if (skipVariables != null && skipVariables.contains(s.getVarName())) {
                log.info("Grad check: skipping variable \"{}\"", (Object)s.getVarName());
                continue;
            }
            String name = s.getVarName();
            INDArray a = s.getArr();
            long n = a.length();
            if (print) {
                log.info("Starting test for variable \"{}\" with {} values", (Object)s.getVarName(), (Object)n);
            }
            NdIndexIterator iter = new NdIndexIterator('c', a.shape());
            INDArray iNDArray = varMask = gradCheckMask == null ? null : gradCheckMask.get(s.getVarName());
            if (varMask != null) {
                Preconditions.checkState((boolean)a.equalShapes(varMask), (String)"Variable \"%s\": Gradient check mask and array shapes must be equal: got %s vs. mask shape %s", (Object)s.getVarName(), (Object)a.shape(), (Object)varMask.shape());
                Preconditions.checkState((varMask.dataType() == DataType.BOOL ? 1 : 0) != 0, (String)"Variable \"%s\": Gradient check mask must be BOOLEAN datatype, got %s", (Object)s.getVarName(), (Object)varMask.dataType());
            }
            int i = 0;
            while (iter.hasNext()) {
                boolean maskValue;
                long[] idx = iter.next();
                String strIdx = null;
                if (print) {
                    strIdx = Arrays.toString(idx).replaceAll(" ", "");
                }
                boolean bl2 = maskValue = varMask == null || varMask.getDouble(idx) != 0.0;
                if (!maskValue) continue;
                ++totalCount;
                double orig = a.getDouble(idx);
                a.putScalar(idx, orig + eps);
                double scorePlus = 0.0;
                Map<String, INDArray> m = sd.exec(placeholderValues, lossFnVariables);
                for (INDArray arr : m.values()) {
                    scorePlus += arr.sumNumber().doubleValue();
                }
                a.putScalar(idx, orig - eps);
                m = sd.exec(placeholderValues, lossFnVariables);
                double scoreMinus = 0.0;
                for (INDArray arr : m.values()) {
                    scoreMinus += arr.sumNumber().doubleValue();
                }
                a.putScalar(idx, orig);
                double numericalGrad = (scorePlus - scoreMinus) / (2.0 * eps);
                INDArray aGrad = (INDArray)grad.get(s.getVarName());
                double analyticGrad = aGrad.getDouble(idx);
                if (Double.isInfinite(numericalGrad) || Double.isNaN(numericalGrad)) {
                    throw new IllegalStateException("Numerical gradient was " + numericalGrad + " for variable \"" + name + "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
                }
                if (Double.isInfinite(analyticGrad) || Double.isNaN(analyticGrad)) {
                    throw new IllegalStateException("Analytic (SameDiff) gradient was " + analyticGrad + " for variable \"" + name + "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
                }
                double relError = numericalGrad == 0.0 && analyticGrad == 0.0 ? 0.0 : Math.abs(analyticGrad - numericalGrad) / Math.abs(Math.abs(analyticGrad) + Math.abs(numericalGrad));
                if (relError > maxError) {
                    maxError = relError;
                }
                if (relError > maxRelError || Double.isNaN(relError)) {
                    double absError = Math.abs(analyticGrad - numericalGrad);
                    if (absError < minAbsError) {
                        if (print) {
                            log.info("Param " + i + " (" + name + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsError);
                        }
                    } else {
                        if (print) {
                            log.info("Param " + i + " (" + name + strIdx + ") FAILED: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError + ", absError=" + absError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus);
                        }
                        if (exitOnFirstFailure) {
                            return false;
                        }
                        ++var19_24;
                    }
                } else if (print) {
                    log.info("Param " + i + " (" + name + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError);
                }
                ++i;
            }
        }
        if (print) {
            int nPass = totalCount - var19_24;
            log.info("GradCheckUtil.checkGradients(): " + totalCount + " params checked, " + nPass + " passed, " + (int)var19_24 + " failed. Largest relative error = " + maxError);
        }
        if (debugMode && !debugBefore) {
            sd.disableDebugging();
        }
        return var19_24 == false;
    }

    public static void validateInternalState(SameDiff sd, boolean generateAndCheckGradFn) {
        DifferentialFunction[] dfs = sd.functions();
        List<SDVariable> vars = sd.variables();
        HashSet<SDVariable> varsSet = new HashSet<SDVariable>(vars);
        Preconditions.checkState((vars.size() == varsSet.size() ? 1 : 0) != 0, (String)"Duplicate variables in variables() list");
        HashSet<String> varSetStr = new HashSet<String>();
        for (SDVariable sDVariable : vars) {
            if (varSetStr.contains(sDVariable.getVarName())) {
                throw new IllegalStateException("Variable with name " + sDVariable.getVarName() + " already encountered");
            }
            varSetStr.add(sDVariable.getVarName());
        }
        Map<String, SameDiffOp> ops = sd.getOps();
        Preconditions.checkState((dfs.length == ops.size() ? 1 : 0) != 0, (String)"All functions not present in incomingArgsReverse");
        for (DifferentialFunction df : dfs) {
            Preconditions.checkState((boolean)ops.containsKey(df.getOwnName()), (String)(df.getOwnName() + " not present in ops map"));
            List<String> str = ops.get(df.getOwnName()).getInputsToOp();
            if (str != null) {
                for (String s : str) {
                    Preconditions.checkState((boolean)varSetStr.contains(s), (String)("Variable " + s + " in op inputs not a known variable name"));
                }
            }
            if ((str = ops.get(df.getOwnName()).getOutputsOfOp()) == null) continue;
            for (String s : str) {
                Preconditions.checkState((boolean)varSetStr.contains(s), (String)("Variable " + s + " in op outputs not a known variable name"));
            }
        }
        HashMap<String, String> hashMap = new HashMap<String, String>();
        for (Map.Entry<String, SameDiffOp> e : ops.entrySet()) {
            List<String> varNames = e.getValue().getOutputsOfOp();
            if (varNames == null) continue;
            for (String s : varNames) {
                if (hashMap.containsKey(s)) {
                    throw new IllegalStateException("Already saw variable \"" + s + "\" as output for op \"" + (String)hashMap.get(s) + "\": expected variables to be present as an output only once; also seen as output for op \"" + e.getKey() + "\"");
                }
                hashMap.put(s, e.getKey());
            }
        }
        Map<String, Variable> variableMap = sd.getVariables();
        Preconditions.checkState((vars.size() == variableMap.size() ? 1 : 0) != 0, (String)"Variable map size check failed");
        for (Map.Entry<String, Variable> e : variableMap.entrySet()) {
            Preconditions.checkState((boolean)e.getKey().equals(e.getValue().getVariable().getVarName()), (String)"Name not equal");
        }
        if (generateAndCheckGradFn) {
            if (sd.getFunction("grad") == null) {
                sd.createGradFunction();
            }
            SameDiff gradFn = sd.getFunction("grad");
            GradCheckUtil.validateInternalState(gradFn, false);
            for (DifferentialFunction dfOrig : dfs) {
                Preconditions.checkNotNull((Object)gradFn.getFunctionById(dfOrig.getOwnName()), (String)("DifferentialFunction " + dfOrig.getOwnName() + " from original SameDiff instance not present in grad fn"));
            }
        }
    }

    private static <T> T getObject(String fieldName, Object from, Class<?> fromClass) {
        try {
            Field f = fromClass.getDeclaredField(fieldName);
            f.setAccessible(true);
            return (T)f.get(from);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}

