/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.parallelism.trainer;

import lombok.NonNull;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.SharedGradient;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.deeplearning4j.parallelism.trainer.CommunicativeTrainer;
import org.deeplearning4j.parallelism.trainer.DefaultTrainer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SymmetricTrainer
extends DefaultTrainer
implements CommunicativeTrainer {
    private static final Logger log = LoggerFactory.getLogger(SymmetricTrainer.class);
    protected GradientsAccumulator accumulator;

    public SymmetricTrainer(@NonNull Model originalModel, String uuid, int threadIdx, @NonNull WorkspaceMode mode, @NonNull ParallelWrapper wrapper, boolean useMDS) {
        if (originalModel == null) {
            throw new NullPointerException("originalModel is marked non-null but is null");
        }
        if (mode == null) {
            throw new NullPointerException("mode is marked non-null but is null");
        }
        if (wrapper == null) {
            throw new NullPointerException("wrapper is marked non-null but is null");
        }
        this.uuid = uuid + "_thread_" + threadIdx;
        this.useMDS = useMDS;
        this.originalModel = originalModel;
        this.threadId = threadIdx;
        this.workspaceMode = mode;
        this.parallelWrapper = wrapper;
        this.accumulator = wrapper.getGradientsAccumulator();
    }

    @Override
    @Deprecated
    public void enqueueGradient(SharedGradient gradient) {
    }

    @Override
    public boolean averagingRequired() {
        return false;
    }

    @Override
    protected void postInit() {
        super.postInit();
        if (this.accumulator == null) {
            log.warn("GradientsAccumulator is undefined, gradients sharing will be skipped");
            return;
        }
        if (this.replicatedModel instanceof ComputationGraph) {
            ((ComputationGraph)this.replicatedModel).setGradientsAccumulator(this.accumulator);
        } else if (this.replicatedModel instanceof MultiLayerNetwork) {
            ((MultiLayerNetwork)this.replicatedModel).setGradientsAccumulator(this.accumulator);
        }
        this.accumulator.touch();
    }
}

