/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.scaleout.perform.models.glove;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import org.deeplearning4j.nn.conf.Configuration;
import org.deeplearning4j.scaleout.aggregator.JobAggregator;
import org.deeplearning4j.scaleout.job.Job;
import org.deeplearning4j.scaleout.perform.models.glove.GloveResult;
import org.deeplearning4j.util.MultiDimensionalMap;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class GloveJobAggregator
implements JobAggregator {
    private List<GloveResult> work = new ArrayList<GloveResult>();

    public void accumulate(Job job) {
        if (job.getResult() instanceof GloveResult) {
            GloveResult work = (GloveResult)job.getResult();
            this.work.add(work);
        } else if (job.getResult() instanceof Collection) {
            Collection coll = (Collection)((Object)job.getResult());
            this.work.addAll(coll);
        }
    }

    public Job aggregate() {
        Job ret = new Job((Serializable)((Object)""), "");
        GloveResult aggregateResult = new GloveResult();
        MultiDimensionalMap workResults = MultiDimensionalMap.newHashBackedMap();
        HashSet<String> vocab = new HashSet<String>();
        for (GloveResult r : this.work) {
            for (String syn0Key : r.getSyn0Change().keySet()) {
                List<INDArray> syn0List = this.getOrPutIfNotExists((MultiDimensionalMap<String, String, List<INDArray>>)workResults, syn0Key, "syn0");
                syn0List.add(r.getSyn0Change().get(syn0Key));
                vocab.add(syn0Key);
            }
        }
        for (String key : vocab) {
            aggregateResult.getSyn0Change().put(key, this.average((List)workResults.get((Object)key, (Object)"syn0")));
        }
        ret.setResult((Serializable)((Object)Arrays.asList(aggregateResult)));
        return ret;
    }

    private INDArray average(List<INDArray> list) {
        if (list == null || list.isEmpty()) {
            throw new IllegalArgumentException("Can't average empty or null list");
        }
        if (list.get(0) == null) {
            return null;
        }
        INDArray ret = Nd4j.create((int[])list.get(0).shape());
        for (INDArray arr : list) {
            ret.addi(arr);
        }
        if (list.size() > 1) {
            return ret.divi((Number)list.size());
        }
        return ret;
    }

    private List<INDArray> getOrPutIfNotExists(MultiDimensionalMap<String, String, List<INDArray>> workResults, String key, String otherKey) {
        ArrayList syn0List = (ArrayList)workResults.get((Object)key, (Object)otherKey);
        if (syn0List == null) {
            syn0List = new ArrayList();
            workResults.put((Object)key, (Object)otherKey, syn0List);
        }
        return syn0List;
    }

    public void init(Configuration conf) {
    }
}

