/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.profiler;

import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicLong;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.GridOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.MetaOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.profiler.data.StackAggregator;
import org.nd4j.linalg.profiler.data.StringAggregator;
import org.nd4j.linalg.profiler.data.StringCounter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OpProfiler {
    private static AtomicLong invocationsCount = new AtomicLong(0L);
    private static OpProfiler ourInstance = new OpProfiler();
    private static StringAggregator classAggergator = new StringAggregator();
    private static StringAggregator longAggergator = new StringAggregator();
    private static StringCounter classCounter = new StringCounter();
    private static StringCounter opCounter = new StringCounter();
    private static StringCounter classPairsCounter = new StringCounter();
    private static StringCounter opPairsCounter = new StringCounter();
    private static StringCounter matchingCounter = new StringCounter();
    private static StringCounter matchingCounterDetailed = new StringCounter();
    private static StringCounter matchingCounterInverted = new StringCounter();
    private static StringCounter orderCounter = new StringCounter();
    private static StackAggregator methodsAggregator = new StackAggregator();
    private static StackAggregator scalarAggregator = new StackAggregator();
    private static StackAggregator mixedOrderAggregator = new StackAggregator();
    private static StackAggregator nonEwsAggregator = new StackAggregator();
    private static StackAggregator stridedAggregator = new StackAggregator();
    private static StackAggregator tadStridedAggregator = new StackAggregator();
    private static StackAggregator tadNonEwsAggregator = new StackAggregator();
    private static StackAggregator blasAggregator = new StackAggregator();
    private static StringCounter blasOrderCounter = new StringCounter();
    private static Logger logger = LoggerFactory.getLogger(OpProfiler.class);
    private static final long THRESHOLD = 100000L;
    private String prevOpClass = "";
    private String prevOpName = "";
    private String prevOpMatching = "";
    private String prevOpMatchingDetailed = "";
    private String prevOpMatchingInverted = "";
    private long lastZ = 0L;

    public void reset() {
        invocationsCount.set(0L);
        classAggergator.reset();
        longAggergator.reset();
        classCounter.reset();
        opCounter.reset();
        classPairsCounter.reset();
        opPairsCounter.reset();
        matchingCounter.reset();
        matchingCounterDetailed.reset();
        matchingCounterInverted.reset();
        methodsAggregator.reset();
        scalarAggregator.reset();
        nonEwsAggregator.reset();
        stridedAggregator.reset();
        tadNonEwsAggregator.reset();
        tadStridedAggregator.reset();
        mixedOrderAggregator.reset();
        blasAggregator.reset();
        blasOrderCounter.reset();
        orderCounter.reset();
    }

    public static OpProfiler getInstance() {
        return ourInstance;
    }

    private OpProfiler() {
    }

    protected String getOpClass(Op op) {
        if (op instanceof ScalarOp) {
            return "ScalarOp";
        }
        if (op instanceof MetaOp) {
            return "MetaOp";
        }
        if (op instanceof GridOp) {
            return "GridOp";
        }
        if (op instanceof BroadcastOp) {
            return "BroadcastOp";
        }
        if (op instanceof RandomOp) {
            return "RandomOp";
        }
        if (op instanceof Accumulation) {
            return "AccumulationOp";
        }
        if (op instanceof TransformOp) {
            if (op.y() == null) {
                return "TransformOp";
            }
            return "PairWiseTransformOp";
        }
        if (op instanceof IndexAccumulation) {
            return "IndexAccumulationOp";
        }
        return "Unknown Op calls";
    }

    public void processScalarCall() {
        invocationsCount.incrementAndGet();
        scalarAggregator.incrementCount();
    }

    public void processOpCall(Op op) {
        PenaltyCause[] causes;
        invocationsCount.incrementAndGet();
        opCounter.incrementCount(op.name());
        String opClass = this.getOpClass(op);
        classCounter.incrementCount(opClass);
        if (op.x().data().address() == this.lastZ && op.z() == op.x() && op.y() == null) {
            matchingCounter.incrementCount(this.prevOpMatching + " -> " + opClass);
            matchingCounterDetailed.incrementCount(this.prevOpMatchingDetailed + " -> " + opClass + " " + op.name());
        } else {
            matchingCounter.totalsIncrement();
            matchingCounterDetailed.totalsIncrement();
            if (op.y() != null && op.y().data().address() == this.lastZ) {
                matchingCounterInverted.incrementCount(this.prevOpMatchingInverted + " -> " + opClass + " " + op.name());
            } else {
                matchingCounterInverted.totalsIncrement();
            }
        }
        this.lastZ = op.z().data().address();
        this.prevOpMatching = opClass;
        this.prevOpMatchingDetailed = opClass + " " + op.name();
        this.prevOpMatchingInverted = opClass + " " + op.name();
        this.updatePairs(op.name(), opClass);
        block5: for (PenaltyCause cause : causes = this.processOperands(op.x(), op.y(), op.z())) {
            switch (cause) {
                case NON_EWS_ACCESS: {
                    nonEwsAggregator.incrementCount();
                    continue block5;
                }
                case STRIDED_ACCESS: {
                    stridedAggregator.incrementCount();
                    continue block5;
                }
                case MIXED_ORDER: {
                    mixedOrderAggregator.incrementCount();
                    continue block5;
                }
            }
        }
    }

    public void processOpCall(Op op, DataBuffer ... tadBuffers) {
        PenaltyCause[] causes;
        this.processOpCall(op);
        block4: for (PenaltyCause cause : causes = this.processTADOperands(tadBuffers)) {
            switch (cause) {
                case TAD_NON_EWS_ACCESS: {
                    tadNonEwsAggregator.incrementCount();
                    continue block4;
                }
                case TAD_STRIDED_ACCESS: {
                    tadStridedAggregator.incrementCount();
                    continue block4;
                }
            }
        }
    }

    public StackAggregator getMixedOrderAggregator() {
        return mixedOrderAggregator;
    }

    public StackAggregator getScalarAggregator() {
        return scalarAggregator;
    }

    protected void updatePairs(String opName, String opClass) {
        String cOpNameKey = this.prevOpName + " -> " + opName;
        String cOpClassKey = this.prevOpClass + " -> " + opClass;
        classPairsCounter.incrementCount(cOpClassKey);
        opPairsCounter.incrementCount(cOpNameKey);
        this.prevOpName = opName;
        this.prevOpClass = opClass;
    }

    public void timeOpCall(Op op, long startTime) {
        long currentTime = System.nanoTime() - startTime;
        classAggergator.putTime(this.getOpClass(op), op, currentTime);
        if (currentTime > 100000L) {
            String keyExt = this.getOpClass(op) + " " + op.name() + " (" + op.opNum() + ")";
            longAggergator.putTime(keyExt, currentTime);
        }
    }

    @Deprecated
    public void processBlasCall(String blasOpName) {
        String key = "BLAS";
        invocationsCount.incrementAndGet();
        opCounter.incrementCount(blasOpName);
        classCounter.incrementCount(key);
        this.updatePairs(blasOpName, key);
        this.prevOpMatching = "";
        this.lastZ = 0L;
    }

    public void timeBlasCall() {
    }

    public void printOutDashboard() {
        logger.info("---Total Op Calls: {}", (Object)invocationsCount.get());
        System.out.println();
        logger.info("--- OpClass calls statistics: ---");
        System.out.println(classCounter.asString());
        System.out.println();
        logger.info("--- OpClass pairs statistics: ---");
        System.out.println(classPairsCounter.asString());
        System.out.println();
        logger.info("--- Individual Op calls statistics: ---");
        System.out.println(opCounter.asString());
        System.out.println();
        logger.info("--- Matching Op calls statistics: ---");
        System.out.println(matchingCounter.asString());
        System.out.println();
        logger.info("--- Matching detailed Op calls statistics: ---");
        System.out.println(matchingCounterDetailed.asString());
        System.out.println();
        logger.info("--- Matching inverts Op calls statistics: ---");
        System.out.println(matchingCounterInverted.asString());
        System.out.println();
        logger.info("--- Time for OpClass calls statistics: ---");
        System.out.println(classAggergator.asString());
        System.out.println();
        logger.info("--- Time for long Op calls statistics: ---");
        System.out.println(longAggergator.asString());
        System.out.println();
        logger.info("--- Time spent for Op calls statistics: ---");
        System.out.println(classAggergator.asPercentageString());
        System.out.println();
        logger.info("--- Time spent for long Op calls statistics: ---");
        System.out.println(longAggergator.asPercentageString());
        System.out.println();
        logger.info("--- Time spent within methods: ---");
        methodsAggregator.renderTree(true);
        System.out.println();
        logger.info("--- Bad strides stack tree: ---");
        System.out.println("Unique entries: " + stridedAggregator.getUniqueBranchesNumber());
        stridedAggregator.renderTree();
        System.out.println();
        logger.info("--- non-EWS access stack tree: ---");
        System.out.println("Unique entries: " + nonEwsAggregator.getUniqueBranchesNumber());
        nonEwsAggregator.renderTree();
        System.out.println();
        logger.info("--- Mixed orders access stack tree: ---");
        System.out.println("Unique entries: " + mixedOrderAggregator.getUniqueBranchesNumber());
        mixedOrderAggregator.renderTree();
        System.out.println();
        logger.info("--- TAD bad strides stack tree: ---");
        System.out.println("Unique entries: " + tadStridedAggregator.getUniqueBranchesNumber());
        tadStridedAggregator.renderTree();
        System.out.println();
        logger.info("--- TAD non-EWS access stack tree: ---");
        System.out.println("Unique entries: " + tadNonEwsAggregator.getUniqueBranchesNumber());
        tadNonEwsAggregator.renderTree();
        System.out.println();
        logger.info("--- Scalar access stack tree: ---");
        System.out.println("Unique entries: " + scalarAggregator.getUniqueBranchesNumber());
        scalarAggregator.renderTree(false);
        System.out.println();
        logger.info("--- Blas GEMM odrders count: ---");
        System.out.println(blasOrderCounter.asString());
        System.out.println();
        logger.info("--- BLAS access stack trace: ---");
        System.out.println("Unique entries: " + blasAggregator.getUniqueBranchesNumber());
        blasAggregator.renderTree(false);
        System.out.println();
    }

    public long getInvocationsCount() {
        return invocationsCount.get();
    }

    public void processStackCall(Op op, long timeStart) {
        long timeSpent = (System.nanoTime() - timeStart) / 1000L;
        methodsAggregator.incrementCount(timeSpent);
    }

    public String processOrders(INDArray ... operands) {
        StringBuffer buffer = new StringBuffer();
        for (int e = 0; e < operands.length; ++e) {
            if (operands[e] == null) {
                buffer.append("null");
            } else {
                buffer.append(new String("" + operands[e].ordering()).toUpperCase());
            }
            if (e >= operands.length - 1) continue;
            buffer.append(" x ");
        }
        orderCounter.incrementCount(buffer.toString());
        return buffer.toString();
    }

    public void processBlasCall(boolean isGemm, INDArray ... operands) {
        if (isGemm) {
            PenaltyCause[] causes;
            String key = this.processOrders(operands);
            blasOrderCounter.incrementCount(key);
            block8: for (PenaltyCause cause : causes = this.processOperands(operands)) {
                switch (cause) {
                    case NON_EWS_ACCESS: 
                    case STRIDED_ACCESS: 
                    case NONE: {
                        blasAggregator.incrementCount();
                        continue block8;
                    }
                }
            }
        } else {
            PenaltyCause[] causes;
            block9: for (PenaltyCause cause : causes = this.processOperands(operands)) {
                switch (cause) {
                    case NON_EWS_ACCESS: {
                        nonEwsAggregator.incrementCount();
                        continue block9;
                    }
                    case STRIDED_ACCESS: {
                        stridedAggregator.incrementCount();
                        continue block9;
                    }
                    case MIXED_ORDER: {
                        mixedOrderAggregator.incrementCount();
                        continue block9;
                    }
                }
            }
        }
    }

    public PenaltyCause[] processOperands(INDArray x, INDArray y) {
        ArrayList<PenaltyCause> penalties = new ArrayList<PenaltyCause>();
        if (x.ordering() != y.ordering()) {
            penalties.add(PenaltyCause.MIXED_ORDER);
        }
        if (x.elementWiseStride() < 1) {
            penalties.add(PenaltyCause.NON_EWS_ACCESS);
        } else if (y.elementWiseStride() < 1) {
            penalties.add(PenaltyCause.NON_EWS_ACCESS);
        }
        if (x.elementWiseStride() > 1) {
            penalties.add(PenaltyCause.STRIDED_ACCESS);
        } else if (y.elementWiseStride() > 1) {
            penalties.add(PenaltyCause.STRIDED_ACCESS);
        }
        if (penalties.isEmpty()) {
            penalties.add(PenaltyCause.NONE);
        }
        return penalties.toArray(new PenaltyCause[0]);
    }

    public PenaltyCause[] processTADOperands(DataBuffer ... tadBuffers) {
        ArrayList<PenaltyCause> causes = new ArrayList<PenaltyCause>();
        for (DataBuffer tadBuffer : tadBuffers) {
            if (tadBuffer == null) continue;
            int rank = tadBuffer.getInt(0L);
            int length = rank * 2 + 4;
            int ews = tadBuffer.getInt((long)(length - 2));
            if ((ews < 1 || rank > 2 || rank == 2 && tadBuffer.getInt(1L) > 1 && tadBuffer.getInt(2L) > 1) && !causes.contains((Object)PenaltyCause.TAD_NON_EWS_ACCESS)) {
                causes.add(PenaltyCause.TAD_NON_EWS_ACCESS);
                continue;
            }
            if (ews <= 1 || causes.contains((Object)PenaltyCause.TAD_STRIDED_ACCESS)) continue;
            causes.add(PenaltyCause.TAD_STRIDED_ACCESS);
        }
        if (causes.isEmpty()) {
            causes.add(PenaltyCause.NONE);
        }
        return causes.toArray(new PenaltyCause[0]);
    }

    public PenaltyCause[] processOperands(INDArray x, INDArray y, INDArray z) {
        if (y == null) {
            return this.processOperands(x, z);
        }
        if (x == z || y == z) {
            return this.processOperands(x, y);
        }
        PenaltyCause[] causeXY = this.processOperands(x, y);
        PenaltyCause[] causeXZ = this.processOperands(x, z);
        if (causeXY.length == 1 && causeXY[0] == PenaltyCause.NONE && causeXZ.length == 1 && causeXZ[0] == PenaltyCause.NONE) {
            return causeXY;
        }
        if (causeXY.length == 1 && causeXY[0] == PenaltyCause.NONE) {
            return causeXZ;
        }
        if (causeXZ.length == 1 && causeXZ[0] == PenaltyCause.NONE) {
            return causeXY;
        }
        return this.joinDistinct(causeXY, causeXZ);
    }

    protected PenaltyCause[] joinDistinct(PenaltyCause[] a, PenaltyCause[] b) {
        ArrayList<PenaltyCause> causes = new ArrayList<PenaltyCause>();
        for (PenaltyCause cause : a) {
            if (cause == null || causes.contains((Object)cause)) continue;
            causes.add(cause);
        }
        for (PenaltyCause cause : b) {
            if (cause == null || causes.contains((Object)cause)) continue;
            causes.add(cause);
        }
        return causes.toArray(new PenaltyCause[0]);
    }

    public PenaltyCause[] processOperands(INDArray ... operands) {
        if (operands == null) {
            return new PenaltyCause[]{PenaltyCause.NONE};
        }
        ArrayList<PenaltyCause> causes = new ArrayList<PenaltyCause>();
        for (int e = 0; e < operands.length - 1; ++e) {
            PenaltyCause[] lc;
            if (operands[e] == null && operands[e + 1] == null) continue;
            for (PenaltyCause cause : lc = this.processOperands(operands[e], operands[e + 1])) {
                if (cause == PenaltyCause.NONE || causes.contains((Object)cause)) continue;
                causes.add(cause);
            }
        }
        if (causes.isEmpty()) {
            causes.add(PenaltyCause.NONE);
        }
        return causes.toArray(new PenaltyCause[0]);
    }

    public void processMemoryAccess() {
    }

    public static enum PenaltyCause {
        NONE,
        NON_EWS_ACCESS,
        STRIDED_ACCESS,
        MIXED_ORDER,
        TAD_NON_EWS_ACCESS,
        TAD_STRIDED_ACCESS;

    }
}

