/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.parameterserver.networking.v1;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.optimize.solvers.accumulation.IndexedTail;
import org.deeplearning4j.spark.parameterserver.networking.v1.messages.SilentUpdatesMessage;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.logic.Storage;
import org.nd4j.parameterserver.distributed.logic.completion.Clipboard;
import org.nd4j.parameterserver.distributed.messages.VoidAggregation;
import org.nd4j.parameterserver.distributed.messages.VoidMessage;
import org.nd4j.parameterserver.distributed.training.TrainingDriver;
import org.nd4j.parameterserver.distributed.transport.Transport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
public class SilentTrainingDriver
implements TrainingDriver<SilentUpdatesMessage> {
    private static final Logger log = LoggerFactory.getLogger(SilentTrainingDriver.class);
    protected transient INDArray params;
    protected transient INDArray updates;
    protected transient StepFunction stepFunction;
    protected transient GradientsAccumulator accumulator;
    protected transient VoidConfiguration voidConfiguration;
    protected transient Transport transport;
    protected transient AtomicLong updatesCount;
    protected transient AtomicBoolean hasSomething;
    protected transient AtomicBoolean bypassMode = new AtomicBoolean(false);
    protected transient AtomicLong denseCounter = new AtomicLong(0L);
    protected transient AtomicLong sparseCounter = new AtomicLong(0L);
    protected transient IndexedTail updatesBuffer;
    protected transient Storage storage;
    protected transient Clipboard clipboard;

    public SilentTrainingDriver(@NonNull GradientsAccumulator accumulator) {
        if (accumulator == null) {
            throw new NullPointerException("accumulator is marked non-null but is null");
        }
        log.info("Creating TrainingDriver for worker...");
        this.accumulator = accumulator;
        this.updatesCount = new AtomicLong(0L);
        this.updatesBuffer = new IndexedTail(1);
        this.accumulator.setExternalSource(this.updatesBuffer);
    }

    public SilentTrainingDriver(@NonNull INDArray params, @NonNull StepFunction stepFunction) {
        if (params == null) {
            throw new NullPointerException("params is marked non-null but is null");
        }
        if (stepFunction == null) {
            throw new NullPointerException("stepFunction is marked non-null but is null");
        }
        log.info("Creating TrainingDriver for master...");
        log.info("Params at Master BEFORE: {}", (Object)params.meanNumber().doubleValue());
        this.params = params;
        this.stepFunction = stepFunction;
        this.updatesCount = new AtomicLong(0L);
        this.hasSomething = new AtomicBoolean(false);
        try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            this.updates = Nd4j.create((long[])params.shape(), (char)params.ordering());
        }
    }

    public IndexedTail getUpdatesBuffer() {
        return this.updatesBuffer;
    }

    public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport, Storage storage, Clipboard clipboard) {
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration is marked non-null but is null");
        }
        if (transport == null) {
            throw new NullPointerException("transport is marked non-null but is null");
        }
        this.voidConfiguration = voidConfiguration;
        this.transport = transport;
    }

    public void bypassMode(boolean reallyBypass) {
        this.bypassMode.set(reallyBypass);
        if (reallyBypass) {
            // empty if block
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public void startTraining(SilentUpdatesMessage message) {
        if (this.accumulator != null) {
            if (message.getOriginatorId() == this.transport.getOwnOriginatorId()) {
                return;
            }
            try {
                if (this.bypassMode.get()) return;
                this.updatesBuffer.put(message.getUpdates());
                return;
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        if (this.params == null || this.stepFunction == null) throw new DL4JInvalidConfigException("Neither GradientsAccumulator or StepFunction is defined!");
        SilentTrainingDriver silentTrainingDriver = this;
        synchronized (silentTrainingDriver) {
            int encoding = message.getUpdates().data().getInt(3L);
            if (encoding == 0) {
                Nd4j.getExecutioner().thresholdDecode(message.getUpdates(), this.updates);
                this.sparseCounter.incrementAndGet();
            } else {
                if (encoding != 1) throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
                Nd4j.getExecutioner().bitmapDecode(message.getUpdates(), this.updates);
                this.denseCounter.incrementAndGet();
            }
            this.hasSomething.set(true);
            if (this.updatesCount.incrementAndGet() % (long)Math.max(this.transport.numberOfKnownClients(), 5) == 0L) {
                this.stepFunction.step(this.params, this.updates);
                Nd4j.getMemoryManager().memset(this.updates);
                this.hasSomething.set(false);
            }
        }
        if (this.transport.numberOfKnownClients() <= 1) return;
        this.transport.sendMessageToAllClients((VoidMessage)message, new Long[]{message.getOriginatorId(), this.transport.getOwnOriginatorId()});
    }

    public void pickTraining(SilentUpdatesMessage message) {
        throw new UnsupportedOperationException();
    }

    public void aggregationFinished(VoidAggregation aggregation) {
        throw new UnsupportedOperationException();
    }

    public void finishTraining(long originatorId, long taskId) {
        if (this.params != null && this.stepFunction != null && this.hasSomething.get()) {
            this.stepFunction.step(this.params, this.updates);
            this.updates.assign((Number)0.0);
        }
    }

    public void addCompletionHook(long originatorId, long frameId, long messageId) {
        throw new UnsupportedOperationException();
    }

    public String targetMessageClass() {
        return SilentUpdatesMessage.class.getSimpleName();
    }
}

