/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.plot;

import java.awt.Color;
import java.awt.Dimension;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.awt.GraphicsEnvironment;
import java.awt.Image;
import java.awt.image.BufferStrategy;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferInt;
import java.awt.image.RenderedImage;
import java.awt.image.WritableRaster;
import java.io.File;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.Map;
import java.util.TreeMap;
import javax.imageio.ImageIO;
import javax.swing.ImageIcon;
import javax.swing.JFrame;
import javax.swing.JLabel;
import org.deeplearning4j.plot.FilterPanel;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FilterRenderer {
    public JFrame frame;
    BufferedImage img;
    private int width = 28;
    private int height = 28;
    public String title = "TEST";
    private int heightOffset = 0;
    private int widthOffset = 0;
    private static Logger log = LoggerFactory.getLogger(FilterRenderer.class);

    public void renderHiddenBiases(int heightOffset, int widthOffset, INDArray render_data, String filename) {
        this.width = render_data.columns();
        this.height = render_data.rows();
        this.img = new BufferedImage(this.width, this.height, 1);
        this.heightOffset = heightOffset;
        this.widthOffset = widthOffset;
        WritableRaster r = this.img.getRaster();
        int[] equiv = new int[render_data.length()];
        for (int i = 0; i < equiv.length; ++i) {
            equiv[i] = (int)Math.round(render_data.getDouble(i) * 256.0);
            log.debug("> " + equiv[i]);
        }
        log.debug("hbias size: Cols: " + render_data.columns() + ", Rows: " + render_data.rows());
        r.setDataElements(0, 0, this.width, this.height, equiv);
        this.saveToDisk(filename);
    }

    public int computeHistogramBucketIndex(double min, double stepSize, double value, int numberBins) {
        for (int x = 0; x < numberBins; ++x) {
            double tmp = (double)x * stepSize + min;
            if (!(value >= tmp) || !(value <= tmp + stepSize)) continue;
            return x;
        }
        return -10;
    }

    public static double round(double unrounded, int precision, int roundingMode) {
        BigDecimal bd = new BigDecimal(unrounded);
        BigDecimal rounded = bd.setScale(precision, roundingMode);
        return rounded.doubleValue();
    }

    private String buildBucketLabel(int bucketIndex, double stepSize, double min) {
        double val = min + (double)bucketIndex * stepSize;
        String ret = "" + FilterRenderer.round(val, 2, 4);
        return ret;
    }

    public Map<Integer, Integer> generateHistogramBuckets(INDArray data, int numberBins) {
        TreeMap<Integer, Integer> mapHistory = new TreeMap<Integer, Integer>();
        double min = data.min(Integer.MAX_VALUE).getDouble(0);
        double max = data.max(Integer.MAX_VALUE).getDouble(0);
        double range = max - min;
        double stepSize = range / (double)numberBins;
        for (int row = 0; row < data.rows(); ++row) {
            for (int col = 0; col < data.columns(); ++col) {
                double matrix_value = data.getScalar(row, col).getDouble(0);
                int bucket_key = this.computeHistogramBucketIndex(min, stepSize, matrix_value, numberBins);
                int entry = 0;
                if (mapHistory.containsKey(bucket_key)) {
                    entry = (Integer)mapHistory.get(bucket_key);
                    mapHistory.put(bucket_key, ++entry);
                    continue;
                }
                String bucket_label = this.buildBucketLabel(bucket_key, stepSize, min);
                entry = 1;
                mapHistory.put(bucket_key, entry);
            }
        }
        return mapHistory;
    }

    public void renderHistogram(INDArray data, String filename, int numberBins) {
        Map<Integer, Integer> mapHistory = this.generateHistogramBuckets(data, numberBins);
        double min = data.min(Integer.MAX_VALUE).getDouble(0);
        double max = data.max(Integer.MAX_VALUE).getDouble(0);
        double range = max - min;
        double stepSize = range / (double)numberBins;
        int xOffset = 50;
        int yOffset = -50;
        int graphWidth = 600;
        int graphHeight = 400;
        BufferedImage img = new BufferedImage(graphWidth, graphHeight, 1);
        Graphics2D g2d = img.createGraphics();
        int BAR_WIDTH = 40;
        boolean X_POSITION = false;
        int Y_POSITION = 200;
        int MIN_BAR_WIDTH = 4;
        g2d.setColor(Color.LIGHT_GRAY);
        g2d.fillRect(0, 0, graphWidth, graphHeight);
        int barWidth = 40;
        int maxValue = 0;
        for (Integer key : mapHistory.keySet()) {
            int value = mapHistory.get(key);
            maxValue = Math.max(maxValue, value);
        }
        double plotAreaHeight = graphHeight + yOffset;
        double yScaleStepSize = plotAreaHeight / 4.0;
        double yLabelStepSize = (double)maxValue / 4.0;
        for (int yStep = 0; yStep < 5; ++yStep) {
            double curLabel = (double)yStep * yLabelStepSize;
            long curY = (long)(graphHeight + yOffset) - Math.round((double)((int)curLabel) / (double)maxValue * (double)(graphHeight + yOffset - 20));
            g2d.setColor(Color.BLACK);
            g2d.drawString("" + curLabel, 10.0f, (float)curY);
        }
        int xPos = xOffset;
        for (Integer key : mapHistory.keySet()) {
            long value = mapHistory.get(key).intValue();
            String bucket_label = this.buildBucketLabel(key, stepSize, min);
            long barHeight = Math.round((double)value / (double)maxValue * (double)(graphHeight + yOffset - 20));
            g2d.setColor(Color.BLUE);
            long yPos = (long)(graphHeight + yOffset) - barHeight;
            g2d.fillRect(xPos, (int)yPos, barWidth, (int)barHeight);
            g2d.setColor(Color.DARK_GRAY);
            g2d.drawRect(xPos, (int)yPos, barWidth, (int)barHeight);
            g2d.setColor(Color.BLACK);
            g2d.drawString(bucket_label, (float)(xPos + (barWidth / 2 - 10)), (float)(barHeight + 20L + yPos));
            xPos += barWidth + 10;
        }
        try {
            FilterRenderer.saveImageToDisk(img, filename);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        g2d.dispose();
    }

    public BufferedImage renderFilters(INDArray data, String filename, int patchWidth, int patchHeight, int patchesPerRow) throws Exception {
        int[] equiv = new int[data.length()];
        int numberCols = data.columns();
        double approx = (double)numberCols / (double)patchesPerRow;
        int numPatchRows = (int)Math.round(approx);
        if (numPatchRows < 1) {
            numPatchRows = 1;
        }
        int patchBorder = 2;
        int filterImgWidth = (patchWidth + patchBorder) * patchesPerRow;
        int filterImgHeight = numPatchRows * (patchHeight + patchBorder);
        this.img = new BufferedImage(filterImgWidth, filterImgHeight, 10);
        WritableRaster r = this.img.getRaster();
        for (int col = 0; col < data.columns(); ++col) {
            int curX = col % patchesPerRow * (patchWidth + patchBorder);
            int curY = col / patchesPerRow * (patchHeight + patchBorder);
            INDArray column = data.getColumn(col);
            double col_max = column.min(Integer.MAX_VALUE).getDouble(0);
            double col_min = column.max(Integer.MAX_VALUE).getDouble(0);
            log.debug("rendering " + column.length() + " pixels in column " + col + " for filter patch " + patchWidth + " x " + patchHeight + ", total size: " + patchWidth * patchHeight + " at " + curX);
            for (int i = 0; i < column.length(); ++i) {
                double patch_normal = (column.getScalar(i).getDouble(0) - col_min) / (col_max - col_min + (double)1.0E-6f);
                equiv[i] = (int)(255.0 * patch_normal);
            }
            boolean outOfBounds = false;
            if (curX >= filterImgWidth) {
                curX = filterImgWidth - 1;
                outOfBounds = true;
                break;
            }
            if (curY >= filterImgHeight) {
                curY = filterImgHeight - 1;
                outOfBounds = true;
                break;
            }
            r.setPixels(curX, curY, patchWidth, patchHeight, equiv);
            if (outOfBounds) break;
        }
        try {
            FilterRenderer.saveImageToDisk(this.img, filename);
            GraphicsEnvironment ge = GraphicsEnvironment.getLocalGraphicsEnvironment();
            if (!ge.isHeadlessInstance()) {
                log.info("Rendering frame...");
                JFrame frame = new JFrame();
                frame.setDefaultCloseOperation(2);
                FilterPanel panel = new FilterPanel(this.img);
                frame.add(panel);
                Dimension d = new Dimension(numberCols * patchWidth, numPatchRows * patchHeight);
                frame.setSize(d);
                frame.setMinimumSize(d);
                panel.setMinimumSize(d);
                frame.pack();
                frame.setVisible(true);
                Thread.sleep(10000L);
                frame.dispose();
            }
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        return this.img;
    }

    public void renderActivations(int heightOffset, int widthOffset, INDArray activation_data, String filename, int scale) {
        this.width = activation_data.columns();
        this.height = activation_data.rows();
        log.debug("----- renderActivations ------");
        this.img = new BufferedImage(this.width, this.height, 10);
        this.heightOffset = heightOffset;
        this.widthOffset = widthOffset;
        WritableRaster r = this.img.getRaster();
        int[] equiv = new int[activation_data.length()];
        double max = 0.1f * (float)scale;
        double min = -0.1f * (float)scale;
        double range = max - min;
        for (int i = 0; i < equiv.length; ++i) {
            equiv[i] = (int)Math.round(activation_data.getDouble(i) * 255.0);
        }
        log.debug("activations size: Cols: " + activation_data.columns() + ", Rows: " + activation_data.rows());
        r.setPixels(0, 0, this.width, this.height, equiv);
        this.saveToDisk(filename);
    }

    public static void saveImageToDisk(BufferedImage img, String imageName) throws IOException {
        File outputfile = new File(imageName);
        if (!outputfile.exists()) {
            outputfile.createNewFile();
        }
        ImageIO.write((RenderedImage)img, "png", outputfile);
    }

    public void saveToDisk(String filename) {
        try {
            FilterRenderer.saveImageToDisk(this.img, filename);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void draw() {
        this.frame = new JFrame(this.title);
        this.frame.setVisible(true);
        this.start();
        this.frame.add(new JLabel(new ImageIcon(this.getImage())));
        this.frame.pack();
        this.frame.setDefaultCloseOperation(2);
    }

    public void close() {
        this.frame.dispose();
    }

    public Image getImage() {
        return this.img;
    }

    public void start() {
        int[] pixels = ((DataBufferInt)this.img.getRaster().getDataBuffer()).getData();
        boolean running = true;
        while (running) {
            BufferStrategy bs = this.frame.getBufferStrategy();
            if (bs == null) {
                this.frame.createBufferStrategy(4);
                return;
            }
            for (int i = 0; i < this.width * this.height; ++i) {
                pixels[i] = 0;
            }
            Graphics g = bs.getDrawGraphics();
            g.drawImage(this.img, this.heightOffset, this.widthOffset, this.width, this.height, null);
            g.dispose();
            bs.show();
        }
    }
}

