/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.parameterserver.distributed.training.impl;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.RandomUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.enums.ExecutionMode;
import org.nd4j.parameterserver.distributed.logic.completion.FrameCompletionHandler;
import org.nd4j.parameterserver.distributed.logic.completion.RequestDescriptor;
import org.nd4j.parameterserver.distributed.logic.storage.WordVectorStorage;
import org.nd4j.parameterserver.distributed.messages.VoidAggregation;
import org.nd4j.parameterserver.distributed.messages.aggregations.DotAggregation;
import org.nd4j.parameterserver.distributed.messages.complete.FrameCompleteMessage;
import org.nd4j.parameterserver.distributed.messages.intercom.DistributedSgDotMessage;
import org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage;
import org.nd4j.parameterserver.distributed.training.BaseTrainer;
import org.nd4j.parameterserver.distributed.training.chains.SkipGramChain;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SkipGramTrainer
extends BaseTrainer<SkipGramRequestMessage> {
    private static final Logger log = LoggerFactory.getLogger(SkipGramTrainer.class);
    private static final float HS_MAX_EXP = 6.0f;
    protected Map<RequestDescriptor, SkipGramChain> chains = new ConcurrentHashMap<RequestDescriptor, SkipGramChain>();
    protected AtomicLong cntRounds = new AtomicLong(0L);

    @Override
    public void startTraining(SkipGramRequestMessage message) {
        SkipGramChain chain = new SkipGramChain(message.getOriginatorId(), message.getTaskId(), message.getFrameId());
        chain.addElement(message);
        this.chains.put(RequestDescriptor.createDescriptor(message.getOriginatorId(), message.getTaskId()), chain);
        int[] row_syn0 = new int[]{};
        int[] row_syn1 = message.getPoints();
        if (message.getNegSamples() > 0) {
            int rows = this.storage.getArray(WordVectorStorage.SYN_0).rows();
            int[] tempArray = new int[message.getNegSamples() + 1];
            tempArray[0] = message.getW1();
            for (int e = 1; e < message.getNegSamples() + 1; ++e) {
                int rnd;
                while ((rnd = RandomUtils.nextInt((int)0, (int)rows)) == message.getW1()) {
                }
                tempArray[e] = rnd;
            }
            row_syn1 = ArrayUtils.addAll((int[])row_syn1, (int[])tempArray);
            message.setNegatives(tempArray);
        }
        if (message.getPoints().length != message.getCodes().length) {
            throw new RuntimeException("Mismatiching points/codes lengths here!");
        }
        DistributedSgDotMessage ddm = new DistributedSgDotMessage(message.getTaskId(), row_syn0, row_syn1, message.getW1(), message.getW2(), message.getCodes(), message.getCodes() != null && message.getCodes().length > 0, message.getNegSamples(), (float)message.getAlpha());
        ddm.setTargetId((short)-1);
        ddm.setOriginatorId(message.getOriginatorId());
        if (this.voidConfiguration.getExecutionMode() == ExecutionMode.AVERAGING) {
            this.transport.putMessage(ddm);
        } else if (this.voidConfiguration.getExecutionMode() == ExecutionMode.SHARDED) {
            this.transport.sendMessage(ddm);
        }
    }

    @Override
    public void pickTraining(@NonNull SkipGramRequestMessage message) {
        if (message == null) {
            throw new NullPointerException("message is marked non-null but is null");
        }
        RequestDescriptor descriptor = RequestDescriptor.createDescriptor(message.getOriginatorId(), message.getTaskId());
        if (!this.chains.containsKey(descriptor)) {
            SkipGramChain chain = new SkipGramChain(message);
            this.chains.put(descriptor, chain);
        }
    }

    @Override
    public String targetMessageClass() {
        return SkipGramRequestMessage.class.getSimpleName();
    }

    @Override
    public void aggregationFinished(@NonNull VoidAggregation aggregation) {
        if (aggregation == null) {
            throw new NullPointerException("aggregation is marked non-null but is null");
        }
        SkipGramChain chain = this.chains.get(RequestDescriptor.createDescriptor(aggregation.getOriginatorId(), aggregation.getTaskId()));
        if (chain == null) {
            throw new RuntimeException("sI_" + this.transport.getShardIndex() + " Unable to find chain for specified originatorId: [" + aggregation.getOriginatorId() + "]; taskId: [" + aggregation.getTaskId() + "]");
        }
        chain.addElement((DotAggregation)aggregation);
        this.finishTraining(aggregation.getOriginatorId(), aggregation.getTaskId());
    }

    @Override
    public void finishTraining(long originatorId, long taskId) {
        RequestDescriptor descriptor;
        boolean updated;
        INDArray neu1e;
        INDArray syn0;
        SkipGramRequestMessage sgrm;
        SkipGramChain chain;
        RequestDescriptor chainDesc;
        block15: {
            int e;
            chainDesc = RequestDescriptor.createDescriptor(originatorId, taskId);
            chain = this.chains.get(chainDesc);
            if (chain == null) {
                throw new RuntimeException("Unable to find chain for specified taskId: [" + taskId + "]");
            }
            sgrm = chain.getRequestMessage();
            double alpha = sgrm.getAlpha();
            INDArray expTable = this.storage.getArray(WordVectorStorage.EXP_TABLE);
            INDArray dots = chain.getDotAggregation().getAccumulatedResult();
            syn0 = this.storage.getArray(WordVectorStorage.SYN_0);
            INDArray syn1 = this.storage.getArray(WordVectorStorage.SYN_1);
            INDArray syn1Neg = this.storage.getArray(WordVectorStorage.SYN_1_NEGATIVE);
            neu1e = Nd4j.create((int)syn0.columns());
            updated = false;
            if (sgrm.getCodes().length > 0) {
                for (e = 0; e < sgrm.getCodes().length; ++e) {
                    int idx;
                    float dot = dots.getFloat((long)e);
                    if (dot < -6.0f || dot >= 6.0f || (long)(idx = (int)((double)(dot + 6.0f) * ((double)((float)expTable.length() / 6.0f) / 2.0))) >= expTable.length() || idx < 0) continue;
                    byte code = chain.getRequestMessage().getCodes()[e];
                    double f = expTable.getFloat((long)idx);
                    double g = ((double)(1 - code) - f) * alpha;
                    updated = true;
                    Nd4j.getBlasWrapper().axpy((Number)new Double(g), syn1.getRow((long)sgrm.getPoints()[e]), neu1e);
                    Nd4j.getBlasWrapper().axpy((Number)new Double(g), syn0.getRow((long)sgrm.getW2()), syn1.getRow((long)sgrm.getPoints()[e]));
                }
            }
            if (sgrm.getNegSamples() <= 0) break block15;
            int cnt = 0;
            while (e < sgrm.getNegSamples() + 1) {
                block19: {
                    double g;
                    block17: {
                        float code;
                        float dot;
                        block18: {
                            block16: {
                                dot = dots.getFloat((long)e);
                                code = cnt == 0 ? 1.0f : 0.0f;
                                g = 0.0;
                                if (!(dot > 6.0f)) break block16;
                                g = (double)(code - 1.0f) * alpha;
                                break block17;
                            }
                            if (!(dot < -6.0f)) break block18;
                            g = (double)(code - 0.0f) * alpha;
                            break block17;
                        }
                        int idx = (int)((double)(dot + 6.0f) * ((double)((float)expTable.length() / 6.0f) / 2.0));
                        if ((long)idx >= expTable.length() || idx < 0) break block19;
                        g = ((double)code - expTable.getDouble((long)idx)) * alpha;
                    }
                    updated = true;
                    Nd4j.getBlasWrapper().axpy((Number)new Double(g), syn1Neg.getRow((long)sgrm.getNegatives()[cnt]), neu1e);
                    Nd4j.getBlasWrapper().axpy((Number)new Double(g), syn0.getRow((long)sgrm.getW2()), syn1Neg.getRow((long)sgrm.getNegatives()[cnt]));
                }
                ++e;
                ++cnt;
            }
        }
        if (updated) {
            Nd4j.getBlasWrapper().axpy((Number)new Double(1.0), neu1e, syn0.getRow((long)sgrm.getW2()));
        }
        if (this.completionHandler.isTrackingFrame(descriptor = RequestDescriptor.createDescriptor(chain.getOriginatorId(), chain.getFrameId()))) {
            this.completionHandler.notifyFrame(chain.getOriginatorId(), chain.getFrameId(), chain.getTaskId());
            if (this.completionHandler.isCompleted(descriptor)) {
                FrameCompletionHandler.FrameDescriptor frameDescriptor = this.completionHandler.getCompletedFrameInfo(descriptor);
                if (frameDescriptor != null) {
                    FrameCompleteMessage fcm = new FrameCompleteMessage(chain.getFrameId());
                    fcm.setOriginatorId(frameDescriptor.getFrameOriginatorId());
                    this.transport.sendMessage(fcm);
                } else {
                    log.warn("Frame double spending detected");
                }
            }
        } else {
            log.info("sI_{} isn't tracking this frame: Originator: {}, frameId: {}, taskId: {}", new Object[]{this.transport.getShardIndex(), chain.getOriginatorId(), chain.getFrameId(), taskId});
        }
        if (this.cntRounds.incrementAndGet() % 100000L == 0L) {
            log.info("{} training rounds finished...", (Object)this.cntRounds.get());
        }
        this.chains.remove(chainDesc);
    }
}

