/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.ui.weights;

import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.Image;
import java.awt.image.BufferedImage;
import java.awt.image.RenderedImage;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.UUID;
import javax.imageio.ImageIO;
import lombok.NonNull;
import org.datavec.image.loader.ImageLoader;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.BaseTrainingListener;
import org.deeplearning4j.ui.UiConnectionInfo;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage;
import org.deeplearning4j.ui.weights.ConvolutionListenerPersistable;
import org.deeplearning4j.util.UIDProvider;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.io.ClassPathResource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ConvolutionalIterationListener
extends BaseTrainingListener {
    private int freq = 10;
    private static final Logger log = LoggerFactory.getLogger(ConvolutionalIterationListener.class);
    private int minibatchNum = 0;
    private boolean openBrowser = true;
    private String path;
    private boolean firstIteration = true;
    private Color borderColor = new Color(140, 140, 140);
    private Color bgColor = new Color(255, 255, 255);
    private final StatsStorageRouter ssr;
    private final String sessionID;
    private final String workerID;

    public ConvolutionalIterationListener(UiConnectionInfo connectionInfo, int visualizationFrequency) {
        this((StatsStorageRouter)new MapDBStatsStorage(), visualizationFrequency, true);
    }

    public ConvolutionalIterationListener(int visualizationFrequency) {
        this(visualizationFrequency, true);
    }

    public ConvolutionalIterationListener(int iterations, boolean openBrowser) {
        this((StatsStorageRouter)new MapDBStatsStorage(), iterations, openBrowser);
    }

    public ConvolutionalIterationListener(StatsStorageRouter ssr, int iterations, boolean openBrowser) {
        this(ssr, iterations, openBrowser, null, null);
    }

    public ConvolutionalIterationListener(StatsStorageRouter ssr, int iterations, boolean openBrowser, String sessionID, String workerID) {
        this.ssr = ssr;
        this.sessionID = sessionID == null ? UUID.randomUUID().toString() : sessionID;
        this.workerID = workerID == null ? UIDProvider.getJVMUID() + "_" + Thread.currentThread().getId() : workerID;
        String subPath = "activations";
        this.freq = iterations;
        this.openBrowser = openBrowser;
        this.path = "http://localhost:" + UIServer.getInstance().getPort() + "/" + subPath;
        if (openBrowser && ssr instanceof StatsStorage) {
            UIServer.getInstance().attach((StatsStorage)ssr);
        }
        System.out.println("ConvolutionTrainingListener path: " + this.path);
    }

    public void iterationDone(Model model, int iteration, int epoch) {
    }

    public void onForwardPass(Model model, Map<String, INDArray> activations) {
        int iteration;
        int n = iteration = model instanceof MultiLayerNetwork ? ((MultiLayerNetwork)model).getIterationCount() : ((ComputationGraph)model).getIterationCount();
        if (iteration % this.freq == 0) {
            ArrayList<INDArray> tensors = new ArrayList<INDArray>();
            int cnt = 0;
            Random rnd = new Random();
            BufferedImage sourceImage = null;
            int sampleIdx = -1;
            if (model instanceof ComputationGraph) {
                ComputationGraph l = (ComputationGraph)model;
                Layer[] layers = l.getLayers();
                if (layers.length != activations.size()) {
                    throw new RuntimeException();
                }
                for (int i = 0; i < layers.length; ++i) {
                    if (layers[i].type() != Layer.Type.CONVOLUTIONAL) continue;
                    String layerName = layers[i].conf().getLayer().getLayerName();
                    INDArray output = activations.get(layerName);
                    if (sampleIdx < 0) {
                        sampleIdx = output.shape()[0] == 1L ? 0 : rnd.nextInt((int)output.shape()[0] - 1) + 1;
                    }
                    INDArray tad = output.tensorAlongDimension((long)sampleIdx, new int[]{3, 2, 1});
                    tensors.add(tad);
                    ++cnt;
                }
            } else {
                return;
            }
            ComputationGraph cg = (ComputationGraph)model;
            INDArray[] arr = cg.getInputs();
            if (arr.length > 1) {
                throw new IllegalStateException("ConvolutionIterationListener does not support ComputationGraph models with more than 1 input; model has " + arr.length + " inputs");
            }
            if (arr[0].rank() == 4) {
                sourceImage = null;
                if (cnt == 0) {
                    try {
                        sourceImage = this.restoreRGBImage(arr[0].tensorAlongDimension((long)sampleIdx, new int[]{3, 2, 1}));
                    }
                    catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
            }
            BufferedImage render = this.rasterizeConvoLayers(tensors, sourceImage);
            ConvolutionListenerPersistable p = new ConvolutionListenerPersistable(this.sessionID, this.workerID, System.currentTimeMillis(), render);
            this.ssr.putStaticInfo((Persistable)p);
            ++this.minibatchNum;
        }
    }

    public void onForwardPass(Model model, List<INDArray> activations) {
        int iteration;
        int n = iteration = model instanceof MultiLayerNetwork ? ((MultiLayerNetwork)model).getIterationCount() : ((ComputationGraph)model).getIterationCount();
        if (iteration % this.freq == 0) {
            ArrayList<INDArray> tensors = new ArrayList<INDArray>();
            int cnt = 0;
            Random rnd = new Random();
            BufferedImage sourceImage = null;
            if (model instanceof MultiLayerNetwork) {
                MultiLayerNetwork l = (MultiLayerNetwork)model;
                Layer[] layers = l.getLayers();
                if (layers.length != activations.size()) {
                    throw new RuntimeException();
                }
                for (int i = 0; i < layers.length; ++i) {
                    int sampleDim;
                    if (layers[i].type() != Layer.Type.CONVOLUTIONAL) continue;
                    INDArray output = activations.get(i + 1);
                    if (output.shape()[0] - 1L > Integer.MAX_VALUE) {
                        throw new ND4JArraySizeException();
                    }
                    int n2 = sampleDim = output.shape()[0] == 1L ? 0 : rnd.nextInt((int)output.shape()[0] - 1) + 1;
                    if (cnt == 0) {
                        INDArray inputs = layers[i].input();
                        try {
                            sourceImage = this.restoreRGBImage(inputs.tensorAlongDimension((long)sampleDim, new int[]{3, 2, 1}));
                        }
                        catch (Exception e) {
                            throw new RuntimeException(e);
                        }
                    }
                    INDArray tad = output.tensorAlongDimension((long)sampleDim, new int[]{3, 2, 1});
                    tensors.add(tad);
                    ++cnt;
                }
            } else {
                return;
            }
            BufferedImage render = this.rasterizeConvoLayers(tensors, sourceImage);
            ConvolutionListenerPersistable p = new ConvolutionListenerPersistable(this.sessionID, this.workerID, System.currentTimeMillis(), render);
            this.ssr.putStaticInfo((Persistable)p);
            ++this.minibatchNum;
        }
    }

    private BufferedImage rasterizeConvoLayers(@NonNull List<INDArray> tensors3D, BufferedImage sourceImage) {
        if (tensors3D == null) {
            throw new NullPointerException("tensors3D is marked @NonNull but is null");
        }
        long width = 0L;
        long height = 0L;
        int border = 1;
        int padding_row = 2;
        int padding_col = 80;
        long[] shape = tensors3D.get(0).shape();
        long numImages = shape[0];
        height = shape[2];
        width = shape[1];
        int maxHeight = 0;
        int totalWidth = 0;
        int iOffset = 1;
        Orientation orientation = Orientation.LANDSCAPE;
        if (tensors3D.size() > 3) {
            orientation = Orientation.PORTRAIT;
        }
        ArrayList<BufferedImage> images = new ArrayList<BufferedImage>();
        for (int layer = 0; layer < tensors3D.size(); ++layer) {
            INDArray tad = tensors3D.get(layer);
            boolean zoomed = false;
            BufferedImage image = null;
            if (orientation == Orientation.LANDSCAPE) {
                maxHeight = (int)((height + (long)(border * 2) + (long)padding_row) * numImages);
                image = this.renderMultipleImagesLandscape(tad, maxHeight, (int)width, (int)height);
                totalWidth += image.getWidth() + padding_col;
            } else if (orientation == Orientation.PORTRAIT) {
                totalWidth = (int)((width + (long)(border * 2) + (long)padding_row) * numImages);
                image = this.renderMultipleImagesPortrait(tad, totalWidth, (int)width, (int)height);
                maxHeight += image.getHeight() + padding_col;
            }
            images.add(image);
        }
        if (orientation == Orientation.LANDSCAPE) {
            totalWidth += padding_col * 2;
        } else if (orientation == Orientation.PORTRAIT) {
            maxHeight += padding_col * 2;
            maxHeight += sourceImage.getHeight() + padding_col * 2;
        }
        BufferedImage output = new BufferedImage(totalWidth, maxHeight, 1);
        Graphics2D graphics2D = output.createGraphics();
        graphics2D.setPaint(this.bgColor);
        graphics2D.fillRect(0, 0, output.getWidth(), output.getHeight());
        BufferedImage singleArrow = null;
        BufferedImage multipleArrows = null;
        try {
            ClassPathResource resource2;
            ClassPathResource resource3;
            if (orientation == Orientation.LANDSCAPE) {
                try {
                    resource3 = new ClassPathResource("arrow_sing.PNG");
                    resource2 = new ClassPathResource("arrow_mul.PNG");
                    singleArrow = ImageIO.read(resource3.getInputStream());
                    multipleArrows = ImageIO.read(resource2.getInputStream());
                }
                catch (Exception resource3) {
                    // empty catch block
                }
                graphics2D.drawImage((Image)sourceImage, padding_col / 2 - sourceImage.getWidth() / 2, maxHeight / 2 - sourceImage.getHeight() / 2, null);
                graphics2D.setPaint(this.borderColor);
                graphics2D.drawRect(padding_col / 2 - sourceImage.getWidth() / 2, maxHeight / 2 - sourceImage.getHeight() / 2, sourceImage.getWidth(), sourceImage.getHeight());
                iOffset += sourceImage.getWidth();
                if (singleArrow != null) {
                    graphics2D.drawImage((Image)singleArrow, iOffset + padding_col / 2 - singleArrow.getWidth() / 2, maxHeight / 2 - singleArrow.getHeight() / 2, null);
                }
            } else {
                try {
                    resource3 = new ClassPathResource("arrow_singi.PNG");
                    resource2 = new ClassPathResource("arrow_muli.PNG");
                    singleArrow = ImageIO.read(resource3.getInputStream());
                    multipleArrows = ImageIO.read(resource2.getInputStream());
                }
                catch (Exception resource4) {
                    // empty catch block
                }
                graphics2D.drawImage((Image)sourceImage, totalWidth / 2 - sourceImage.getWidth() / 2, padding_col / 2 - sourceImage.getHeight() / 2, null);
                graphics2D.setPaint(this.borderColor);
                graphics2D.drawRect(totalWidth / 2 - sourceImage.getWidth() / 2, padding_col / 2 - sourceImage.getHeight() / 2, sourceImage.getWidth(), sourceImage.getHeight());
                iOffset += sourceImage.getHeight();
                if (singleArrow != null) {
                    graphics2D.drawImage((Image)singleArrow, totalWidth / 2 - singleArrow.getWidth() / 2, iOffset + padding_col / 2 - singleArrow.getHeight() / 2, null);
                }
            }
            iOffset += padding_col;
        }
        catch (Exception resource4) {
            // empty catch block
        }
        for (int i = 0; i < images.size(); ++i) {
            BufferedImage curImage = (BufferedImage)images.get(i);
            if (orientation == Orientation.LANDSCAPE) {
                graphics2D.drawImage((Image)curImage, iOffset, 1, null);
                iOffset += curImage.getWidth() + padding_col;
                if (singleArrow == null || multipleArrows == null || i >= images.size() - 1 || multipleArrows == null) continue;
                graphics2D.drawImage((Image)multipleArrows, iOffset - padding_col / 2 - multipleArrows.getWidth() / 2, maxHeight / 2 - multipleArrows.getHeight() / 2, null);
                continue;
            }
            if (orientation != Orientation.PORTRAIT) continue;
            graphics2D.drawImage((Image)curImage, 1, iOffset, null);
            iOffset += curImage.getHeight() + padding_col;
            if (singleArrow == null || multipleArrows == null || i >= images.size() - 1 || multipleArrows == null) continue;
            graphics2D.drawImage((Image)multipleArrows, totalWidth / 2 - multipleArrows.getWidth() / 2, iOffset - padding_col / 2 - multipleArrows.getHeight() / 2, null);
        }
        return output;
    }

    private BufferedImage renderMultipleImagesPortrait(INDArray tensor3D, int maxWidth, int zoomWidth, int zoomHeight) {
        int border = 1;
        int padding_row = 2;
        int padding_col = 2;
        int zoomPadding = 20;
        long[] tShape = tensor3D.shape();
        long numRows = tShape[0] / tShape[2];
        long height = numRows * (tShape[1] + (long)border + (long)padding_col) + (long)padding_col + (long)zoomPadding + (long)zoomWidth;
        if (height > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        BufferedImage outputImage = new BufferedImage(maxWidth, (int)height, 10);
        Graphics2D graphics2D = outputImage.createGraphics();
        graphics2D.setPaint(this.bgColor);
        graphics2D.fillRect(0, 0, outputImage.getWidth(), outputImage.getHeight());
        int columnOffset = 0;
        int rowOffset = 0;
        int numZoomed = 0;
        int limZoomed = 5;
        int zoomSpan = maxWidth / limZoomed;
        int z = 0;
        while ((long)z < tensor3D.shape()[0]) {
            INDArray tad2D = tensor3D.tensorAlongDimension((long)z, new int[]{2, 1});
            long rWidth = tad2D.shape()[0];
            long rHeight = tad2D.shape()[1];
            long loc_height = rHeight + (long)(border * 2) + (long)padding_row;
            long loc_width = rWidth + (long)(border * 2) + (long)padding_col;
            BufferedImage currentImage = this.renderImageGrayscale(tad2D);
            if ((long)columnOffset + loc_width > (long)maxWidth) {
                rowOffset = (int)((long)rowOffset + loc_height);
                columnOffset = 0;
            }
            graphics2D.drawImage((Image)currentImage, columnOffset + 1, rowOffset + 1, null);
            graphics2D.setPaint(this.borderColor);
            graphics2D.drawRect(columnOffset, rowOffset, (int)tad2D.shape()[0], (int)tad2D.shape()[1]);
            if (z % 7 == 0 && z != 0 && numZoomed < limZoomed && rHeight != (long)zoomHeight && rWidth != (long)zoomWidth) {
                int cY = zoomSpan * numZoomed + zoomHeight;
                int cX = zoomSpan * numZoomed + zoomWidth;
                graphics2D.drawImage(currentImage, cX - 1, (int)height - zoomWidth - 1, zoomWidth, zoomHeight, null);
                graphics2D.drawRect(cX - 2, (int)height - zoomWidth - 2, zoomWidth, zoomHeight);
                graphics2D.drawLine(columnOffset + (int)rWidth, rowOffset + (int)rHeight, cX - 2, (int)height - zoomWidth - 2);
                ++numZoomed;
            }
            columnOffset = (int)((long)columnOffset + loc_width);
            ++z;
        }
        return outputImage;
    }

    private BufferedImage renderMultipleImagesLandscape(INDArray tensor3D, int maxHeight, int zoomWidth, int zoomHeight) {
        int border = 1;
        int padding_row = 2;
        int padding_col = 2;
        int zoomPadding = 20;
        long[] tShape = tensor3D.shape();
        long numColumns = tShape[0] / tShape[1];
        long width = numColumns * (tShape[1] + (long)border + (long)padding_col) + (long)padding_col + (long)zoomPadding + (long)zoomWidth;
        BufferedImage outputImage = new BufferedImage((int)width, maxHeight, 10);
        Graphics2D graphics2D = outputImage.createGraphics();
        graphics2D.setPaint(this.bgColor);
        graphics2D.fillRect(0, 0, outputImage.getWidth(), outputImage.getHeight());
        int columnOffset = 0;
        int rowOffset = 0;
        int numZoomed = 0;
        int limZoomed = 5;
        int zoomSpan = maxHeight / limZoomed;
        int z = 0;
        while ((long)z < tensor3D.shape()[0]) {
            INDArray tad2D = tensor3D.tensorAlongDimension((long)z, new int[]{2, 1});
            long rWidth = tad2D.shape()[0];
            long rHeight = tad2D.shape()[1];
            long loc_height = rHeight + (long)(border * 2) + (long)padding_row;
            long loc_width = rWidth + (long)(border * 2) + (long)padding_col;
            BufferedImage currentImage = this.renderImageGrayscale(tad2D);
            if ((long)rowOffset + loc_height > (long)maxHeight) {
                columnOffset = (int)((long)columnOffset + loc_width);
                rowOffset = 0;
            }
            graphics2D.drawImage((Image)currentImage, columnOffset + 1, rowOffset + 1, null);
            graphics2D.setPaint(this.borderColor);
            if (tad2D.shape()[0] > Integer.MAX_VALUE || tad2D.shape()[1] > Integer.MAX_VALUE) {
                throw new ND4JArraySizeException();
            }
            graphics2D.drawRect(columnOffset, rowOffset, (int)tad2D.shape()[0], (int)tad2D.shape()[1]);
            if (z % 5 == 0 && z != 0 && numZoomed < limZoomed && rHeight != (long)zoomHeight && rWidth != (long)zoomWidth) {
                int cY = zoomSpan * numZoomed + zoomHeight;
                graphics2D.drawImage(currentImage, (int)width - zoomWidth - 1, cY - 1, zoomWidth, zoomHeight, null);
                graphics2D.drawRect((int)width - zoomWidth - 2, cY - 2, zoomWidth, zoomHeight);
                graphics2D.drawLine(columnOffset + (int)rWidth, rowOffset + (int)rHeight, (int)width - zoomWidth - 2, cY - 2 + zoomHeight);
                ++numZoomed;
            }
            rowOffset = (int)((long)rowOffset + loc_height);
            ++z;
        }
        return outputImage;
    }

    private BufferedImage restoreRGBImage(INDArray tensor3D) {
        INDArray arrayR = null;
        INDArray arrayG = null;
        INDArray arrayB = null;
        if (tensor3D.shape()[0] == 3L) {
            arrayR = tensor3D.tensorAlongDimension(2L, new int[]{2, 1});
            arrayG = tensor3D.tensorAlongDimension(1L, new int[]{2, 1});
            arrayB = tensor3D.tensorAlongDimension(0L, new int[]{2, 1});
        } else {
            arrayG = arrayB = tensor3D.tensorAlongDimension(0L, new int[]{2, 1});
            arrayR = arrayB;
        }
        BufferedImage imageToRender = new BufferedImage(arrayR.columns(), arrayR.rows(), 1);
        for (int x = 0; x < arrayR.columns(); ++x) {
            for (int y = 0; y < arrayR.rows(); ++y) {
                Color pix = new Color((int)(255.0 * arrayR.getRow((long)y).getDouble((long)x)), (int)(255.0 * arrayG.getRow((long)y).getDouble((long)x)), (int)(255.0 * arrayB.getRow((long)y).getDouble((long)x)));
                int rgb = pix.getRGB();
                imageToRender.setRGB(x, y, rgb);
            }
        }
        return imageToRender;
    }

    private BufferedImage renderImageGrayscale(INDArray array) {
        BufferedImage imageToRender = new BufferedImage(array.columns(), array.rows(), 10);
        for (int x = 0; x < array.columns(); ++x) {
            for (int y = 0; y < array.rows(); ++y) {
                imageToRender.getRaster().setSample(x, y, 0, (int)(255.0 * array.getRow((long)y).getDouble((long)x)));
            }
        }
        return imageToRender;
    }

    private void writeImageGrayscale(INDArray array, File file) {
        try {
            ImageIO.write((RenderedImage)this.renderImageGrayscale(array), "png", file);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void writeImage(INDArray array, File file) {
        BufferedImage image = ImageLoader.toImage((INDArray)array);
        try {
            ImageIO.write((RenderedImage)image, "png", file);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void writeRows(INDArray array, File file) {
        try {
            PrintWriter writer = new PrintWriter(file);
            for (int x = 0; x < array.rows(); ++x) {
                writer.println("Row [" + x + "]: " + array.getRow((long)x));
            }
            writer.flush();
            writer.close();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static enum Orientation {
        LANDSCAPE,
        PORTRAIT;

    }
}

