/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.parameterserver.networking.v2;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.optimize.solvers.accumulation.IndexedTail;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.v2.transport.UpdatesHandler;
import org.reactivestreams.Subscription;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class UpdatesConsumer
implements UpdatesHandler {
    private static final Logger log = LoggerFactory.getLogger(UpdatesConsumer.class);
    protected int numWorkers;
    protected transient INDArray params;
    protected transient INDArray updates;
    protected transient StepFunction stepFunction;
    protected transient GradientsAccumulator accumulator;
    protected final transient AtomicLong updatesCount = new AtomicLong(0L);
    protected final transient AtomicBoolean hasSomething = new AtomicBoolean(false);
    protected final transient AtomicBoolean bypassMode = new AtomicBoolean(false);
    protected final transient AtomicLong denseCounter = new AtomicLong(0L);
    protected final transient AtomicLong sparseCounter = new AtomicLong(0L);
    protected transient IndexedTail updatesBuffer;

    public void onSubscribe(Subscription subscription) {
    }

    public void bypassMode(boolean reallBypass) {
        this.bypassMode.set(reallBypass);
    }

    public boolean isBypassMod() {
        return this.bypassMode.get();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public IndexedTail getUpdatesQueue() {
        if (this.updatesBuffer == null && this.accumulator != null) {
            UpdatesConsumer updatesConsumer = this;
            synchronized (updatesConsumer) {
                if (this.updatesBuffer == null) {
                    this.updatesBuffer = new IndexedTail(this.numWorkers, true, this.params.shape());
                }
            }
        }
        return this.updatesBuffer;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void onNext(INDArray array) {
        UpdatesConsumer updatesConsumer;
        if (this.updatesBuffer == null && this.accumulator != null) {
            updatesConsumer = this;
            synchronized (updatesConsumer) {
                if (this.updatesBuffer == null) {
                    this.updatesBuffer = new IndexedTail(this.numWorkers, true, this.params.shape());
                }
            }
        }
        if (!this.bypassMode.get()) {
            if (this.accumulator != null) {
                try {
                    this.updatesBuffer.put(array);
                }
                catch (Exception e) {
                    log.error("", (Throwable)e);
                    throw new RuntimeException(e);
                }
            } else if (this.params != null && this.stepFunction != null) {
                updatesConsumer = this;
                synchronized (updatesConsumer) {
                    int encoding = array.data().getInt(3L);
                    if (encoding == 0) {
                        Nd4j.getExecutioner().thresholdDecode(array, this.updates);
                        this.sparseCounter.incrementAndGet();
                    } else if (encoding == 1) {
                        Nd4j.getExecutioner().bitmapDecode(array, this.updates);
                        this.denseCounter.incrementAndGet();
                    } else {
                        throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
                    }
                    this.hasSomething.set(true);
                    if (this.updatesCount.incrementAndGet() % 32L == 0L) {
                        this.flush();
                    }
                }
            } else {
                throw new ND4JIllegalStateException("Accumulator & StepFunction is null at the same time");
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void flush() {
        UpdatesConsumer updatesConsumer = this;
        synchronized (updatesConsumer) {
            if (this.params != null && this.updates != null && this.hasSomething.get()) {
                this.stepFunction.step(this.params, this.updates);
                Nd4j.getExecutioner().commit();
                log.debug("Applying updates. Current ratio: [{}]; Sparse: [{}]; Dense: [{}];", new Object[]{(double)this.sparseCounter.get() / (double)this.denseCounter.get(), this.sparseCounter.get(), this.denseCounter.get()});
                Nd4j.getMemoryManager().memset(this.updates);
                this.hasSomething.set(false);
            }
        }
    }

    public void onError(Throwable throwable) {
        throw new RuntimeException(throwable);
    }

    public void onComplete() {
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public INDArray getParametersArray() {
        UpdatesConsumer updatesConsumer = this;
        synchronized (updatesConsumer) {
            return this.params.dup(this.params.ordering());
        }
    }

    public static UpdatesConsumerBuilder builder() {
        return new UpdatesConsumerBuilder();
    }

    public UpdatesConsumer(int numWorkers, INDArray params, INDArray updates, StepFunction stepFunction, GradientsAccumulator accumulator, IndexedTail updatesBuffer) {
        this.numWorkers = numWorkers;
        this.params = params;
        this.updates = updates;
        this.stepFunction = stepFunction;
        this.accumulator = accumulator;
        this.updatesBuffer = updatesBuffer;
    }

    public UpdatesConsumer() {
    }

    public static class UpdatesConsumerBuilder {
        private int numWorkers;
        private INDArray params;
        private INDArray updates;
        private StepFunction stepFunction;
        private GradientsAccumulator accumulator;
        private IndexedTail updatesBuffer;

        UpdatesConsumerBuilder() {
        }

        public UpdatesConsumerBuilder numWorkers(int numWorkers) {
            this.numWorkers = numWorkers;
            return this;
        }

        public UpdatesConsumerBuilder params(INDArray params) {
            this.params = params;
            return this;
        }

        public UpdatesConsumerBuilder updates(INDArray updates) {
            this.updates = updates;
            return this;
        }

        public UpdatesConsumerBuilder stepFunction(StepFunction stepFunction) {
            this.stepFunction = stepFunction;
            return this;
        }

        public UpdatesConsumerBuilder accumulator(GradientsAccumulator accumulator) {
            this.accumulator = accumulator;
            return this;
        }

        public UpdatesConsumerBuilder updatesBuffer(IndexedTail updatesBuffer) {
            this.updatesBuffer = updatesBuffer;
            return this;
        }

        public UpdatesConsumer build() {
            return new UpdatesConsumer(this.numWorkers, this.params, this.updates, this.stepFunction, this.accumulator, this.updatesBuffer);
        }

        public String toString() {
            return "UpdatesConsumer.UpdatesConsumerBuilder(numWorkers=" + this.numWorkers + ", params=" + this.params + ", updates=" + this.updates + ", stepFunction=" + this.stepFunction + ", accumulator=" + this.accumulator + ", updatesBuffer=" + this.updatesBuffer + ")";
        }
    }
}

