/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.ui.stats;

import java.io.InputStream;
import java.io.Serializable;
import java.lang.management.GarbageCollectorMXBean;
import java.lang.management.ManagementFactory;
import java.lang.management.OperatingSystemMXBean;
import java.lang.management.RuntimeMXBean;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.UUID;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.api.storage.StorageMetaData;
import org.deeplearning4j.api.storage.listener.RoutingIterationListener;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.ui.stats.api.Histogram;
import org.deeplearning4j.ui.stats.api.StatsInitializationConfiguration;
import org.deeplearning4j.ui.stats.api.StatsInitializationReport;
import org.deeplearning4j.ui.stats.api.StatsReport;
import org.deeplearning4j.ui.stats.api.StatsType;
import org.deeplearning4j.ui.stats.api.StatsUpdateConfiguration;
import org.deeplearning4j.ui.stats.impl.DefaultStatsInitializationConfiguration;
import org.deeplearning4j.ui.stats.impl.DefaultStatsUpdateConfiguration;
import org.deeplearning4j.util.UIDProvider;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseStatsListener
implements RoutingIterationListener {
    private static final Logger log = LoggerFactory.getLogger(BaseStatsListener.class);
    public static final String TYPE_ID = "StatsListener";
    private StatsStorageRouter router;
    private final StatsInitializationConfiguration initConfig;
    private StatsUpdateConfiguration updateConfig;
    private String sessionID;
    private String workerID;
    private transient List<GarbageCollectorMXBean> gcBeans;
    private Map<String, Pair<Long, Long>> gcStatsAtLastReport;
    private List<ModelInfo> modelInfos = new ArrayList<ModelInfo>();
    private Map<String, Histogram> activationHistograms;
    private Map<String, Double> meanActivations;
    private Map<String, Double> stdevActivations;
    private Map<String, Double> meanMagActivations;
    private Map<String, Histogram> gradientHistograms;
    private Map<String, Double> meanGradients;
    private Map<String, Double> stdevGradient;
    private Map<String, Double> meanMagGradients;
    private Map<Integer, Pointer> devPointers = new HashMap<Integer, Pointer>();

    private ModelInfo getModelInfo(Model model) {
        ModelInfo mi = null;
        for (ModelInfo m : this.modelInfos) {
            if (m.model != model) continue;
            mi = m;
            break;
        }
        if (mi == null) {
            mi = new ModelInfo(model);
            this.modelInfos.add(mi);
        }
        return mi;
    }

    public BaseStatsListener(StatsStorageRouter router) {
        this(router, null, null, null, null);
    }

    public BaseStatsListener(StatsStorageRouter router, int listenerFrequency) {
        this(router, null, new DefaultStatsUpdateConfiguration.Builder().reportingFrequency(listenerFrequency).build(), null, null);
    }

    public BaseStatsListener(StatsStorageRouter router, StatsInitializationConfiguration initConfig, StatsUpdateConfiguration updateConfig, String sessionID, String workerID) {
        this.router = router;
        this.initConfig = initConfig == null ? new DefaultStatsInitializationConfiguration(true, true, true) : initConfig;
        this.updateConfig = updateConfig == null ? new DefaultStatsUpdateConfiguration.Builder().build() : updateConfig;
        this.sessionID = sessionID == null ? UUID.randomUUID().toString() : sessionID;
        this.workerID = workerID == null ? UIDProvider.getJVMUID() + "_" + Thread.currentThread().getId() : workerID;
    }

    public abstract StatsInitializationReport getNewInitializationReport();

    public abstract StatsReport getNewStatsReport();

    public abstract StorageMetaData getNewStorageMetaData(long var1, String var3, String var4);

    public StatsInitializationConfiguration getInitConfig() {
        return this.initConfig;
    }

    public StatsUpdateConfiguration getUpdateConfig() {
        return this.updateConfig;
    }

    public void setUpdateConfig(StatsUpdateConfiguration newConfig) {
        this.updateConfig = newConfig;
    }

    public void setStorageRouter(StatsStorageRouter router) {
        this.router = router;
    }

    public StatsStorageRouter getStorageRouter() {
        return this.router;
    }

    public void setWorkerID(String workerID) {
        this.workerID = workerID;
    }

    public String getWorkerID() {
        return this.workerID;
    }

    public void setSessionID(String sessionID) {
        this.sessionID = sessionID;
    }

    public String getSessionID() {
        return this.sessionID;
    }

    private String getSessionID(Model model) {
        if (model instanceof MultiLayerNetwork || model instanceof ComputationGraph) {
            return this.sessionID;
        }
        if (model instanceof Layer) {
            Layer l = (Layer)model;
            int layerIdx = l.getIndex();
            return this.sessionID + "_layer" + layerIdx;
        }
        return this.sessionID;
    }

    public void onEpochStart(Model model) {
    }

    public void onEpochEnd(Model model) {
    }

    public void onForwardPass(Model model, List<INDArray> activations) {
        int iterCount = this.getModelInfo(model).iterCount;
        if (this.calcFromActivations() && (iterCount == 0 || iterCount % this.updateConfig.reportingFrequency() == 0)) {
            HashMap<String, INDArray> activationsMap = new HashMap<String, INDArray>();
            int count = 0;
            for (INDArray arr : activations) {
                String layerName = count == 0 ? "input" : String.valueOf(count - 1);
                activationsMap.put(layerName, arr);
                ++count;
            }
            this.onForwardPass(model, activationsMap);
        }
    }

    public void onForwardPass(Model model, Map<String, INDArray> activations) {
        int iterCount = this.getModelInfo(model).iterCount;
        if (this.calcFromActivations() && this.updateConfig.reportingFrequency() > 0 && (iterCount == 0 || iterCount % this.updateConfig.reportingFrequency() == 0)) {
            if (this.updateConfig.collectHistograms(StatsType.Activations)) {
                this.activationHistograms = BaseStatsListener.getHistograms(activations, this.updateConfig.numHistogramBins(StatsType.Activations));
            }
            if (this.updateConfig.collectMean(StatsType.Activations)) {
                this.meanActivations = BaseStatsListener.calculateSummaryStats(activations, StatType.Mean);
            }
            if (this.updateConfig.collectStdev(StatsType.Activations)) {
                this.stdevActivations = BaseStatsListener.calculateSummaryStats(activations, StatType.Stdev);
            }
            if (this.updateConfig.collectMeanMagnitudes(StatsType.Activations)) {
                this.meanMagActivations = BaseStatsListener.calculateSummaryStats(activations, StatType.MeanMagnitude);
            }
        }
    }

    public void onGradientCalculation(Model model) {
        int iterCount = this.getModelInfo(model).iterCount;
        if (this.calcFromGradients() && this.updateConfig.reportingFrequency() > 0 && (iterCount == 0 || iterCount % this.updateConfig.reportingFrequency() == 0)) {
            Gradient g = model.gradient();
            if (this.updateConfig.collectHistograms(StatsType.Gradients)) {
                this.gradientHistograms = BaseStatsListener.getHistograms(g.gradientForVariable(), this.updateConfig.numHistogramBins(StatsType.Gradients));
            }
            if (this.updateConfig.collectMean(StatsType.Gradients)) {
                this.meanGradients = BaseStatsListener.calculateSummaryStats(g.gradientForVariable(), StatType.Mean);
            }
            if (this.updateConfig.collectStdev(StatsType.Gradients)) {
                this.stdevGradient = BaseStatsListener.calculateSummaryStats(g.gradientForVariable(), StatType.Stdev);
            }
            if (this.updateConfig.collectMeanMagnitudes(StatsType.Gradients)) {
                this.meanMagGradients = BaseStatsListener.calculateSummaryStats(g.gradientForVariable(), StatType.MeanMagnitude);
            }
        }
    }

    private boolean calcFromActivations() {
        return this.updateConfig.collectMean(StatsType.Activations) || this.updateConfig.collectStdev(StatsType.Activations) || this.updateConfig.collectMeanMagnitudes(StatsType.Activations) || this.updateConfig.collectHistograms(StatsType.Activations);
    }

    private boolean calcFromGradients() {
        return this.updateConfig.collectMean(StatsType.Gradients) || this.updateConfig.collectStdev(StatsType.Gradients) || this.updateConfig.collectMeanMagnitudes(StatsType.Gradients) || this.updateConfig.collectHistograms(StatsType.Gradients);
    }

    public void onBackwardPass(Model model) {
    }

    public void iterationDone(Model model, int iteration, int epoch) {
        ModelInfo modelInfo = this.getModelInfo(model);
        boolean backpropParamsOnly = this.backpropParamsOnly(model);
        long currentTime = this.getTime();
        if (modelInfo.iterCount == 0) {
            modelInfo.initTime = currentTime;
            this.doInit(model);
        }
        if (this.updateConfig.collectPerformanceStats()) {
            this.updateExamplesMinibatchesCounts(model);
        }
        if (this.updateConfig.reportingFrequency() > 1 && (iteration == 0 || iteration % this.updateConfig.reportingFrequency() != 0)) {
            modelInfo.iterCount = iteration;
            return;
        }
        StatsReport report = this.getNewStatsReport();
        report.reportIDs(this.getSessionID(model), TYPE_ID, this.workerID, System.currentTimeMillis());
        if (this.updateConfig.collectPerformanceStats()) {
            double minibatchesPerSecond;
            double examplesPerSecond;
            if (modelInfo.iterCount == 0) {
                examplesPerSecond = 0.0;
                minibatchesPerSecond = 0.0;
            } else {
                long deltaTimeMS = currentTime - modelInfo.lastReportTime;
                examplesPerSecond = 1000.0 * (double)modelInfo.examplesSinceLastReport / (double)deltaTimeMS;
                minibatchesPerSecond = 1000.0 * (double)modelInfo.minibatchesSinceLastReport / (double)deltaTimeMS;
            }
            long totalRuntimeMS = currentTime - modelInfo.initTime;
            report.reportPerformance(totalRuntimeMS, modelInfo.totalExamples, modelInfo.totalMinibatches, examplesPerSecond, minibatchesPerSecond);
            modelInfo.examplesSinceLastReport = 0;
            modelInfo.minibatchesSinceLastReport = 0;
        }
        if (this.updateConfig.collectMemoryStats()) {
            Runtime runtime = Runtime.getRuntime();
            long jvmTotal = runtime.totalMemory();
            long jvmMax = runtime.maxMemory();
            long offheapTotal = Pointer.totalBytes();
            long offheapMax = Pointer.maxBytes();
            long[] gpuCurrentBytes = null;
            long[] gpuMaxBytes = null;
            NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
            int nDevices = nativeOps.getAvailableDevices();
            if (nDevices > 0) {
                gpuCurrentBytes = new long[nDevices];
                gpuMaxBytes = new long[nDevices];
                for (int i = 0; i < nDevices; ++i) {
                    try {
                        Pointer p = this.getDevicePointer(i);
                        if (p == null) {
                            gpuMaxBytes[i] = 0L;
                            gpuCurrentBytes[i] = 0L;
                            continue;
                        }
                        gpuMaxBytes[i] = nativeOps.getDeviceTotalMemory(p);
                        gpuCurrentBytes[i] = gpuMaxBytes[i] - nativeOps.getDeviceFreeMemory(p);
                        continue;
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            }
            report.reportMemoryUse(jvmTotal, jvmMax, offheapTotal, offheapMax, gpuCurrentBytes, gpuMaxBytes);
        }
        if (this.updateConfig.collectGarbageCollectionStats()) {
            long timeMs;
            if (modelInfo.lastReportIteration == -1 || this.gcBeans == null) {
                this.gcBeans = ManagementFactory.getGarbageCollectorMXBeans();
                this.gcStatsAtLastReport = new HashMap<String, Pair<Long, Long>>();
                for (GarbageCollectorMXBean bean : this.gcBeans) {
                    long count = bean.getCollectionCount();
                    timeMs = bean.getCollectionTime();
                    this.gcStatsAtLastReport.put(bean.getName(), (Pair<Long, Long>)new Pair((Object)count, (Object)timeMs));
                }
            } else {
                for (GarbageCollectorMXBean bean : this.gcBeans) {
                    long count = bean.getCollectionCount();
                    timeMs = bean.getCollectionTime();
                    Pair<Long, Long> lastStats = this.gcStatsAtLastReport.get(bean.getName());
                    long deltaGCCount = count - (Long)lastStats.getFirst();
                    long deltaGCTime = timeMs - (Long)lastStats.getSecond();
                    lastStats.setFirst((Object)count);
                    lastStats.setSecond((Object)timeMs);
                    report.reportGarbageCollection(bean.getName(), (int)deltaGCCount, (int)deltaGCTime);
                }
            }
        }
        report.reportScore(model.score());
        if (this.updateConfig.collectLearningRates()) {
            HashMap<String, Double> lrs = new HashMap<String, Double>();
            if (model instanceof MultiLayerNetwork) {
                int layerIdx = 0;
                for (Layer l : ((MultiLayerNetwork)model).getLayers()) {
                    NeuralNetConfiguration conf = l.conf();
                    List paramkeys = l.conf().getLayer().initializer().paramKeys(l.conf().getLayer());
                    for (String s : paramkeys) {
                        double lr = conf.getLayer().getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount());
                        if (Double.isNaN(lr)) {
                            lr = 0.0;
                        }
                        lrs.put(layerIdx + "_" + s, lr);
                    }
                    ++layerIdx;
                }
            } else if (model instanceof ComputationGraph) {
                for (Layer l : ((ComputationGraph)model).getLayers()) {
                    NeuralNetConfiguration conf = l.conf();
                    String layerName = conf.getLayer().getLayerName();
                    List paramkeys = l.conf().getLayer().initializer().paramKeys(l.conf().getLayer());
                    for (String s : paramkeys) {
                        double lr = conf.getLayer().getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount());
                        if (Double.isNaN(lr)) {
                            lr = 0.0;
                        }
                        lrs.put(layerName + "_" + s, lr);
                    }
                }
            } else if (model instanceof Layer) {
                Layer l = (Layer)model;
                List paramkeys = l.conf().getLayer().initializer().paramKeys(l.conf().getLayer());
                for (String s : paramkeys) {
                    double lr = l.conf().getLayer().getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount());
                    lrs.put(s, lr);
                }
            }
            report.reportLearningRates(lrs);
        }
        if (this.updateConfig.collectHistograms(StatsType.Parameters)) {
            Map<String, Histogram> paramHistograms = BaseStatsListener.getHistograms(model.paramTable(backpropParamsOnly), this.updateConfig.numHistogramBins(StatsType.Parameters));
            report.reportHistograms(StatsType.Parameters, paramHistograms);
        }
        if (this.updateConfig.collectHistograms(StatsType.Gradients)) {
            report.reportHistograms(StatsType.Gradients, this.gradientHistograms);
        }
        if (this.updateConfig.collectHistograms(StatsType.Updates)) {
            Map<String, Histogram> updateHistograms = BaseStatsListener.getHistograms(model.gradient().gradientForVariable(), this.updateConfig.numHistogramBins(StatsType.Updates));
            report.reportHistograms(StatsType.Updates, updateHistograms);
        }
        if (this.updateConfig.collectHistograms(StatsType.Activations)) {
            report.reportHistograms(StatsType.Activations, this.activationHistograms);
        }
        if (this.updateConfig.collectMean(StatsType.Parameters)) {
            Map<String, Double> meanParams = BaseStatsListener.calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.Mean);
            report.reportMean(StatsType.Parameters, meanParams);
        }
        if (this.updateConfig.collectMean(StatsType.Gradients)) {
            report.reportMean(StatsType.Gradients, this.meanGradients);
        }
        if (this.updateConfig.collectMean(StatsType.Updates)) {
            Map<String, Double> meanUpdates = BaseStatsListener.calculateSummaryStats(model.gradient().gradientForVariable(), StatType.Mean);
            report.reportMean(StatsType.Updates, meanUpdates);
        }
        if (this.updateConfig.collectMean(StatsType.Activations)) {
            report.reportMean(StatsType.Activations, this.meanActivations);
        }
        if (this.updateConfig.collectStdev(StatsType.Parameters)) {
            Map<String, Double> stdevParams = BaseStatsListener.calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.Stdev);
            report.reportStdev(StatsType.Parameters, stdevParams);
        }
        if (this.updateConfig.collectStdev(StatsType.Gradients)) {
            report.reportStdev(StatsType.Gradients, this.stdevGradient);
        }
        if (this.updateConfig.collectStdev(StatsType.Updates)) {
            Map<String, Double> stdevUpdates = BaseStatsListener.calculateSummaryStats(model.gradient().gradientForVariable(), StatType.Stdev);
            report.reportStdev(StatsType.Updates, stdevUpdates);
        }
        if (this.updateConfig.collectStdev(StatsType.Activations)) {
            report.reportStdev(StatsType.Activations, this.stdevActivations);
        }
        if (this.updateConfig.collectMeanMagnitudes(StatsType.Parameters)) {
            Map<String, Double> meanMagParams = BaseStatsListener.calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.MeanMagnitude);
            report.reportMeanMagnitudes(StatsType.Parameters, meanMagParams);
        }
        if (this.updateConfig.collectMeanMagnitudes(StatsType.Gradients)) {
            report.reportMeanMagnitudes(StatsType.Gradients, this.meanMagGradients);
        }
        if (this.updateConfig.collectMeanMagnitudes(StatsType.Updates)) {
            Map<String, Double> meanMagUpdates = BaseStatsListener.calculateSummaryStats(model.gradient().gradientForVariable(), StatType.MeanMagnitude);
            report.reportMeanMagnitudes(StatsType.Updates, meanMagUpdates);
        }
        if (this.updateConfig.collectMeanMagnitudes(StatsType.Activations)) {
            report.reportMeanMagnitudes(StatsType.Activations, this.meanMagActivations);
        }
        long endTime = this.getTime();
        report.reportStatsCollectionDurationMS((int)(endTime - currentTime));
        modelInfo.lastReportTime = currentTime;
        modelInfo.lastReportIteration = iteration;
        report.reportIterationCount(iteration);
        this.router.putUpdate((Persistable)report);
        modelInfo.iterCount = iteration;
        this.activationHistograms = null;
        this.meanActivations = null;
        this.stdevActivations = null;
        this.meanMagActivations = null;
        this.gradientHistograms = null;
        this.meanGradients = null;
        this.stdevGradient = null;
        this.meanMagGradients = null;
    }

    private long getTime() {
        return System.currentTimeMillis();
    }

    private void doInit(Model model) {
        boolean backpropParamsOnly = this.backpropParamsOnly(model);
        long initTime = System.currentTimeMillis();
        StatsInitializationReport initReport = this.getNewInitializationReport();
        initReport.reportIDs(this.getSessionID(model), TYPE_ID, this.workerID, initTime);
        if (this.initConfig.collectSoftwareInfo()) {
            OperatingSystemMXBean osBean = ManagementFactory.getOperatingSystemMXBean();
            RuntimeMXBean runtime = ManagementFactory.getRuntimeMXBean();
            String arch = osBean.getArch();
            String osName = osBean.getName();
            String jvmName = runtime.getVmName();
            String jvmVersion = System.getProperty("java.version");
            String jvmSpecVersion = runtime.getSpecVersion();
            String nd4jBackendClass = Nd4j.getNDArrayFactory().getClass().getName();
            String nd4jDataTypeName = DataTypeUtil.getDtypeFromContext().name();
            String hostname = System.getenv("COMPUTERNAME");
            if (hostname == null || hostname.isEmpty()) {
                try {
                    Process proc = Runtime.getRuntime().exec("hostname");
                    InputStream stream = proc.getInputStream();
                    Object object = null;
                    try {
                        hostname = IOUtils.toString((InputStream)stream);
                    }
                    catch (Throwable throwable) {
                        object = throwable;
                        throw throwable;
                    }
                    finally {
                        if (stream != null) {
                            if (object != null) {
                                try {
                                    stream.close();
                                }
                                catch (Throwable throwable) {
                                    ((Throwable)object).addSuppressed(throwable);
                                }
                            } else {
                                stream.close();
                            }
                        }
                    }
                }
                catch (Exception proc) {
                    // empty catch block
                }
            }
            Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
            HashMap<String, String> envInfo = new HashMap<String, String>();
            for (Map.Entry entry : p.entrySet()) {
                Object v = entry.getValue();
                String value = v == null ? "" : v.toString();
                envInfo.put(entry.getKey().toString(), value);
            }
            initReport.reportSoftwareInfo(arch, osName, jvmName, jvmVersion, jvmSpecVersion, nd4jBackendClass, nd4jDataTypeName, hostname, UIDProvider.getJVMUID(), envInfo);
        }
        if (this.initConfig.collectHardwareInfo()) {
            int availableProcessors = Runtime.getRuntime().availableProcessors();
            NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
            int nDevices = nativeOps.getAvailableDevices();
            long[] deviceTotalMem = null;
            String[] deviceDescription = null;
            if (nDevices > 0) {
                deviceTotalMem = new long[nDevices];
                deviceDescription = new String[nDevices];
                for (int i = 0; i < nDevices; ++i) {
                    try {
                        Pointer p = this.getDevicePointer(i);
                        if (p == null) {
                            deviceTotalMem[i] = 0L;
                            deviceDescription[i] = "Device(" + i + ")";
                            continue;
                        }
                        deviceTotalMem[i] = nativeOps.getDeviceTotalMemory(p);
                        deviceDescription[i] = nativeOps.getDeviceName(p);
                        if (nDevices <= 1) continue;
                        deviceDescription[i] = deviceDescription[i] + " (" + i + ")";
                        continue;
                    }
                    catch (Exception e) {
                        log.debug("Error getting device info", (Throwable)e);
                    }
                }
            }
            long jvmMaxMemory = Runtime.getRuntime().maxMemory();
            long offheapMaxMemory = Pointer.maxBytes();
            initReport.reportHardwareInfo(availableProcessors, nDevices, jvmMaxMemory, offheapMaxMemory, deviceTotalMem, deviceDescription, UIDProvider.getHardwareUID());
        }
        if (this.initConfig.collectModelInfo()) {
            int numParams;
            int numLayers;
            String jsonConf;
            if (model instanceof MultiLayerNetwork) {
                MultiLayerNetwork net = (MultiLayerNetwork)model;
                jsonConf = net.getLayerWiseConfigurations().toJson();
                numLayers = net.getnLayers();
                numParams = net.numParams();
            } else if (model instanceof ComputationGraph) {
                ComputationGraph cg = (ComputationGraph)model;
                jsonConf = cg.getConfiguration().toJson();
                numLayers = cg.getNumLayers();
                numParams = cg.numParams();
            } else if (model instanceof Layer) {
                Layer l = (Layer)model;
                jsonConf = l.conf().toJson();
                numLayers = 1;
                numParams = l.numParams();
            } else {
                throw new RuntimeException("Invalid model: Expected MultiLayerNetwork or ComputationGraph. Got: " + (model == null ? null : model.getClass()));
            }
            Map paramMap = model.paramTable(backpropParamsOnly);
            String[] paramNames = new String[paramMap.size()];
            int i = 0;
            for (String s : paramMap.keySet()) {
                paramNames[i++] = s;
            }
            initReport.reportModelInfo(model.getClass().getName(), jsonConf, paramNames, numLayers, numParams);
        }
        StorageMetaData meta = this.getNewStorageMetaData(initTime, this.getSessionID(model), this.workerID);
        this.router.putStorageMetaData(meta);
        this.router.putStaticInfo((Persistable)initReport);
    }

    private synchronized Pointer getDevicePointer(int device) {
        if (this.devPointers.containsKey(device)) {
            return this.devPointers.get(device);
        }
        try {
            Class<?> c = Class.forName("org.nd4j.jita.allocator.pointers.CudaPointer");
            Constructor<?> constructor = c.getConstructor(Long.TYPE);
            Pointer p = (Pointer)constructor.newInstance(device);
            this.devPointers.put(device, p);
            return p;
        }
        catch (Throwable t) {
            this.devPointers.put(device, null);
            return null;
        }
    }

    private void updateExamplesMinibatchesCounts(Model model) {
        ModelInfo modelInfo = this.getModelInfo(model);
        int examplesThisMinibatch = 0;
        if (model instanceof MultiLayerNetwork) {
            examplesThisMinibatch = ((MultiLayerNetwork)model).batchSize();
        } else if (model instanceof ComputationGraph) {
            examplesThisMinibatch = ((ComputationGraph)model).batchSize();
        } else if (model instanceof Layer) {
            examplesThisMinibatch = ((Layer)model).getInputMiniBatchSize();
        }
        ModelInfo modelInfo2 = modelInfo;
        modelInfo2.examplesSinceLastReport = modelInfo2.examplesSinceLastReport + examplesThisMinibatch;
        modelInfo2 = modelInfo;
        modelInfo2.totalExamples = modelInfo2.totalExamples + (long)examplesThisMinibatch;
        modelInfo.minibatchesSinceLastReport++;
        modelInfo.totalMinibatches++;
    }

    private boolean backpropParamsOnly(Model model) {
        return model instanceof MultiLayerNetwork || model instanceof ComputationGraph;
    }

    private static Map<String, Double> calculateSummaryStats(Map<String, INDArray> source, StatType statType) {
        LinkedHashMap<String, Double> out = new LinkedHashMap<String, Double>();
        if (source == null) {
            return out;
        }
        for (Map.Entry<String, INDArray> entry : source.entrySet()) {
            double value;
            String name = entry.getKey();
            switch (statType) {
                case Mean: {
                    value = entry.getValue().meanNumber().doubleValue();
                    break;
                }
                case Stdev: {
                    value = entry.getValue().stdNumber().doubleValue();
                    break;
                }
                case MeanMagnitude: {
                    value = entry.getValue().norm1Number().doubleValue() / (double)entry.getValue().length();
                    break;
                }
                default: {
                    throw new RuntimeException();
                }
            }
            out.put(name, value);
        }
        return out;
    }

    private static Map<String, Histogram> getHistograms(Map<String, INDArray> map, int nBins) {
        LinkedHashMap<String, Histogram> out = new LinkedHashMap<String, Histogram>();
        if (map == null) {
            return out;
        }
        for (Map.Entry<String, INDArray> entry : map.entrySet()) {
            org.nd4j.linalg.api.ops.impl.transforms.Histogram hOp = new org.nd4j.linalg.api.ops.impl.transforms.Histogram(entry.getValue(), nBins);
            Nd4j.getExecutioner().exec((Op)hOp);
            INDArray bins = hOp.z();
            int[] count = new int[nBins];
            int i = 0;
            while ((long)i < bins.length()) {
                count[i] = (int)bins.getDouble((long)i);
                ++i;
            }
            double min = entry.getValue().minNumber().doubleValue();
            double max = entry.getValue().maxNumber().doubleValue();
            Histogram h = new Histogram(min, max, nBins, count);
            out.put(entry.getKey(), h);
        }
        return out;
    }

    public abstract BaseStatsListener clone();

    private static class ModelInfo
    implements Serializable {
        private final Model model;
        private long initTime;
        private long lastReportTime = -1L;
        private int lastReportIteration = -1;
        private int examplesSinceLastReport = 0;
        private int minibatchesSinceLastReport = 0;
        private long totalExamples = 0L;
        private long totalMinibatches = 0L;
        private int iterCount = 0;

        private ModelInfo(Model model) {
            this.model = model;
        }
    }

    private static enum StatType {
        Mean,
        Stdev,
        MeanMagnitude;

    }
}

