/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets;

import java.awt.Color;
import java.awt.Dimension;
import java.awt.Graphics;
import java.awt.GridLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.IOException;
import javax.swing.BorderFactory;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import org.deeplearning4j.datasets.MnistManager;
import org.deeplearning4j.datasets.NN;

public class MNISTViewer {
    MnistManager manager;
    NN network;
    InputPanel input;
    OutputPanel output;
    OptionsPanel options;
    int xnodes = 50;
    int ynodes = 24;
    int connections = 10;
    int outputRows = 10;
    int outputCols = 10;

    public MNISTViewer() {
        try {
            this.manager = new MnistManager("MNIST/train-images-idx3-ubyte", "MNIST/train-labels-idx1-ubyte");
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        this.network = new NN(this.xnodes * this.ynodes, this.connections);
        this.network.init();
        MyFrame mf = new MyFrame("MNSIT Viewer");
        mf.pack();
        mf.setVisible(true);
    }

    public static void main(String[] args) {
        MNISTViewer viewer = new MNISTViewer();
    }

    class OptionsPanel
    extends JPanel {
        public OptionsPanel() {
            this.setPreferredSize(new Dimension(200, 200));
            this.setBorder(BorderFactory.createLineBorder(Color.black));
            JLabel title = new JLabel("Options");
            this.add(title);
            JButton next = new JButton("Next");
            JButton previous = new JButton("Previous");
            JButton setInput = new JButton("Set Input");
            JButton setOutput = new JButton("Set Output");
            JButton update = new JButton("Update");
            JButton reset = new JButton("Reset");
            next.addActionListener(new ActionListener(){

                @Override
                public void actionPerformed(ActionEvent e) {
                    MNISTViewer.this.input.nextImage();
                    MNISTViewer.this.input.repaint();
                }
            });
            previous.addActionListener(new ActionListener(){

                @Override
                public void actionPerformed(ActionEvent e) {
                    MNISTViewer.this.input.previousImage();
                    MNISTViewer.this.input.repaint();
                }
            });
            setInput.addActionListener(new ActionListener(){

                @Override
                public void actionPerformed(ActionEvent e) {
                    try {
                        MNISTViewer.this.manager.setCurrent(MNISTViewer.this.input.imageIndex);
                        int[][] image = MNISTViewer.this.manager.readImage();
                        int size = image.length * image[0].length;
                        int[] inputNodes = new int[size];
                        for (int i = 0; i < size; ++i) {
                            inputNodes[i] = i;
                        }
                        float[] inputValues = new float[size];
                        for (int i = 0; i < image.length; ++i) {
                            for (int j = 0; j < image[0].length; ++j) {
                                inputValues[i * image.length + j] = (float)image[i][j] / 255.0f;
                            }
                        }
                        MNISTViewer.this.network.setInput(inputNodes, inputValues);
                    }
                    catch (IOException e1) {
                        e1.printStackTrace();
                    }
                    MNISTViewer.this.output.repaint();
                }
            });
            setOutput.addActionListener(new ActionListener(){

                @Override
                public void actionPerformed(ActionEvent e) {
                    int size = MNISTViewer.this.outputRows * MNISTViewer.this.outputCols;
                    int[] outputNodes = new int[size];
                    int index = MNISTViewer.this.xnodes * MNISTViewer.this.ynodes - 1 - size;
                    for (int i = 0; i < size; ++i) {
                        outputNodes[i] = index++;
                    }
                    try {
                        int target = MNISTViewer.this.manager.readLabel();
                        float[][] values = new float[MNISTViewer.this.outputRows][MNISTViewer.this.outputCols];
                        for (int i = 0; i < MNISTViewer.this.outputCols; ++i) {
                            values[target][i] = 1.0f;
                        }
                        float[] outputValues = new float[size];
                        for (int i = 0; i < MNISTViewer.this.outputRows; ++i) {
                            for (int j = 0; j < MNISTViewer.this.outputCols; ++j) {
                                outputValues[i * MNISTViewer.this.outputCols + j] = values[i][j];
                            }
                        }
                        MNISTViewer.this.network.setOutput(outputNodes, outputValues);
                        MNISTViewer.this.input.repaint();
                        MNISTViewer.this.output.repaint();
                    }
                    catch (IOException e1) {
                        e1.printStackTrace();
                    }
                }
            });
            update.addActionListener(new ActionListener(){

                @Override
                public void actionPerformed(ActionEvent e) {
                    MNISTViewer.this.network.update();
                    MNISTViewer.this.output.repaint();
                }
            });
            reset.addActionListener(new ActionListener(){

                @Override
                public void actionPerformed(ActionEvent e) {
                    MNISTViewer.this.network.reset();
                    MNISTViewer.this.output.repaint();
                }
            });
            this.add(next);
            this.add(previous);
            this.add(setInput);
            this.add(setOutput);
            this.add(update);
            this.add(reset);
        }
    }

    class OutputPanel
    extends JPanel {
        private int width = 200;
        private int height = 200;
        private int mx = 20;
        private int my = 30;

        public OutputPanel() {
            this.setPreferredSize(new Dimension(this.width, this.height));
            this.setBorder(BorderFactory.createLineBorder(Color.black));
            JLabel title = new JLabel("Output");
            this.add(title);
        }

        private void drawInputNodes(Graphics g) {
            int rows = MNISTViewer.this.manager.getImages().getRows();
            int cols = MNISTViewer.this.manager.getImages().getRows();
            float[] state = MNISTViewer.this.network.readInput();
            if (state == null) {
                state = new float[rows * cols];
            }
            int xoffset = this.mx;
            int yoffset = this.my;
            for (int i = 0; i < rows; ++i) {
                for (int j = 0; j < cols; ++j) {
                    int index = i * cols + j;
                    int v = (int)(255.0 * (double)state[index]);
                    int c = v > 255 ? 255 : v;
                    try {
                        g.setColor(new Color(c, c, c));
                        g.fillRect(xoffset + j, yoffset + i, 1, 1);
                        continue;
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            }
        }

        private void drawState(Graphics g) {
            float[] state = MNISTViewer.this.network.getState();
            int xoffset = this.mx;
            int yoffset = 3 * this.my;
            int rows = MNISTViewer.this.ynodes;
            int cols = MNISTViewer.this.xnodes;
            for (int i = 0; i < rows; ++i) {
                for (int j = 0; j < cols; ++j) {
                    int v = (int)(255.0 * (double)state[i * cols + j]);
                    int c = v > 255 ? 255 : v;
                    g.setColor(new Color(c, c, c));
                    g.fillRect(xoffset + j, yoffset + i, 1, 1);
                }
            }
        }

        private void drawOutputNodes(Graphics g) {
            int rows = MNISTViewer.this.outputRows;
            int cols = MNISTViewer.this.outputCols;
            float[] state = MNISTViewer.this.network.readOutput();
            if (state == null) {
                state = new float[rows * cols];
            }
            int xoffset = this.mx;
            int yoffset = 4 * this.my;
            for (int i = 0; i < rows; ++i) {
                for (int j = 0; j < cols; ++j) {
                    int index = i * cols + j;
                    int v = (int)(255.0 * (double)state[index]);
                    int c = v > 255 ? 255 : v;
                    try {
                        g.setColor(new Color(c, c, c));
                        g.fillRect(xoffset + j, yoffset + i, 1, 1);
                        continue;
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            }
        }

        @Override
        public void paintComponent(Graphics g) {
            super.paintComponent(g);
            g.drawRect(this.mx, this.my, this.getWidth() - 2 * this.mx, this.getHeight() - 2 * this.my);
            this.drawInputNodes(g);
            this.drawState(g);
            this.drawOutputNodes(g);
        }
    }

    class InputPanel
    extends JPanel {
        private int width = 200;
        private int height = 200;
        private int mx = 20;
        private int my = 30;
        private int imageIndex = 1;
        private int maxIndex;

        public InputPanel() {
            this.setPreferredSize(new Dimension(this.width, this.height));
            this.setBorder(BorderFactory.createLineBorder(Color.black));
            JLabel title = new JLabel("Input");
            this.add(title);
            MNISTViewer.this.manager.setCurrent(this.imageIndex);
            this.maxIndex = MNISTViewer.this.manager.getImages().getCount();
        }

        public void nextImage() {
            this.imageIndex = this.imageIndex + 1 > this.maxIndex ? this.maxIndex : this.imageIndex + 1;
            MNISTViewer.this.manager.setCurrent(this.imageIndex);
        }

        public void previousImage() {
            this.imageIndex = this.imageIndex - 1 < 1 ? 1 : this.imageIndex - 1;
            MNISTViewer.this.manager.setCurrent(this.imageIndex);
        }

        public void drawCurrentImage(Graphics g) {
            MNISTViewer.this.manager.setCurrent(this.imageIndex);
            int[][] image = null;
            int rows = MNISTViewer.this.manager.getImages().getRows();
            int cols = MNISTViewer.this.manager.getImages().getRows();
            try {
                image = MNISTViewer.this.manager.readImage();
            }
            catch (IOException e) {
                e.printStackTrace();
            }
            for (int i = 0; i < rows; ++i) {
                for (int j = 0; j < cols; ++j) {
                    int c = image[i][j];
                    g.setColor(new Color(c, c, c));
                    g.fillRect(this.mx + j, this.my + i, 1, 1);
                }
            }
        }

        public void drawCurrentOutput(Graphics g) {
            int xoffset = this.mx;
            int yoffset = 3 * this.my;
            try {
                int target = MNISTViewer.this.manager.readLabel();
                int[][] values = new int[MNISTViewer.this.outputRows][MNISTViewer.this.outputCols];
                for (int i = 0; i < MNISTViewer.this.outputCols; ++i) {
                    values[target][i] = 255;
                }
                int c = 0;
                for (int i = 0; i < MNISTViewer.this.outputRows; ++i) {
                    for (int j = 0; j < MNISTViewer.this.outputCols; ++j) {
                        c = values[i][j];
                        g.setColor(new Color(c, c, c));
                        g.fillRect(xoffset + j, yoffset + i, 1, 1);
                    }
                }
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }

        @Override
        public void paintComponent(Graphics g) {
            super.paintComponent(g);
            g.drawRect(this.mx, this.my, this.getWidth() - 2 * this.mx, this.getHeight() - 2 * this.my);
            this.drawCurrentImage(g);
            this.drawCurrentOutput(g);
        }
    }

    class MyFrame
    extends JFrame {
        public MyFrame(String s) {
            super(s);
            this.setDefaultCloseOperation(3);
            JPanel main = new JPanel();
            main.setBorder(BorderFactory.createLineBorder(Color.black));
            main.setLayout(new GridLayout(1, 3));
            MNISTViewer.this.input = new InputPanel();
            MNISTViewer.this.output = new OutputPanel();
            MNISTViewer.this.options = new OptionsPanel();
            this.add(main);
            main.add(MNISTViewer.this.input);
            main.add(MNISTViewer.this.output);
            main.add(MNISTViewer.this.options);
        }
    }
}

