/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.validation;

import java.io.IOException;
import java.lang.reflect.Modifier;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.apache.commons.collections4.trie.PatriciaTrie;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.validation.GradCheckUtil;
import org.nd4j.autodiff.validation.OpTestCase;
import org.nd4j.autodiff.validation.TestCase;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.config.ND4JClassLoading;
import org.nd4j.common.function.Function;
import org.nd4j.common.primitives.Pair;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
import org.nd4j.linalg.api.ops.custom.BarnesHutGains;
import org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize;
import org.nd4j.linalg.api.ops.custom.BitCast;
import org.nd4j.linalg.api.ops.custom.SpTreeCell;
import org.nd4j.linalg.api.ops.custom.ToggleBits;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMax;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMin;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMax;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMin;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRSubOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastNotEqual;
import org.nd4j.linalg.api.ops.impl.grid.FreeGridOp;
import org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex;
import org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalizationDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2dDerivative;
import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.CosineDistanceLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.HingeLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.HuberLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.LogLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.LogPoissonLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.MeanPairwiseSquaredErrorLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.MeanSquaredErrorLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.SigmoidCrossEntropyLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.SparseSoftmaxCrossEntropyLossWithLogitsBp;
import org.nd4j.linalg.api.ops.impl.meta.InvertedPredicateMetaOp;
import org.nd4j.linalg.api.ops.impl.meta.PostulateMetaOp;
import org.nd4j.linalg.api.ops.impl.meta.PredicateMetaOp;
import org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp;
import org.nd4j.linalg.api.ops.impl.nlp.CbowRound;
import org.nd4j.linalg.api.ops.impl.nlp.SkipGramRound;
import org.nd4j.linalg.api.ops.impl.reduce.HashCode;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments;
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
import org.nd4j.linalg.api.ops.impl.reduce.bool.IsNaN;
import org.nd4j.linalg.api.ops.impl.reduce.bp.CumProdBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.CumSumBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.DotBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm1Bp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm2Bp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.NormMaxBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.ProdBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.SquaredNormBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.StandardDeviationBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.impl.reduce3.EqualsWithEps;
import org.nd4j.linalg.api.ops.impl.scalar.PowDerivative;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarRemainder;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarSetValue;
import org.nd4j.linalg.api.ops.impl.shape.BroadcastDynamicShape;
import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix;
import org.nd4j.linalg.api.ops.impl.shape.Eye;
import org.nd4j.linalg.api.ops.impl.shape.MergeSum;
import org.nd4j.linalg.api.ops.impl.shape.OneHot;
import org.nd4j.linalg.api.ops.impl.shape.ReductionShape;
import org.nd4j.linalg.api.ops.impl.shape.Shape;
import org.nd4j.linalg.api.ops.impl.shape.ShapeN;
import org.nd4j.linalg.api.ops.impl.shape.SizeAt;
import org.nd4j.linalg.api.ops.impl.shape.bp.ConcatBp;
import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp;
import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp;
import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.Histogram;
import org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot;
import org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf;
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance;
import org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits;
import org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.InTopK;
import org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNormBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogicalAnd;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogicalNot;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogicalOr;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogicalXor;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttentionBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits;
import org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits;
import org.nd4j.linalg.api.ops.impl.transforms.custom.StandardizeBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.PReluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryMinimalRelativeError;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.AddBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.ModBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SquaredDifferenceBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Not;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMeanBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMinBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentProdBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentSumBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMeanBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMinBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SwishDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.TanDerivative;
import org.nd4j.linalg.api.ops.persistence.RestoreV2;
import org.nd4j.linalg.api.ops.persistence.SaveV2;
import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal;
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
import org.nd4j.linalg.api.ops.random.impl.AlphaDropOut;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution;
import org.nd4j.linalg.api.ops.random.impl.BinomialDistributionEx;
import org.nd4j.linalg.api.ops.random.impl.Choice;
import org.nd4j.linalg.api.ops.random.impl.DropOut;
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.api.ops.random.impl.Linspace;
import org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution;
import org.nd4j.linalg.api.ops.random.impl.ProbablisticMerge;
import org.nd4j.linalg.api.ops.random.impl.Range;
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
import org.nd4j.linalg.api.ops.random.impl.UniformDistribution;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.OpDef;

public class OpValidation {
    private static final Logger log = LoggerFactory.getLogger(OpValidation.class);
    private static List<Class> allOps;
    private static List<Long> nonMappedLibnd4jOps;
    private static Map<Long, Pair<List<String>, CustomOpDescriptor>> dedupedCustomOps;
    private static int countTotalLibnd4jOps;
    private static Map<Class, Integer> gradCheckCoverageCountPerClass;
    private static Map<Class, Integer> fwdPassCoverageCountPerClass;
    private static Map<Class, Integer> singleOpTestCountPerClass;
    private static Map<Class, Integer> opsWithTFMappingTFImportCounts;
    private static Map<String, Integer> tfMappedOpsImportTestCounts;

    public static String validate(TestCase testCase) {
        return OpValidation.validate(testCase, false);
    }

    public static String validate(TestCase testCase, boolean exceptionsAsErrorMsg) {
        try {
            return OpValidation.validateHelper(testCase);
        }
        catch (Throwable t) {
            if (exceptionsAsErrorMsg) {
                log.info("Exception encountered - returning as error message", t);
                return "EXCEPTION: " + t.getMessage();
            }
            throw t;
        }
    }

    private static String validateHelper(TestCase testCase) {
        testCase.assertConfigValid();
        OpValidation.collectCoverageInformation(testCase);
        ByteBuffer serializedBeforeExec = null;
        if (testCase.testFlatBufferSerialization() == TestCase.TestSerialization.BEFORE_EXEC || testCase.testFlatBufferSerialization() == TestCase.TestSerialization.BOTH) {
            serializedBeforeExec = testCase.sameDiff().asFlatBuffers(true);
            Preconditions.checkNotNull((Object)serializedBeforeExec, (String)"Serialization failed? Null output");
        }
        SameDiff sameDiff = testCase.sameDiff();
        List<Listener> listeners = sameDiff.getListeners();
        if (testCase.fwdTestFns() != null && testCase.fwdTestFns().size() > 0) {
            Map<String, INDArray> out;
            SameDiff sd = testCase.sameDiff();
            Set<String> reqVars = testCase.fwdTestFns().keySet();
            try {
                out = sd.output(testCase.placeholderValues(), new ArrayList<String>(reqVars));
            }
            catch (Exception e) {
                throw new RuntimeException("Error during forward pass testing" + testCase.testNameErrMsg(), e);
            }
            for (Map.Entry<String, Function<INDArray, String>> e : testCase.fwdTestFns().entrySet()) {
                String error;
                SDVariable v = sd.getVariable(e.getKey());
                if (v == null) {
                    throw new IllegalStateException("Test case has expected result function defined for variable \"" + e.getKey() + "\" but SameDiff instance does not have a variable for this name" + testCase.testNameErrMsg());
                }
                INDArray actual = out.get(v.name());
                if (actual == null) {
                    throw new IllegalStateException("Null INDArray after forward pass for variable \"" + e.getKey() + "\"");
                }
                try {
                    error = (String)e.getValue().apply((Object)actual);
                }
                catch (Throwable t) {
                    throw new IllegalStateException("Error checking forward pass for variable \"" + e.getKey() + "\": exception was thrown by forward pass validation function", t);
                }
                if (error == null) continue;
                return testCase.testNameErrMsg() + ": Variable " + e.getKey() + " failed: " + error;
            }
            ByteBuffer serializedAfterExec = null;
            if (testCase.testFlatBufferSerialization() == TestCase.TestSerialization.BEFORE_EXEC || testCase.testFlatBufferSerialization() == TestCase.TestSerialization.BOTH) {
                serializedAfterExec = testCase.sameDiff().asFlatBuffers(true);
                Preconditions.checkNotNull((Object)serializedAfterExec, (String)"Serialization failed? Null output");
            }
            if (serializedBeforeExec != null) {
                OpValidation.checkDeserializedEquality(sd, serializedBeforeExec, testCase);
            }
        }
        if (testCase.gradientCheck()) {
            boolean ok;
            try {
                ok = GradCheckUtil.checkGradients(testCase);
            }
            catch (Throwable t) {
                t.printStackTrace();
                throw new IllegalStateException("Exception encountered during gradient check" + testCase.testNameErrMsg(), t);
            }
            if (!ok) {
                return "Gradient check failed" + testCase.testNameErrMsg();
            }
        }
        return null;
    }

    public static void checkDeserializedEquality(SameDiff original, ByteBuffer bbSerialized, TestCase tc) {
        SameDiff deserialized;
        try {
            deserialized = SameDiff.fromFlatBuffers(bbSerialized);
        }
        catch (IOException e) {
            throw new RuntimeException("IOException deserializing from FlatBuffers", e);
        }
        List<SDVariable> vars = original.variables();
        List<SDVariable> varsDe = deserialized.variables();
        Preconditions.checkState((vars.size() == varsDe.size() ? 1 : 0) != 0, (String)"Number of variables differs: expected %s, got %s", (int)vars.size(), (int)varsDe.size());
        for (int i = 0; i < vars.size(); ++i) {
            SDVariable vO = vars.get(i);
            SDVariable vD = varsDe.get(i);
            Preconditions.checkState((boolean)vO.name().equals(vD.name()), (String)"Names should be equal for variable %s: expected %s vs %s", (Object)i, (Object)vO.name(), (Object)vD.name());
        }
        Map<String, SameDiffOp> opsOrig = original.getOps();
        Map<String, SameDiffOp> opsDeser = deserialized.getOps();
        Preconditions.checkState((boolean)opsOrig.keySet().equals(opsDeser.keySet()), (String)"Op names differs: %s vs. %s", opsOrig.keySet(), opsDeser.keySet());
        for (String s : opsOrig.keySet()) {
            Iterator orig = opsOrig.get(s);
            SameDiffOp des = opsDeser.get(s);
            Preconditions.checkState((boolean)((SameDiffOp)((Object)orig)).getName().equals(des.getName()), (String)"Names differ: %s vs %s", (Object)((SameDiffOp)((Object)orig)).getName(), (Object)des.getName());
            Preconditions.checkState((((SameDiffOp)((Object)orig)).getInputsToOp() == null == (des.getInputsToOp() == null) ? 1 : 0) != 0, (String)"Inputs differ: %s vs. %s", ((SameDiffOp)((Object)orig)).getInputsToOp(), des.getInputsToOp());
            Preconditions.checkState((((SameDiffOp)((Object)orig)).getInputsToOp() == null || ((SameDiffOp)((Object)orig)).getInputsToOp().equals(des.getInputsToOp()) ? 1 : 0) != 0, (String)"Inputs differ: %s vs. %s", ((SameDiffOp)((Object)orig)).getInputsToOp(), des.getInputsToOp());
            Preconditions.checkState((((SameDiffOp)((Object)orig)).getOutputsOfOp() == null == (des.getOutputsOfOp() == null) ? 1 : 0) != 0, (String)"Outputs differ: %s vs. %s", ((SameDiffOp)((Object)orig)).getOutputsOfOp(), des.getOutputsOfOp());
            Preconditions.checkState((((SameDiffOp)((Object)orig)).getOutputsOfOp() == null || ((SameDiffOp)((Object)orig)).getOutputsOfOp().equals(des.getOutputsOfOp()) ? 1 : 0) != 0, (String)"Outputs differ: %s vs. %s", ((SameDiffOp)((Object)orig)).getOutputsOfOp(), des.getOutputsOfOp());
            Preconditions.checkState((((SameDiffOp)((Object)orig)).getControlDeps() == null == (des.getControlDeps() == null) ? 1 : 0) != 0, (String)"Control dependencies differ: %s vs. %s", ((SameDiffOp)((Object)orig)).getControlDeps(), des.getControlDeps());
            Preconditions.checkState((((SameDiffOp)((Object)orig)).getControlDeps() == null || ((SameDiffOp)((Object)orig)).getControlDeps().equals(des.getControlDeps()) ? 1 : 0) != 0, (String)"Control dependencies differ: %s vs. %s", ((SameDiffOp)((Object)orig)).getControlDeps(), des.getControlDeps());
            Preconditions.checkState((((SameDiffOp)((Object)orig)).getVarControlDeps() == null == (des.getVarControlDeps() == null) ? 1 : 0) != 0, (String)"Op variable control dependencies differ: %s vs. %s", ((SameDiffOp)((Object)orig)).getVarControlDeps(), des.getVarControlDeps());
            Preconditions.checkState((((SameDiffOp)((Object)orig)).getVarControlDeps() == null || ((SameDiffOp)((Object)orig)).getVarControlDeps().equals(des.getVarControlDeps()) ? 1 : 0) != 0, (String)"Op variable control dependencies differ: %s vs. %s", ((SameDiffOp)((Object)orig)).getControlDeps(), des.getControlDeps());
            Preconditions.checkState((((SameDiffOp)((Object)orig)).getControlDepFor() == null == (des.getControlDepFor() == null) ? 1 : 0) != 0, (String)"Op control dependencies for list differ: %s vs. %s", ((SameDiffOp)((Object)orig)).getControlDepFor(), des.getControlDepFor());
            Preconditions.checkState((((SameDiffOp)((Object)orig)).getControlDepFor() == null || ((SameDiffOp)((Object)orig)).getControlDepFor().equals(des.getControlDepFor()) ? 1 : 0) != 0, (String)"Op variable control dependencies differ: %s vs. %s", ((SameDiffOp)((Object)orig)).getControlDepFor(), des.getControlDepFor());
            Preconditions.checkState((boolean)((SameDiffOp)((Object)orig)).getOp().getClass().equals(des.getOp().getClass()), (String)"Classes differ: %s v. %s", ((SameDiffOp)((Object)orig)).getOp().getClass(), des.getOp().getClass());
        }
        HashSet<String> phBefore = new HashSet<String>();
        HashSet<String> phAfter = new HashSet<String>();
        for (Variable v : original.getVariables().values()) {
            if (!v.getVariable().isPlaceHolder()) continue;
            phBefore.add(v.getName());
        }
        for (Variable v : deserialized.getVariables().values()) {
            if (!v.getVariable().isPlaceHolder()) continue;
            phAfter.add(v.getName());
        }
        if (phBefore == null) {
            Preconditions.checkState((phAfter == null || phAfter.size() == 0 ? 1 : 0) != 0, (String)"%s", phAfter);
        } else {
            Preconditions.checkState((phAfter != null ? 1 : 0) != 0, (String)"Placeholders after deserialization was null");
            Preconditions.checkState((boolean)phBefore.equals(phAfter), (String)"Before: %s, after deserialization: %s", phBefore, phAfter);
        }
        PatriciaTrie<Variable> varsBefore = original.getVariables();
        PatriciaTrie<Variable> varsAfter = deserialized.getVariables();
        Preconditions.checkState((boolean)varsBefore.keySet().equals(varsAfter.keySet()), (String)"Variable keysets do not match: %s vs %s", varsBefore.keySet(), varsAfter.keySet());
        for (String s : varsBefore.keySet()) {
            Variable vB = (Variable)varsBefore.get(s);
            Variable vA = (Variable)varsAfter.get(s);
            Preconditions.checkState((boolean)vB.getName().equals(vA.getName()), (String)"Variable names do not match: %s vs %s", (Object)vA.getName(), (Object)vB.getName());
            Preconditions.checkState((vB.getVariable().getVariableType() == vA.getVariable().getVariableType() ? 1 : 0) != 0, (String)"Variable types do not match: %s - %s vs %s", (Object)s, (Object)((Object)vB.getVariable().getVariableType()), (Object)((Object)vA.getVariable().getVariableType()));
            OpValidation.equalConsideringNull(vB.getInputsForOp(), vA.getInputsForOp(), "%s - Input to ops differ: %s vs. %s", s, vB.getInputsForOp(), vA.getInputsForOp());
            Preconditions.checkState((vB.getOutputOfOp() == null && vA.getOutputOfOp() == null || vB.getOutputOfOp().equals(vA.getOutputOfOp()) ? 1 : 0) != 0, (String)"%s - Output of op differ: %s vs. %s", (Object)s, (Object)vB.getOutputOfOp(), (Object)vA.getOutputOfOp());
            OpValidation.equalConsideringNull(vB.getControlDeps(), vA.getControlDeps(), "%s - Control dependencies differ: %s vs. %s", s, vB.getControlDeps(), vA.getControlDeps());
            OpValidation.equalConsideringNull(vB.getControlDepsForOp(), vA.getControlDepsForOp(), "%s - Control dependencies for ops differ: %s vs. %s", s, vB.getControlDepsForOp(), vA.getControlDepsForOp());
            OpValidation.equalConsideringNull(vB.getControlDepsForVar(), vA.getControlDepsForVar(), "%s - Control dependencies for vars differ: %s vs. %s", s, vB.getControlDepsForVar(), vA.getControlDepsForVar());
        }
        List<String> lossVarBefore = original.getLossVariables();
        List<String> lossVarAfter = deserialized.getLossVariables();
        if (lossVarBefore == null || lossVarBefore.isEmpty()) {
            Preconditions.checkState((lossVarAfter == null || lossVarAfter.isEmpty() ? 1 : 0) != 0, (String)"Loss variables ");
        } else {
            Preconditions.checkState((boolean)lossVarBefore.equals(lossVarAfter), (String)"Loss variables are not equal after deserialization: %s vs %s", lossVarBefore, lossVarAfter);
        }
        if (tc.fwdTestFns() != null && !tc.fwdTestFns().isEmpty()) {
            Map<String, INDArray> outOrig = original.outputAll(tc.placeholderValues());
            Map<String, INDArray> outDe = deserialized.outputAll(tc.placeholderValues());
            Preconditions.checkState((boolean)outOrig.keySet().equals(outDe.keySet()), (String)"Keysets for execution after deserialization does not match key set for original model");
            for (String s : outOrig.keySet()) {
                INDArray orig = outOrig.get(s);
                INDArray deser = outDe.get(s);
                Function<INDArray, String> f = tc.fwdTestFns().get(s);
                String err = null;
                if (f != null) {
                    err = (String)f.apply((Object)deser);
                } else if (!orig.equals(deser)) {
                    long count;
                    long l = count = orig.dataType().isNumerical() ? Nd4j.getExecutioner().execAndReturn(new MatchCondition(orig, Conditions.isNan(), new int[0])).getFinalResult().longValue() : -1L;
                    if (orig.dataType().isNumerical() && count > 0L && orig.equalShapes(deser)) {
                        long count2 = Nd4j.getExecutioner().execAndReturn(new MatchCondition(deser, Conditions.isNan(), new int[0])).getFinalResult().longValue();
                        if (count != count2) {
                            err = "INDArray equality failed";
                        } else {
                            NdIndexIterator iter = new NdIndexIterator(orig.shape());
                            while (iter.hasNext()) {
                                long[] i = iter.next();
                                double d1 = orig.getDouble(i);
                                double d2 = deser.getDouble(i);
                                if (Double.isNaN(d1) == Double.isNaN(d2) && Double.isInfinite(d1) == Double.isInfinite(d2) && !(Math.abs(d1 - d2) > 1.0E-5)) continue;
                                err = "INDArray equality failed";
                                break;
                            }
                        }
                    } else {
                        err = "INDArray equality failed";
                    }
                }
                Preconditions.checkState((err == null ? 1 : 0) != 0, (String)"Variable result (%s) failed check - \"%ndSInfo\" vs \"%ndSInfo\" - %nd10 vs %nd10\nError:%s", (Object)s, (Object)orig, (Object)deser, (Object)orig, (Object)deser, (Object)err);
            }
        }
    }

    protected static void equalConsideringNull(List<String> l1, List<String> l2, String msg, Object ... args) {
        boolean empty2;
        boolean empty1 = l1 == null || l1.isEmpty();
        boolean bl = empty2 = l2 == null || l2.isEmpty();
        if (empty1 && empty2) {
            return;
        }
        Preconditions.checkState((l1 == null || l1.equals(l2) ? 1 : 0) != 0, (String)msg, (Object[])args);
    }

    public static String validate(OpTestCase testCase) {
        int i;
        List<LongShapeDescriptor> outShapes;
        OpValidation.collectCoverageInformation(testCase);
        try {
            outShapes = Nd4j.getExecutioner().calculateOutputShape(testCase.op());
        }
        catch (Throwable t) {
            throw new IllegalStateException("Error calculating output shapes during op validation", t);
        }
        if (outShapes.size() != testCase.testFns().size()) {
            return "Expected number of output shapes and number of outputs differ. " + outShapes.size() + " output shapes, but OpTestCase specifies " + testCase.testFns().size() + " outputs expected";
        }
        for (i = 0; i < outShapes.size(); ++i) {
            LongShapeDescriptor act = outShapes.get(i);
            LongShapeDescriptor exp = testCase.expShapes().get(i);
            if (!Objects.equals((Object)exp.dataType(), (Object)act.dataType())) {
                return "Shape function check failed for output " + i + ": expected shape " + exp + ", actual shape " + act;
            }
            if (Arrays.equals(act.getShape(), exp.getShape())) continue;
            return "Shape function check failed for output " + i + ": expected shape " + exp + ", actual shape " + act;
        }
        try {
            Nd4j.getExecutioner().execAndReturn(testCase.op());
        }
        catch (Throwable t) {
            throw new IllegalStateException("Error during op execution", t);
        }
        for (i = 0; i < testCase.testFns().size(); ++i) {
            String error;
            try {
                error = (String)testCase.testFns().get(i).apply((Object)testCase.op().outputArguments().get(i));
            }
            catch (Throwable t) {
                throw new IllegalStateException("Exception thrown during op output validation for output " + i, t);
            }
            if (error == null) continue;
            return "Output " + i + " failed: " + error;
        }
        return null;
    }

    /*
     * WARNING - void declaration
     */
    private static void collectCoverageInformation(TestCase testCase) {
        DifferentialFunction df;
        void var6_10;
        SameDiff sd = testCase.sameDiff();
        DifferentialFunction[] functions = sd.ops();
        HashSet backpropSeen = new HashSet();
        DifferentialFunction[] differentialFunctionArray = functions;
        int n = differentialFunctionArray.length;
        boolean bl = false;
        while (var6_10 < n) {
            df = differentialFunctionArray[var6_10];
            backpropSeen.add(df.getClass());
            ++var6_10;
        }
        for (Class clazz : backpropSeen) {
            if (gradCheckCoverageCountPerClass.containsKey(clazz)) {
                gradCheckCoverageCountPerClass.put(clazz, gradCheckCoverageCountPerClass.get(clazz) + 1);
                continue;
            }
            gradCheckCoverageCountPerClass.put(clazz, 1);
        }
        HashSet seen = null;
        if (testCase.fwdTestFns() != null) {
            for (String string : testCase.fwdTestFns().keySet()) {
                df = sd.getVariableOutputOp(string);
                if (df == null) continue;
                if (seen == null) {
                    seen = new HashSet();
                }
                seen.add(df.getClass());
            }
        }
        if (seen != null) {
            for (Class clazz : seen) {
                if (fwdPassCoverageCountPerClass.containsKey(clazz)) {
                    fwdPassCoverageCountPerClass.put(clazz, fwdPassCoverageCountPerClass.get(clazz) + 1);
                    continue;
                }
                fwdPassCoverageCountPerClass.put(clazz, 1);
            }
        }
    }

    private static void collectCoverageInformation(OpTestCase testCase) {
        if (singleOpTestCountPerClass.containsKey(testCase.op().getClass())) {
            singleOpTestCountPerClass.put(testCase.op().getClass(), singleOpTestCountPerClass.get(testCase.op().getClass()) + 1);
        } else {
            singleOpTestCountPerClass.put(testCase.op().getClass(), 1);
        }
    }

    public static void collectTensorflowImportCoverage(SameDiff graph) {
        for (SameDiffOp op : graph.getOps().values()) {
            DifferentialFunction d = op.getOp();
            String[] tfNames = null;
            try {
                tfNames = d.tensorflowNames();
            }
            catch (Throwable t) {
                continue;
            }
            if (tfNames == null || tfNames.length <= 0) continue;
            Object currCount = opsWithTFMappingTFImportCounts.get(d.getClass());
            if (currCount == null) {
                currCount = 0;
            }
            String[] stringArray = currCount;
            currCount = (Integer)currCount + 1;
            Integer n = currCount;
            opsWithTFMappingTFImportCounts.put(d.getClass(), (Integer)currCount);
            currCount = fwdPassCoverageCountPerClass.get(d.getClass());
            if (currCount == null) {
                currCount = 0;
            }
            stringArray = currCount;
            currCount = (Integer)currCount + 1;
            n = currCount;
            fwdPassCoverageCountPerClass.put(d.getClass(), (Integer)currCount);
            for (String s : tfNames) {
                currCount = tfMappedOpsImportTestCounts.get(s);
                if (currCount == null) {
                    currCount = 0;
                }
                Object object = currCount;
                currCount = (Integer)currCount + 1;
                Integer n2 = currCount;
                tfMappedOpsImportTestCounts.put(s, (Integer)currCount);
            }
        }
    }

    private static void initializeCoverage() {
        ImmutableSet info;
        try {
            info = ClassPath.from((ClassLoader)DifferentialFunctionClassHolder.class.getClassLoader()).getTopLevelClassesRecursive("org.nd4j.linalg.api.ops");
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        Map<String, CustomOpDescriptor> customOps = Nd4j.getExecutioner().getCustomOperations();
        dedupedCustomOps = new HashMap<Long, Pair<List<String>, CustomOpDescriptor>>();
        for (Map.Entry<String, CustomOpDescriptor> e : customOps.entrySet()) {
            List l;
            Pair p;
            long hash = e.getValue().getHash();
            if (!dedupedCustomOps.containsKey(hash)) {
                p = new Pair(new ArrayList(), (Object)e.getValue());
                dedupedCustomOps.put(hash, (Pair<List<String>, CustomOpDescriptor>)p);
            }
            if ((l = (List)(p = dedupedCustomOps.get(hash)).getFirst()).contains(e.getKey())) continue;
            l.add(e.getKey());
        }
        HashSet<Long> notSeenCustomOps = new HashSet<Long>(dedupedCustomOps.keySet());
        allOps = new ArrayList<Class>(gradCheckCoverageCountPerClass.keySet());
        for (ClassPath.ClassInfo c : info) {
            CustomOpDescriptor d;
            Class clazz = ND4JClassLoading.loadClassByName((String)c.getName());
            Objects.requireNonNull(clazz);
            if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || !DifferentialFunction.class.isAssignableFrom(clazz)) continue;
            if (DifferentialFunction.class.isAssignableFrom(clazz) && !clazz.getSimpleName().contains("Old")) {
                allOps.add(clazz);
            }
            String opName = null;
            try {
                opName = ((DifferentialFunction)clazz.newInstance()).opName();
            }
            catch (Exception e) {
                log.warn("Could not instantiate object of type {}", (Object)clazz.getName(), (Object)e);
            }
            if (opName == null || (d = customOps.get(opName)) == null) continue;
            notSeenCustomOps.remove(d.getHash());
        }
        countTotalLibnd4jOps = dedupedCustomOps.size();
        nonMappedLibnd4jOps = new ArrayList<Long>(notSeenCustomOps);
        Collections.sort(nonMappedLibnd4jOps, new Comparator<Long>(){

            @Override
            public int compare(Long o1, Long o2) {
                Pair<List<String>, CustomOpDescriptor> p1 = dedupedCustomOps.get(o1);
                Pair<List<String>, CustomOpDescriptor> p2 = dedupedCustomOps.get(o2);
                return ((String)((List)p1.getKey()).get(0)).compareTo((String)((List)p2.getKey()).get(0));
            }
        });
        Collections.sort(allOps, new Comparator<Class>(){

            @Override
            public int compare(Class o1, Class o2) {
                return o1.getName().compareTo(o2.getName());
            }
        });
        for (Class c : allOps) {
            gradCheckCoverageCountPerClass.put(c, 0);
            fwdPassCoverageCountPerClass.put(c, 0);
            singleOpTestCountPerClass.put(c, 0);
        }
    }

    public static void logCoverageInformation(boolean logAdequatelyTested, boolean logInadequate, boolean logUnmappedLibnd4jOps, boolean logUntestedTFImport, boolean logUnmappedTFOps) {
        boolean gradExcluded;
        int countFwdValidation;
        int countBackpropSeen;
        Set<Class> excludedFromBackpropCoverage = OpValidation.excludedFromGradientCheckCoverage();
        Set<Class> excludedFromAllTestCoverage = OpValidation.excludedFromAllTests();
        String numFormat = "%3d";
        int countAdequate = 0;
        int countAdequateBwd = 0;
        int countAdequateFwd = 0;
        if (logAdequatelyTested) {
            log.info(" --- Adequately Tested Classes ---");
            for (Class c : allOps) {
                if (excludedFromAllTestCoverage.contains(c)) continue;
                countBackpropSeen = gradCheckCoverageCountPerClass.get(c);
                countFwdValidation = fwdPassCoverageCountPerClass.get(c) + singleOpTestCountPerClass.get(c);
                if (countBackpropSeen > 0) {
                    ++countAdequateBwd;
                }
                if (countFwdValidation > 0) {
                    ++countAdequateFwd;
                }
                if (countFwdValidation > 0 && countBackpropSeen > 0) {
                    ++countAdequate;
                }
                gradExcluded = excludedFromBackpropCoverage.contains(c);
                if (countFwdValidation <= 0 || countBackpropSeen <= 0 && !gradExcluded) continue;
                if (gradExcluded) {
                    log.info("Forward: {} tests, GradCheck: <excluded> for op {}", (Object)String.format(numFormat, countFwdValidation), (Object)c.getName());
                    continue;
                }
                log.info("Forward: {} tests, GradCheck: {} tests  for op {}", new Object[]{String.format(numFormat, countFwdValidation), String.format(numFormat, countBackpropSeen), c.getName()});
            }
        }
        if (logInadequate) {
            log.info(" --- Classes NOT Tested Adequately ---");
            for (Class c : allOps) {
                if (excludedFromAllTestCoverage.contains(c)) continue;
                countBackpropSeen = gradCheckCoverageCountPerClass.get(c);
                countFwdValidation = fwdPassCoverageCountPerClass.get(c) + singleOpTestCountPerClass.get(c);
                gradExcluded = excludedFromBackpropCoverage.contains(c);
                if (countFwdValidation != 0 && (countBackpropSeen != 0 || gradExcluded)) continue;
                if (gradExcluded) {
                    log.info("Forward: {} tests, GradCheck: <excluded> for op {}", (Object)String.format(numFormat, countFwdValidation), (Object)c.getName());
                    continue;
                }
                log.info("Forward: {} tests, GradCheck: {} tests  for op {}", new Object[]{String.format(numFormat, countFwdValidation), String.format(numFormat, countBackpropSeen), c.getName()});
            }
        }
        int countLibnd4jIgnored = 0;
        if (logUnmappedLibnd4jOps) {
            Set<String> ignoreLibnd4j = OpValidation.excludeFromLibnd4jCustomOpMapping();
            log.info(" --- Libnd4j Ops Not Mapped ---");
            for (long l : nonMappedLibnd4jOps) {
                Pair<List<String>, CustomOpDescriptor> p = dedupedCustomOps.get(l);
                boolean foundIgnore = false;
                for (String s : (List)p.getFirst()) {
                    if (!ignoreLibnd4j.contains(s)) continue;
                    foundIgnore = true;
                    ++countLibnd4jIgnored;
                    break;
                }
                if (foundIgnore) continue;
                log.info("Not mapped libnd4j custom op: {} (hash: {})", p.getFirst(), (Object)l);
            }
        }
        Map<String, DifferentialFunction> tfOpsMap = DifferentialFunctionClassHolder.getInstance().getTensorFlowNames();
        int totalTFMappedOps = tfOpsMap.size();
        int tfOpsWithImportTests = 0;
        if (logUntestedTFImport) {
            log.info(" --- Ops with TF Mapping but No TF Import Tests ---");
        }
        ArrayList<String> tfOpsKeys = new ArrayList<String>(tfOpsMap.keySet());
        Collections.sort(tfOpsKeys);
        Set<String> tfIgnored = OpValidation.excludeFromTfImportCoverage();
        int tfImportIgnored = 0;
        for (String s : tfOpsKeys) {
            Integer count = tfMappedOpsImportTestCounts.get(s);
            if (count == null || count == 0) {
                if (tfIgnored.contains(s)) {
                    ++tfImportIgnored;
                    continue;
                }
                if (!logUntestedTFImport) continue;
                log.info("TF mapped op with no import tests: {}", (Object)s);
                continue;
            }
            ++tfOpsWithImportTests;
        }
        if (logUnmappedTFOps) {
            Map<String, OpDef> allTFOps;
            log.info(" --- TF Ops Not Mapped for Import ---");
            try {
                allTFOps = TensorflowDescriptorParser.opDescs();
            }
            catch (Throwable t) {
                throw new RuntimeException(t);
            }
            ArrayList notMapped = new ArrayList();
            for (String s : allTFOps.keySet()) {
                if (DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(s) != null || tfIgnored.contains(s)) continue;
                notMapped.add(s);
            }
            Collections.sort(notMapped);
            int subsets = (int)Math.ceil(notMapped.size() / 10);
            for (int i = 0; i < subsets; ++i) {
                log.info("TF ops not mapped for import: {}", notMapped.subList(10 * i, Math.min(10 * (i + 1), notMapped.size())));
            }
        }
        int totalFwd = 0;
        for (Class c : allOps) {
            if (excludedFromAllTestCoverage.contains(c)) continue;
            ++totalFwd;
        }
        int totalBwd = 0;
        for (Class c : allOps) {
            if (OpValidation.isBackpropOp(c)) continue;
            ++totalBwd;
        }
        double fracFwdAdequate = (double)countAdequateFwd / (double)totalFwd;
        double fracBwdAdequate = (double)countAdequateBwd / (double)totalBwd;
        double fracAdequate = (double)countAdequate / (double)allOps.size();
        String pcFwd = String.format("%.2f", fracFwdAdequate * 100.0);
        String pcBwd = String.format("%.2f", fracBwdAdequate * 100.0);
        String pc = String.format("%.2f", fracAdequate * 100.0);
        int countTf = DifferentialFunctionClassHolder.getInstance().getCountTotalTfOps();
        int countTfMapped = DifferentialFunctionClassHolder.getInstance().getCountTotalMappedOps();
        double tfFrac = (double)countTfMapped / (double)countTf;
        String fracTfStr = String.format("%.2f", 100.0 * tfFrac);
        int countLibnd4jMapped = countTotalLibnd4jOps - nonMappedLibnd4jOps.size();
        String fracLibnd4j = String.format("%.2f", 100.0 * ((double)countLibnd4jMapped / (double)(countTotalLibnd4jOps - countLibnd4jIgnored)));
        String fracTFMappedTested = String.format("%.2f", 100.0 * (double)tfOpsWithImportTests / (double)(totalTFMappedOps - tfImportIgnored));
        log.info("*****************************************************");
        log.info("Op Validation:                        {} of {} classes with adequate tests ({}% coverage)", new Object[]{countAdequate, totalFwd, pc});
        log.info("Forward pass tests:                   {} of {} classes ({}% coverage)", new Object[]{countAdequateFwd, totalFwd, pcFwd});
        log.info("Gradient check tests:                 {} of {} classes ({}% coverage)", new Object[]{countAdequateBwd, totalBwd, pcBwd});
        log.info("({} ops excluded from gradient check coverage)", (Object)excludedFromBackpropCoverage.size());
        log.info("({} ops excluded from fwd+gradient tests)", (Object)excludedFromAllTestCoverage.size());
        log.info("TF mapped ops:                        {} of {} ({}%)", new Object[]{countTfMapped, countTf, fracTfStr});
        log.info("SD ops with TF import mapping + test  {} of {} ({}%) - {} ignored for coverage", new Object[]{tfOpsWithImportTests, totalTFMappedOps - tfImportIgnored, fracTFMappedTested, tfImportIgnored});
        log.info("Libnd4j mapped ops:                   {} of {} ({}%) - {} excluded for coverage", new Object[]{countLibnd4jMapped, countTotalLibnd4jOps, fracLibnd4j, countLibnd4jIgnored});
        log.info("*****************************************************");
    }

    private static boolean isBackpropOp(Class<?> c) {
        String name = c.getSimpleName();
        return name.contains("Bp") || name.contains("Derivative") || name.contains("Grad");
    }

    private static Set<Class> excludedFromAllTests() {
        List<Class> list = Arrays.asList(DynamicCustomOp.class, GradientBackwardsMarker.class, EqualsWithEps.class, FreeGridOp.class, MergeSum.class, ScalarRemainder.class, RestoreV2.class, SaveV2.class, ScalarSetValue.class, BinomialDistributionEx.class, BroadcastAMax.class, BroadcastAMin.class, BroadcastAddOp.class, BroadcastCopyOp.class, BroadcastDivOp.class, BroadcastEqualTo.class, BroadcastGreaterThan.class, BroadcastGreaterThanOrEqual.class, BroadcastLessThan.class, BroadcastLessThanOrEqual.class, BroadcastMax.class, BroadcastMin.class, BroadcastMulOp.class, BroadcastNotEqual.class, BroadcastRDivOp.class, BroadcastRSubOp.class, BroadcastSubOp.class, AddBpOp.class, DivBpOp.class, FloorDivBpOp.class, FloorModBpOp.class, MulBpOp.class, RDivBpOp.class, RSubBpOp.class, SquaredDifferenceBpOp.class, SubBpOp.class, CumProdBp.class, DotBp.class, SquaredNormBp.class, SoftmaxBp.class, CubeDerivative.class, GELUDerivative.class, PreciseGELUDerivative.class, HardSigmoidDerivative.class, HardTanhDerivative.class, LeakyReLUDerivative.class, LogSoftMaxDerivative.class, RationalTanhDerivative.class, RectifiedTanhDerivative.class, Relu6Derivative.class, PReluBp.class, SELUDerivative.class, SigmoidDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative.class, SoftSignDerivative.class, TanhDerivative.class, SwishDerivative.class, TanDerivative.class, TanhDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class, PowDerivative.class, RectifiedLinearDerivative.class, CubeBp.class, EluBp.class, HardSigmoidBp.class, HardTanhBp.class, LeakyReLUBp.class, RationalTanhBp.class, RectifiedTanhBp.class, SeluBp.class, SoftPlusBp.class, SoftSignBp.class, ThresholdReluBp.class, ModBpOp.class, BiasAddGrad.class, ConcatBp.class, TileBp.class, BatchNormDerivative.class, Conv2DDerivative.class, Conv3DDerivative.class, DeConv2DDerivative.class, LocalResponseNormalizationDerivative.class, Pooling2DDerivative.class, Pooling3DDerivative.class, SConv2DDerivative.class, Upsampling2dDerivative.class, Im2colBp.class, SliceBp.class, StridedSliceBp.class, MmulBp.class, DotProductAttentionBp.class, MultiHeadDotProductAttentionBp.class, LayerNormBp.class, StandardizeBp.class, DynamicPartitionBp.class, AbsoluteDifferenceLossBp.class, CosineDistanceLossBp.class, HingeLossBp.class, HuberLossBp.class, LogLossBp.class, LogPoissonLossBp.class, MeanPairwiseSquaredErrorLossBp.class, MeanSquaredErrorLossBp.class, SigmoidCrossEntropyLossBp.class, SoftmaxCrossEntropyLossBp.class, SparseSoftmaxCrossEntropyLossWithLogitsBp.class, SegmentMaxBp.class, SegmentMeanBp.class, SegmentMinBp.class, SegmentProdBp.class, SegmentSumBp.class, UnsortedSegmentMaxBp.class, UnsortedSegmentMeanBp.class, UnsortedSegmentMinBp.class, UnsortedSegmentProdBp.class, UnsortedSegmentSqrtNBp.class, UnsortedSegmentSumBp.class, ExternalErrorsFunction.class, InvertedPredicateMetaOp.class, PostulateMetaOp.class, PredicateMetaOp.class, ReduceMetaOp.class, BarnesEdgeForces.class, BarnesHutGains.class, BarnesHutSymmetrize.class, SpTreeCell.class, CbowRound.class, SkipGramRound.class, HashCode.class, HashCode.class, BitCast.class, ToggleBits.class);
        return new HashSet<Class>(list);
    }

    private static Set<Class> excludedFromGradientCheckCoverage() {
        List<Class> list = Arrays.asList(DynamicCustomOp.class, EqualsWithEps.class, ConfusionMatrix.class, Eye.class, OneHot.class, BinaryMinimalRelativeError.class, BinaryMinimalRelativeError.class, InvertPermutation.class, ConfusionMatrix.class, Linspace.class, Assert.class, Any.class, All.class, org.nd4j.linalg.api.ops.impl.reduce.bool.IsInf.class, IsInf.class, IsNaN.class, org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN.class, BooleanNot.class, Not.class, MatchConditionTransform.class, InTopK.class, IsNonDecreasing.class, IsStrictlyIncreasing.class, IsNumericTensor.class, FirstIndex.class, LastIndex.class, ArgMax.class, ArgMin.class, Shape.class, ShapeN.class, SizeAt.class, BroadcastDynamicShape.class, ReductionShape.class, ShiftBits.class, RShiftBits.class, BitsHammingDistance.class, CyclicShiftBits.class, CyclicRShiftBits.class, RandomStandardNormal.class, DistributionUniform.class, AlphaDropOut.class, BernoulliDistribution.class, BinomialDistribution.class, BinomialDistributionEx.class, Choice.class, DropOut.class, DropOutInverted.class, GaussianDistribution.class, LogNormalDistribution.class, ProbablisticMerge.class, Range.class, TruncatedNormalDistribution.class, UniformDistribution.class, Col2Im.class, NormalizeMoments.class, CumProdBp.class, CumSumBp.class, DotBp.class, MaxBp.class, MeanBp.class, MinBp.class, Norm1Bp.class, Norm2Bp.class, NormMaxBp.class, ProdBp.class, StandardDeviationBp.class, SumBp.class, VarianceBp.class, LogicalAnd.class, LogicalNot.class, LogicalOr.class, LogicalXor.class, Histogram.class);
        return new HashSet<Class>(list);
    }

    private static Set<String> excludeFromTfImportCoverage() {
        List<String> list = Arrays.asList("Reverse", "LogSigmoid", "HardSigmoid", "SpaceToBatch", "BatchToSpace", "Pad", "TopK", "InTopK", "BatchMatrixDeterminant", "BatchMatrixDiagPart", "BatchMatrixDiag", "BatchMatrixBandPart", "BatchMatrixInverse", "BatchMatrixSetDiag", "BatchMatrixSolve", "BatchMatrixSolveLs", "BatchMatrixTriangularSolve", "BatchSelfAdjointEig", "BatchSelfAdjointEigV2", "BatchSvd", "ExperimentalBytesProducedStatsDataset", "ExperimentalCSVDataset", "ExperimentalDatasetCardinality", "ExperimentalDatasetToTFRecord", "ExperimentalDenseToSparseBatchDataset", "ExperimentalDirectedInterleaveDataset", "ExperimentalGroupByReducerDataset", "ExperimentalGroupByWindowDataset", "ExperimentalIdentityIndexedDataset", "ExperimentalIgnoreErrorsDataset", "ExperimentalIndexedDatasetGet", "ExperimentalIndexedDatasetMaterialize", "ExperimentalIteratorGetDevice", "ExperimentalLMDBDataset", "ExperimentalLatencyStatsDataset", "ExperimentalMapAndBatchDataset", "ExperimentalMapDataset", "ExperimentalMatchingFilesDataset", "ExperimentalMaterializedIndexDatasetHandle", "ExperimentalMaxIntraOpParallelismDataset", "ExperimentalNonSerializableDataset", "ExperimentalNumaMapAndBatchDataset", "ExperimentalParallelInterleaveDataset", "ExperimentalParseExampleDataset", "ExperimentalPrivateThreadPoolDataset", "ExperimentalRandomDataset", "ExperimentalScanDataset", "ExperimentalSetStatsAggregatorDataset", "ExperimentalSleepDataset", "ExperimentalSlidingWindowDataset", "ExperimentalSqlDataset", "ExperimentalStatsAggregatorHandle", "ExperimentalStatsAggregatorSummary", "ExperimentalThreadPoolDataset", "ExperimentalThreadPoolHandle", "ExperimentalUnbatchDataset", "ExperimentalUniqueDataset", "DebugIdentity", "NcclAllReduce", "NcclBroadcast", "NcclReduce", "PyFunc", "PyFuncStateless", "QuantizedAdd", "QuantizedAvgPool", "QuantizedBatchNormWithGlobalNormalization", "QuantizedBiasAdd", "QuantizedConcat", "QuantizedConv2D", "QuantizedInstanceNorm", "QuantizedMatMul", "QuantizedMaxPool", "QuantizedMul", "QuantizedRelu", "QuantizedRelu6", "QuantizedReluX", "QuantizedReshape", "QuantizedResizeBilinear", "HardTanh", "Swish", "RDiv", "DivScalar", "LogX", "RationalTanh", "absargmax", "absargmin", "entropy_shannon", "count_zero", "SaveV2", "LoadV2", "RestoreV2", "RandomCrop");
        return new HashSet<String>(list);
    }

    private static Set<String> excludeFromLibnd4jCustomOpMapping() {
        HashSet<String> out = new HashSet<String>();
        Collections.addAll(out, "TestOp2i2o", "testop2i2o", "firas_sparse", "test_output_reshape", "test_scalar", "testcustom", "testreduction", "to_double", "to_float16", "to_float32", "to_int32", "to_int64", "to_uint32", "to_uint64");
        return out;
    }

    static {
        gradCheckCoverageCountPerClass = new LinkedHashMap<Class, Integer>();
        fwdPassCoverageCountPerClass = new LinkedHashMap<Class, Integer>();
        singleOpTestCountPerClass = new LinkedHashMap<Class, Integer>();
        opsWithTFMappingTFImportCounts = new LinkedHashMap<Class, Integer>();
        tfMappedOpsImportTestCounts = new LinkedHashMap<String, Integer>();
        OpValidation.initializeCoverage();
    }
}

