/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers.accumulation;

import com.google.common.util.concurrent.AtomicDouble;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.optimize.solvers.accumulation.MessageHandler;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.NDArrayCompressor;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class EncodingHandler
implements MessageHandler {
    private static final Logger log = LoggerFactory.getLogger(EncodingHandler.class);
    protected transient GradientsAccumulator accumulator;
    protected double threshold;
    protected double minThreshold;
    protected double thresholdStep;
    protected double stepTrigger;
    protected int shakeFrequency;
    protected int stepDelay;
    protected Double boundary = null;
    protected NDArrayCompressor compressor;
    protected AtomicInteger atomicBoundary = new AtomicInteger(-1);
    protected ThreadLocal<AtomicLong> iterations = new ThreadLocal();
    protected ThreadLocal<AtomicLong> lastStep = new ThreadLocal();
    protected ThreadLocal<AtomicDouble> currentThreshold = new ThreadLocal();
    protected ThreadLocal<AtomicBoolean> bitmapMode = new ThreadLocal();

    public EncodingHandler() {
        this(0.001);
    }

    public EncodingHandler(double threshold) {
        this(threshold, null);
    }

    public EncodingHandler(double threshold, Double boundary) {
        this(threshold, threshold, 0.0, 0.0, 0, 0, boundary);
    }

    public EncodingHandler(double threshold, double minThreshold, double thresholdStep, double stepTrigger, int stepDelay, int shakeFrequency) {
        this(threshold, minThreshold, thresholdStep, stepTrigger, stepDelay, shakeFrequency, null);
    }

    public EncodingHandler(double threshold, double minThreshold, double thresholdStep, double stepTrigger, int stepDelay, int shakeFrequency, Double boundary) {
        this.threshold = threshold;
        this.minThreshold = minThreshold;
        this.stepTrigger = stepTrigger;
        this.stepDelay = stepDelay;
        this.thresholdStep = thresholdStep;
        this.shakeFrequency = shakeFrequency;
        this.boundary = boundary;
    }

    @Override
    public void initialize(@NonNull GradientsAccumulator accumulator) {
        if (accumulator == null) {
            throw new NullPointerException("accumulator");
        }
        this.accumulator = accumulator;
        this.compressor = Nd4j.getCompressor().getCompressor("THRESHOLD");
        if (this.compressor == null) {
            throw new ND4JIllegalStateException("Can't find Threshold compressor implementation!");
        }
        this.compressor.configure(new Object[]{this.threshold});
    }

    public INDArray encodeUpdates(INDArray updates) {
        if (this.bitmapMode.get() == null) {
            this.bitmapMode.set(new AtomicBoolean(true));
            this.currentThreshold.set(new AtomicDouble(this.threshold));
            this.iterations.set(new AtomicLong(0L));
            this.lastStep.set(new AtomicLong(0L));
        }
        this.iterations.get().incrementAndGet();
        if (this.boundary != null && this.atomicBoundary.get() < 0) {
            this.atomicBoundary.compareAndSet(-1, (int)((double)updates.lengthLong() * this.boundary));
        }
        INDArray encoded = null;
        if (!this.bitmapMode.get().get()) {
            if (this.shakeFrequency != 0 && this.iterations.get().get() % (long)this.shakeFrequency == 0L) {
                DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(updates.lengthLong() / 16L + 5L);
                encoded = Nd4j.createArrayFromShapeBuffer((DataBuffer)buffer, (DataBuffer)updates.shapeInfoDataBuffer());
                Nd4j.getExecutioner().bitmapEncode(updates, encoded, this.currentThreshold.get().get() / 3.0);
            } else {
                encoded = Nd4j.getExecutioner().thresholdEncode(updates, this.currentThreshold.get().get(), this.boundary == null ? null : Integer.valueOf(this.atomicBoundary.get()));
                if (encoded == null) {
                    return null;
                }
                double encLen = encoded.data().getInt(0L);
                double encodingRatio = encLen * 100.0 / (double)updates.length();
                if (encLen > (double)(updates.lengthLong() / 16L + 5L)) {
                    this.bitmapMode.get().set(true);
                }
                if (this.minThreshold <= this.currentThreshold.get().get() && this.minThreshold < this.currentThreshold.get().get() - this.thresholdStep && this.iterations.get().get() > this.lastStep.get().get() + (long)this.stepDelay && encodingRatio < this.stepTrigger) {
                    this.currentThreshold.get().addAndGet(-this.thresholdStep);
                    this.lastStep.set(this.iterations.get());
                    log.info("Threshold steps down to {}", (Object)this.currentThreshold.get().get());
                }
            }
        } else {
            DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(updates.lengthLong() / 16L + 5L);
            encoded = Nd4j.createArrayFromShapeBuffer((DataBuffer)buffer, (DataBuffer)updates.shapeInfoDataBuffer());
            long values = Nd4j.getExecutioner().bitmapEncode(updates, encoded, this.currentThreshold.get().get());
            if (values < (updates.lengthLong() / 16L + 5L) / 2L) {
                this.bitmapMode.get().set(false);
                log.info("Switched to threshold encoding");
            }
        }
        return encoded;
    }

    @Deprecated
    public INDArray decodeUpdates(INDArray message) {
        throw new UnsupportedOperationException();
    }

    protected void sendMessage(INDArray message) {
        this.accumulator.receiveUpdate(message);
    }

    @Override
    public boolean broadcastUpdates(INDArray updates) {
        INDArray message = this.encodeUpdates(updates);
        if (message != null) {
            this.sendMessage(message);
            return true;
        }
        return false;
    }
}

