/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.memory;

import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Map;
import lombok.NonNull;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryType;
import org.deeplearning4j.nn.conf.memory.MemoryUseMode;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.shade.jackson.annotation.JsonProperty;

public class NetworkMemoryReport
extends MemoryReport {
    private static final DecimalFormat BYTES_FORMAT = new DecimalFormat("#,###");
    private final Map<String, MemoryReport> layerAndVertexReports;
    private final Class<?> modelClass;
    private final String modelName;
    private final InputType[] networkInputTypes;

    public NetworkMemoryReport(@JsonProperty(value="layerAndVertexReports") @NonNull Map<String, MemoryReport> layerAndVertexReports, @JsonProperty(value="modelClass") @NonNull Class<?> modelClass, @JsonProperty(value="modelName") String modelName, InputType ... networkInputTypes) {
        if (layerAndVertexReports == null) {
            throw new NullPointerException("layerAndVertexReports");
        }
        if (modelClass == null) {
            throw new NullPointerException("modelClass");
        }
        if (networkInputTypes == null) {
            throw new NullPointerException("networkInputTypes");
        }
        this.layerAndVertexReports = layerAndVertexReports;
        this.modelClass = modelClass;
        this.modelName = modelName;
        this.networkInputTypes = networkInputTypes;
    }

    @Override
    public Class<?> getReportClass() {
        return this.modelClass;
    }

    @Override
    public String getName() {
        return this.modelName;
    }

    @Override
    public long getTotalMemoryBytes(int minibatchSize, @NonNull MemoryUseMode memoryUseMode, @NonNull CacheMode cacheMode, @NonNull DataBuffer.Type dataType) {
        if (memoryUseMode == null) {
            throw new NullPointerException("memoryUseMode");
        }
        if (cacheMode == null) {
            throw new NullPointerException("cacheMode");
        }
        if (dataType == null) {
            throw new NullPointerException("dataType");
        }
        long totalBytes = 0L;
        long maxWorking = 0L;
        long maxWorkingFixed = 0L;
        long maxWorkingVariable = 0L;
        for (MemoryReport lmr : this.layerAndVertexReports.values()) {
            long workVar;
            for (MemoryType mt : MemoryType.values()) {
                if (mt == MemoryType.WORKING_MEMORY_FIXED || mt == MemoryType.WORKING_MEMORY_VARIABLE) continue;
                totalBytes += lmr.getMemoryBytes(mt, minibatchSize, memoryUseMode, cacheMode, dataType);
            }
            long workFixed = lmr.getMemoryBytes(MemoryType.WORKING_MEMORY_FIXED, minibatchSize, memoryUseMode, cacheMode, dataType);
            long currWorking = workFixed + (workVar = lmr.getMemoryBytes(MemoryType.WORKING_MEMORY_VARIABLE, minibatchSize, memoryUseMode, cacheMode, dataType));
            if (currWorking <= maxWorking) continue;
            maxWorking = currWorking;
            maxWorkingFixed = workFixed;
            maxWorkingVariable = workVar;
        }
        return totalBytes + maxWorkingFixed + maxWorkingVariable;
    }

    @Override
    public long getMemoryBytes(MemoryType memoryType, int minibatchSize, MemoryUseMode memoryUseMode, CacheMode cacheMode, DataBuffer.Type dataType) {
        long totalBytes = 0L;
        for (MemoryReport lmr : this.layerAndVertexReports.values()) {
            long bytes = lmr.getMemoryBytes(memoryType, minibatchSize, memoryUseMode, cacheMode, dataType);
            if (memoryType == MemoryType.WORKING_MEMORY_FIXED || memoryType == MemoryType.WORKING_MEMORY_VARIABLE) {
                totalBytes = Math.max(totalBytes, bytes);
                continue;
            }
            totalBytes += bytes;
        }
        return totalBytes;
    }

    @Override
    public String toString() {
        long fixedMemBytes = this.getTotalMemoryBytes(0, MemoryUseMode.INFERENCE, CacheMode.NONE, DataBuffer.Type.FLOAT);
        long perEx = this.getTotalMemoryBytes(1, MemoryUseMode.INFERENCE, CacheMode.NONE, DataBuffer.Type.FLOAT) - fixedMemBytes;
        long fixedMemBytesTrain = this.getTotalMemoryBytes(0, MemoryUseMode.TRAINING, CacheMode.NONE, DataBuffer.Type.FLOAT);
        long perExTrain = this.getTotalMemoryBytes(1, MemoryUseMode.TRAINING, CacheMode.NONE, DataBuffer.Type.FLOAT) - fixedMemBytesTrain;
        LinkedHashMap layerCounts = new LinkedHashMap();
        for (MemoryReport memoryReport : this.layerAndVertexReports.values()) {
            if (layerCounts.containsKey(memoryReport.getReportClass())) {
                layerCounts.put(memoryReport.getReportClass(), (Integer)layerCounts.get(memoryReport.getReportClass()) + 1);
                continue;
            }
            layerCounts.put(memoryReport.getReportClass(), 1);
        }
        StringBuilder sbLayerCounts = new StringBuilder();
        for (Map.Entry entry : layerCounts.entrySet()) {
            sbLayerCounts.append(entry.getValue()).append(" x ").append(((Class)entry.getKey()).getSimpleName()).append(", ");
        }
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("----- Network Memory Report -----\n").append("  Model Class:                        ").append(this.modelClass.getName()).append("\n").append("  Model Name:                         ").append(this.modelName).append("\n").append("  Network Input:                      ").append(Arrays.toString(this.networkInputTypes)).append("\n").append("  # Layers:                           ").append(this.layerAndVertexReports.size()).append("\n").append("  Layer Types:                        ").append((CharSequence)sbLayerCounts).append("\n");
        this.appendFixedPlusVariable(stringBuilder, "  Inference Memory (FP32)             ", fixedMemBytes, perEx);
        this.appendFixedPlusVariable(stringBuilder, "  Training Memory (FP32):             ", fixedMemBytesTrain, perExTrain);
        stringBuilder.append("  Inference Memory Breakdown (FP32):\n");
        this.appendBreakDown(stringBuilder, MemoryUseMode.INFERENCE, CacheMode.NONE, DataBuffer.Type.FLOAT);
        stringBuilder.append("  Training Memory Breakdown (CacheMode = ").append((Object)CacheMode.NONE).append(", FP32):\n");
        this.appendBreakDown(stringBuilder, MemoryUseMode.TRAINING, CacheMode.NONE, DataBuffer.Type.FLOAT);
        return stringBuilder.toString();
    }

    private void appendBreakDown(StringBuilder sb, MemoryUseMode useMode, CacheMode cacheMode, DataBuffer.Type dataType) {
        for (MemoryType mt : MemoryType.values()) {
            if (useMode == MemoryUseMode.INFERENCE && !mt.isInference()) continue;
            long bytesFixed = this.getMemoryBytes(mt, 0, useMode, cacheMode, dataType);
            long bytesPerEx = this.getMemoryBytes(mt, 1, useMode, cacheMode, dataType) - bytesFixed;
            if (bytesFixed <= 0L && bytesPerEx <= 0L) continue;
            String formatted = String.format("  - %-34s", new Object[]{mt});
            this.appendFixedPlusVariable(sb, formatted, bytesFixed, bytesPerEx);
        }
    }

    private void appendFixedPlusVariable(StringBuilder sb, String title, long bytesFixed, long bytesPerEx) {
        sb.append(title);
        if (bytesFixed > 0L) {
            sb.append(this.formatBytes(bytesFixed)).append(" bytes");
        }
        if (bytesPerEx > 0L) {
            if (bytesFixed > 0L) {
                sb.append(" + ");
            }
            sb.append("nExamples * ").append(this.formatBytes(bytesPerEx)).append(" bytes");
        }
        sb.append("\n");
    }

    private String formatBytes(long bytes) {
        return BYTES_FORMAT.format(bytes);
    }

    public Map<String, MemoryReport> getLayerAndVertexReports() {
        return this.layerAndVertexReports;
    }

    public Class<?> getModelClass() {
        return this.modelClass;
    }

    public String getModelName() {
        return this.modelName;
    }

    public InputType[] getNetworkInputTypes() {
        return this.networkInputTypes;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof NetworkMemoryReport)) {
            return false;
        }
        NetworkMemoryReport other = (NetworkMemoryReport)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        Map<String, MemoryReport> this$layerAndVertexReports = this.getLayerAndVertexReports();
        Map<String, MemoryReport> other$layerAndVertexReports = other.getLayerAndVertexReports();
        if (this$layerAndVertexReports == null ? other$layerAndVertexReports != null : !((Object)this$layerAndVertexReports).equals(other$layerAndVertexReports)) {
            return false;
        }
        Class<?> this$modelClass = this.getModelClass();
        Class<?> other$modelClass = other.getModelClass();
        if (this$modelClass == null ? other$modelClass != null : !this$modelClass.equals(other$modelClass)) {
            return false;
        }
        String this$modelName = this.getModelName();
        String other$modelName = other.getModelName();
        if (this$modelName == null ? other$modelName != null : !this$modelName.equals(other$modelName)) {
            return false;
        }
        return Arrays.deepEquals(this.getNetworkInputTypes(), other.getNetworkInputTypes());
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof NetworkMemoryReport;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        Map<String, MemoryReport> $layerAndVertexReports = this.getLayerAndVertexReports();
        result = result * 59 + ($layerAndVertexReports == null ? 43 : ((Object)$layerAndVertexReports).hashCode());
        Class<?> $modelClass = this.getModelClass();
        result = result * 59 + ($modelClass == null ? 43 : $modelClass.hashCode());
        String $modelName = this.getModelName();
        result = result * 59 + ($modelName == null ? 43 : $modelName.hashCode());
        result = result * 59 + Arrays.deepHashCode(this.getNetworkInputTypes());
        return result;
    }
}

