/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.validation.listeners;

import java.security.MessageDigest;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.dataset.api.MultiDataSet;

public class NonInplaceValidationListener
extends BaseListener {
    private static AtomicInteger useCounter = new AtomicInteger();
    private static AtomicInteger passCounter = new AtomicInteger();
    private static AtomicInteger failCounter = new AtomicInteger();
    protected INDArray[] opInputs;

    public NonInplaceValidationListener() {
        useCounter.getAndIncrement();
    }

    @Override
    public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
        if (op.getOp().isInPlace()) {
            return;
        }
        if (op.getOp() instanceof Op) {
            Op o = (Op)((Object)op.getOp());
            if (o.x() == null) {
                return;
            }
            this.opInputs = o.y() == null ? new INDArray[]{o.x().dup()} : new INDArray[]{o.x().dup(), o.y().dup()};
        } else if (op.getOp() instanceof DynamicCustomOp) {
            INDArray[] arr = ((DynamicCustomOp)op.getOp()).inputArguments();
            this.opInputs = new INDArray[arr.length];
            for (int i = 0; i < arr.length; ++i) {
                this.opInputs[i] = arr[i].dup();
            }
        } else {
            throw new IllegalStateException("Unknown op type: " + op.getOp().getClass());
        }
    }

    @Override
    public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
        MessageDigest md;
        INDArray[] inputsAfter;
        if (op.getOp().isInPlace()) {
            return;
        }
        if (op.getOp() instanceof Op) {
            Op o = (Op)((Object)op.getOp());
            if (o.x() == null) {
                return;
            }
            inputsAfter = o.y() == null ? new INDArray[]{o.x()} : new INDArray[]{o.x(), o.y()};
        } else if (op.getOp() instanceof DynamicCustomOp) {
            inputsAfter = ((DynamicCustomOp)op.getOp()).inputArguments();
        } else {
            throw new IllegalStateException("Unknown op type: " + op.getOp().getClass());
        }
        try {
            md = MessageDigest.getInstance("MD5");
        }
        catch (Throwable t) {
            throw new RuntimeException(t);
        }
        for (int i = 0; i < this.opInputs.length; ++i) {
            byte[] hash2;
            if (this.opInputs[i].isEmpty()) continue;
            byte[] before = this.opInputs[i].data().asBytes();
            INDArray after = inputsAfter[i];
            boolean dealloc = false;
            if (this.opInputs[i].ordering() != inputsAfter[i].ordering() || Arrays.equals(this.opInputs[i].stride(), inputsAfter[i].stride()) || this.opInputs[i].elementWiseStride() != inputsAfter[i].elementWiseStride()) {
                after = inputsAfter[i].dup();
                dealloc = true;
            }
            byte[] afterB = after.data().asBytes();
            byte[] hash1 = md.digest(before);
            boolean eq = Arrays.equals(hash1, hash2 = md.digest(afterB));
            if (eq) {
                passCounter.addAndGet(1);
            } else {
                failCounter.addAndGet(1);
            }
            Preconditions.checkState((boolean)eq, (String)"Input array for non-inplace op was modified during execution for op %s - input %s", op.getOp().getClass(), (Object)i);
            if (dealloc && after.closeable()) {
                after.close();
            }
            if (!this.opInputs[i].closeable()) continue;
            this.opInputs[i].close();
        }
    }

    @Override
    public boolean isActive(Operation operation) {
        return true;
    }

    public static AtomicInteger getUseCounter() {
        return useCounter;
    }

    public static AtomicInteger getPassCounter() {
        return passCounter;
    }

    public static AtomicInteger getFailCounter() {
        return failCounter;
    }
}

