/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.ui.weights;

import com.fasterxml.jackson.jaxrs.json.JacksonJsonProvider;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.Response;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.ui.UiConnectionInfo;
import org.deeplearning4j.ui.UiServer;
import org.deeplearning4j.ui.UiUtils;
import org.deeplearning4j.ui.providers.ObjectMapperProvider;
import org.deeplearning4j.ui.weights.HistogramBin;
import org.deeplearning4j.ui.weights.beans.CompactModelAndGradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HistogramIterationListener
implements IterationListener {
    private static final Logger log = LoggerFactory.getLogger(HistogramIterationListener.class);
    private Client client = (Client)((Client)ClientBuilder.newClient().register(JacksonJsonProvider.class)).register((Object)new ObjectMapperProvider());
    private WebTarget target;
    private int iterations = 1;
    private ArrayList<Double> scoreHistory = new ArrayList();
    private List<Map<String, List<Double>>> meanMagHistoryParams = new ArrayList<Map<String, List<Double>>>();
    private List<Map<String, List<Double>>> meanMagHistoryUpdates = new ArrayList<Map<String, List<Double>>>();
    private Map<String, Integer> layerNameIndexes = new HashMap<String, Integer>();
    private List<String> layerNames = new ArrayList<String>();
    private int layerNameIndexesCount = 0;
    private boolean openBrowser;
    private boolean firstIteration = true;
    private String path;
    private String subPath = "weights";
    private UiConnectionInfo connectionInfo;

    public HistogramIterationListener(@NonNull UiConnectionInfo connection, int iterations) {
        if (connection == null) {
            throw new NullPointerException("connection");
        }
        this.target = this.client.target(connection.getFirstPart()).path(connection.getSecondPart(this.subPath)).path("update").queryParam("sid", new Object[]{connection.getSessionId()});
        this.connectionInfo = connection;
        System.out.println("UI Histogram URL: " + connection.getFullAddress());
    }

    public HistogramIterationListener(int iterations) {
        this(iterations, true);
    }

    public HistogramIterationListener(int iterations, boolean openBrowser) {
        UiConnectionInfo connectionInfo;
        int port = -1;
        try {
            UiServer server = UiServer.getInstance();
            port = server.getPort();
        }
        catch (Exception e) {
            log.error("Error initializing UI server", (Throwable)e);
            throw new RuntimeException(e);
        }
        this.iterations = iterations;
        if (this.iterations < 1) {
            this.iterations = 1;
        }
        this.connectionInfo = connectionInfo = new UiConnectionInfo.Builder().enableHttps(false).setAddress("localhost").setPort(port).build();
        this.target = this.client.target(connectionInfo.getFirstPart()).path(this.subPath).path("update").queryParam("sid", new Object[]{connectionInfo.getSessionId()});
        this.openBrowser = openBrowser;
        this.path = "http://localhost:" + port + "/" + this.subPath;
        System.out.println("UI Histogram URL: " + this.path + "?sid=" + connectionInfo.getSessionId());
    }

    public boolean invoked() {
        return false;
    }

    public void invoke() {
    }

    public void iterationDone(Model model, int iteration) {
        if (iteration % this.iterations == 0) {
            LinkedHashMap<String, Map> newGrad = new LinkedHashMap<String, Map>();
            try {
                Map grad = model.gradient().gradientForVariable();
                if (this.meanMagHistoryParams.size() == 0) {
                    int maxLayerIdx = -1;
                    for (String s : grad.keySet()) {
                        maxLayerIdx = Math.max(maxLayerIdx, this.indexFromString(s));
                    }
                    if (maxLayerIdx == -1) {
                        maxLayerIdx = 0;
                    }
                    for (int i = 0; i <= maxLayerIdx; ++i) {
                        this.meanMagHistoryParams.add(new LinkedHashMap());
                        this.meanMagHistoryUpdates.add(new LinkedHashMap());
                    }
                }
                for (Map.Entry entry : grad.entrySet()) {
                    Map<String, List<Double>> map;
                    List<Double> list;
                    String param = (String)entry.getKey();
                    String newName = Character.isDigit(param.charAt(0)) ? "param_" + param : param;
                    HistogramBin histogram = new HistogramBin.Builder(((INDArray)entry.getValue()).dup()).setBinCount(20).setRounding(6).build();
                    newGrad.put(newName, histogram.getData());
                    int idx = this.indexFromString(param);
                    if (idx >= this.meanMagHistoryUpdates.size()) {
                        this.meanMagHistoryUpdates.add(new LinkedHashMap());
                    }
                    if ((list = (map = this.meanMagHistoryUpdates.get(idx)).get(newName)) == null) {
                        list = new ArrayList<Double>();
                        map.put(newName, list);
                    }
                    double meanMag = ((INDArray)entry.getValue()).norm1Number().doubleValue() / (double)((INDArray)entry.getValue()).length();
                    list.add(meanMag);
                }
            }
            catch (Exception e) {
                log.warn("Skipping gradients update");
            }
            Map params = model.paramTable();
            LinkedHashMap<String, Map> newParams = new LinkedHashMap<String, Map>();
            for (Map.Entry entry : params.entrySet()) {
                Map<String, List<Double>> map;
                List<Double> list;
                String param = (String)entry.getKey();
                String newName = Character.isDigit(param.charAt(0)) ? "param_" + param : param;
                HistogramBin histogram = new HistogramBin.Builder(((INDArray)entry.getValue()).dup()).setBinCount(20).setRounding(6).build();
                newParams.put(newName, histogram.getData());
                int idx = this.indexFromString(param);
                if (idx >= this.meanMagHistoryParams.size()) {
                    this.meanMagHistoryParams.add(new LinkedHashMap());
                }
                if ((list = (map = this.meanMagHistoryParams.get(idx)).get(newName)) == null) {
                    list = new ArrayList<Double>();
                    map.put(newName, list);
                }
                double meanMag = ((INDArray)entry.getValue()).norm1Number().doubleValue() / (double)((INDArray)entry.getValue()).length();
                list.add(meanMag);
            }
            double score = model.score();
            this.scoreHistory.add(score);
            CompactModelAndGradient g = new CompactModelAndGradient();
            g.setGradients(newGrad);
            g.setParameters(newParams);
            g.setScore(score);
            g.setScores(this.scoreHistory);
            g.setPath(this.subPath);
            g.setUpdateMagnitudes(this.meanMagHistoryUpdates);
            g.setParamMagnitudes(this.meanMagHistoryParams);
            g.setLayerNames(this.layerNames);
            g.setLastUpdateTime(System.currentTimeMillis());
            Response resp = this.target.request(new String[]{"application/json"}).accept(new String[]{"application/json"}).post(Entity.entity((Object)g, (String)"application/json"));
            log.info("{}", (Object)resp);
            if (this.openBrowser && this.firstIteration) {
                StringBuilder builder = new StringBuilder(this.connectionInfo.getFullAddress());
                builder.append(this.subPath).append("?sid=").append(this.connectionInfo.getSessionId());
                UiUtils.tryOpenBrowser(builder.toString(), log);
                this.firstIteration = false;
            }
        }
    }

    private int indexFromString(String str) {
        int underscore = str.indexOf("_");
        if (underscore == -1) {
            if (!this.layerNameIndexes.containsKey(str)) {
                this.layerNames.add(str);
                this.layerNameIndexes.put(str, this.layerNameIndexesCount++);
            }
            return this.layerNameIndexes.get(str);
        }
        String subStr = str.substring(0, underscore);
        if (!this.layerNameIndexes.containsKey(subStr)) {
            this.layerNames.add(subStr);
            this.layerNameIndexes.put(subStr, this.layerNameIndexesCount++);
        }
        return this.layerNameIndexes.get(subStr);
    }
}

