/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers.accumulation;

import java.util.List;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.optimize.solvers.accumulation.IndexedTail;
import org.deeplearning4j.optimize.solvers.accumulation.LocalHandler;
import org.deeplearning4j.optimize.solvers.accumulation.MessageHandler;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BasicGradientsAccumulator
implements GradientsAccumulator {
    private static final Logger log = LoggerFactory.getLogger(BasicGradientsAccumulator.class);
    protected MessageHandler handler;
    protected transient IndexedTail gradients;
    protected transient INDArray storage;
    protected transient INDArray updates;
    protected transient AtomicLong ownCounter = new AtomicLong(0L);
    protected transient AtomicLong extCounter = new AtomicLong(0L);
    protected long[] shape;
    protected char ordering;
    protected int parties = 0;
    protected CyclicBarrier barrier;
    protected AtomicLong firstOne = new AtomicLong(-1L);
    protected List<INDArray> candidates = new CopyOnWriteArrayList<INDArray>();
    protected ReentrantReadWriteLock updatesLock = new ReentrantReadWriteLock();
    protected AtomicBoolean hasSomething = new AtomicBoolean(false);

    public BasicGradientsAccumulator(int parties) {
        this(parties, new LocalHandler());
    }

    public BasicGradientsAccumulator(int parties, @NonNull MessageHandler handler) {
        if (handler == null) {
            throw new NullPointerException("handler is marked non-null but is null");
        }
        this.gradients = new IndexedTail(parties);
        this.handler = handler;
        this.handler.initialize(this);
        this.parties = parties;
        this.barrier = new CyclicBarrier(parties);
    }

    @Override
    public IndexedTail getExternalSource() {
        return this.gradients;
    }

    @Override
    public void applyUpdate(StepFunction function, INDArray params, INDArray grad, boolean isFinalStep) {
        try {
            this.updatesLock.readLock().lock();
            this.firstOne.compareAndSet(-1L, Thread.currentThread().getId());
            if (this.hasSomething.get()) {
                function.step(params, this.updates);
            }
            this.barrier.await();
            if (this.firstOne.get() == Thread.currentThread().getId()) {
                this.updates.assign((Number)0.0);
                this.hasSomething.set(false);
                this.firstOne.set(-1L);
            }
            this.updatesLock.readLock().unlock();
            this.barrier.await();
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
        catch (BrokenBarrierException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void markExternalUpdates(boolean updatesAvailable) {
    }

    @Override
    public void applyUpdate(StepFunction function, INDArray params, INDArray grad, double alpha) {
        try {
            this.updatesLock.readLock().lock();
            this.firstOne.compareAndSet(-1L, Thread.currentThread().getId());
            if (this.hasSomething.get()) {
                function.step(params, this.updates, alpha);
            }
            this.barrier.await();
            if (this.firstOne.get() == Thread.currentThread().getId()) {
                this.updates.assign((Number)0.0);
                this.hasSomething.set(false);
                this.firstOne.set(-1L);
            }
            this.updatesLock.readLock().unlock();
            this.barrier.await();
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
        catch (BrokenBarrierException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void storeUpdate(INDArray array, int iterationNumber, int epochNumber) {
        try {
            Nd4j.getExecutioner().commit();
            this.firstOne.compareAndSet(-1L, Thread.currentThread().getId());
            this.candidates.add(array);
            this.barrier.await();
            if (this.firstOne.get() == Thread.currentThread().getId()) {
                if (this.storage == null) {
                    this.shape = array.shape();
                    this.ordering = array.ordering();
                    try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                        this.storage = Nd4j.create((long[])this.shape, (char)this.ordering);
                    }
                }
                Nd4j.accumulate((INDArray)this.storage, this.candidates);
                Nd4j.getExecutioner().commit();
                if (this.handler.broadcastUpdates(this.storage, iterationNumber, epochNumber)) {
                    this.ownCounter.getAndIncrement();
                }
                this.firstOne.set(-1L);
                this.candidates.clear();
            }
            this.barrier.await();
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
        catch (BrokenBarrierException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void receiveUpdate(INDArray array) {
        this.extCounter.getAndIncrement();
        this.updatesLock.writeLock().lock();
        if (this.updates == null) {
            try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                this.updates = Nd4j.create((long[])array.shape(), (char)array.ordering());
            }
        }
        this.hasSomething.compareAndSet(false, true);
        this.updates.addi(array);
        Nd4j.getExecutioner().commit();
        this.updatesLock.writeLock().unlock();
    }

    @Override
    public void reset() {
        this.updatesLock.writeLock().lock();
        if (this.storage != null) {
            this.storage.assign((Number)Float.valueOf(0.0f));
        }
        if (this.updates != null) {
            this.updates.assign((Number)Float.valueOf(0.0f));
        }
        this.updatesLock.writeLock().unlock();
    }

    @Override
    public void touch() {
    }

    @Override
    public void setExternalSource(IndexedTail source) {
        this.gradients = source;
    }

    @Override
    public boolean hasAnything() {
        return false;
    }
}

