/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.parameterserver.accumulation;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import org.apache.spark.api.java.function.Function2;
import org.deeplearning4j.core.storage.Persistable;
import org.deeplearning4j.core.storage.StorageMetaData;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithmReducer;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.parameterserver.accumulation.SharedTrainingAccumulationTuple;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class SharedTrainingAccumulationFunction
implements Function2<SharedTrainingAccumulationTuple, SharedTrainingAccumulationTuple, SharedTrainingAccumulationTuple> {
    public SharedTrainingAccumulationTuple call(SharedTrainingAccumulationTuple tuple1, SharedTrainingAccumulationTuple tuple2) throws Exception {
        if (tuple1 == null) {
            return tuple2;
        }
        if (tuple2 == null) {
            return tuple1;
        }
        double score = 0.0;
        INDArray stateView = null;
        int aggregationsCount = 0;
        if (tuple1.getUpdaterStateArray() != null && tuple2.getUpdaterStateArray() != null) {
            stateView = tuple1.getUpdaterStateArray().addi(tuple2.getUpdaterStateArray());
        } else if (tuple1.getUpdaterStateArray() != null || tuple2.getUpdaterStateArray() != null) {
            stateView = tuple1.getUpdaterStateArray() != null ? tuple1.getUpdaterStateArray() : tuple2.getUpdaterStateArray();
        }
        aggregationsCount = tuple1.getAggregationsCount() + tuple2.getAggregationsCount();
        score = tuple1.getScoreSum() + tuple2.getScoreSum();
        SparkTrainingStats stats = tuple1.getSparkTrainingStats();
        if (tuple2.getSparkTrainingStats() != null) {
            if (stats == null) {
                stats = tuple2.getSparkTrainingStats();
            } else {
                stats.addOtherTrainingStats(tuple2.getSparkTrainingStats());
            }
        }
        Nd4j.getExecutioner().commit();
        Collection<StorageMetaData> listenerMetaData = tuple1.getListenerMetaData();
        if (listenerMetaData == null) {
            listenerMetaData = tuple2.getListenerMetaData();
        } else {
            Collection<StorageMetaData> newMeta = tuple2.getListenerMetaData();
            if (newMeta != null) {
                listenerMetaData.addAll(newMeta);
            }
        }
        Collection<Persistable> listenerStaticInfo = tuple1.getListenerStaticInfo();
        if (listenerStaticInfo == null) {
            listenerStaticInfo = tuple2.getListenerStaticInfo();
        } else {
            Collection<Persistable> newStatic = tuple2.getListenerStaticInfo();
            if (newStatic != null) {
                listenerStaticInfo.addAll(newStatic);
            }
        }
        Collection<Persistable> listenerUpdates = tuple1.getListenerUpdates();
        if (listenerUpdates == null) {
            listenerUpdates = tuple2.getListenerUpdates();
        } else {
            Collection<Persistable> listenerUpdates2 = tuple2.getListenerUpdates();
            if (listenerUpdates2 != null) {
                listenerUpdates.addAll(listenerUpdates2);
            }
        }
        HashMap<String, Integer> minibatchesPerExecutor = new HashMap<String, Integer>();
        if (tuple1.getMinibatchesPerExecutor() != null) {
            for (Map.Entry<String, Integer> e : tuple1.getMinibatchesPerExecutor().entrySet()) {
                minibatchesPerExecutor.put(e.getKey(), e.getValue());
            }
        }
        if (tuple2.getMinibatchesPerExecutor() != null) {
            for (Map.Entry<String, Integer> e : tuple2.getMinibatchesPerExecutor().entrySet()) {
                if (minibatchesPerExecutor.containsKey(e.getKey())) {
                    minibatchesPerExecutor.put(e.getKey(), (Integer)minibatchesPerExecutor.get(e.getKey()) + e.getValue());
                    continue;
                }
                minibatchesPerExecutor.put(e.getKey(), e.getValue());
            }
        }
        ThresholdAlgorithmReducer thresholdAlgorithmReducer = null;
        if (tuple1.getThresholdAlgorithmReducer() != null) {
            thresholdAlgorithmReducer = tuple1.getThresholdAlgorithmReducer();
        }
        if (tuple2.getThresholdAlgorithmReducer() != null) {
            thresholdAlgorithmReducer = thresholdAlgorithmReducer == null ? tuple2.getThresholdAlgorithmReducer() : thresholdAlgorithmReducer.merge(tuple2.getThresholdAlgorithmReducer());
        }
        return SharedTrainingAccumulationTuple.builder().scoreSum(score).updaterStateArray(stateView).aggregationsCount(aggregationsCount).sparkTrainingStats(stats).listenerMetaData(listenerMetaData).listenerUpdates(listenerUpdates).listenerStaticInfo(listenerStaticInfo).minibatchesPerExecutor(minibatchesPerExecutor).thresholdAlgorithmReducer(thresholdAlgorithmReducer).build();
    }
}

