/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.aggregates.impl;

import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.aggregates.BaseAggregate;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AggregateSkipGram
extends BaseAggregate {
    private static final Logger log = LoggerFactory.getLogger(AggregateSkipGram.class);
    private int vectorLength;

    public AggregateSkipGram(INDArray syn0, INDArray syn1, INDArray syn1Neg, INDArray expTable, INDArray negTable, int idxSyn0, int[] idxSyn1, int[] codes, int negativeRounds, int ngStarter, int vectorLength, double alpha, long nextRandom, int vocabSize, INDArray inferenceVector) {
        this(syn0, syn1, syn1Neg, expTable, negTable, idxSyn0, idxSyn1, codes, negativeRounds, ngStarter, vectorLength, alpha, nextRandom, vocabSize);
        this.arguments.set(5, inferenceVector);
        this.indexingArguments.set(8, inferenceVector == null ? 0 : 1);
    }

    public AggregateSkipGram(@NonNull INDArray syn0, INDArray syn1, INDArray syn1Neg, @NonNull INDArray expTable, INDArray negTable, int idxSyn0, int[] idxSyn1, int[] codes, int negativeRounds, int ngStarter, int vectorLength, double alpha, long nextRandom, int vocabSize) {
        if (syn0 == null) {
            throw new NullPointerException("syn0 is marked @NonNull but is null");
        }
        if (expTable == null) {
            throw new NullPointerException("expTable is marked @NonNull but is null");
        }
        this.indexingArguments.add(idxSyn0);
        this.indexingArguments.add(vectorLength);
        this.indexingArguments.add(idxSyn1.length);
        this.indexingArguments.add(negativeRounds);
        this.indexingArguments.add((int)expTable.length());
        this.indexingArguments.add(vocabSize);
        this.indexingArguments.add(ngStarter);
        this.indexingArguments.add(negTable == null ? 0 : (int)negTable.length());
        this.indexingArguments.add(0);
        this.arguments.add(syn0);
        this.arguments.add(syn1);
        this.arguments.add(expTable);
        this.arguments.add(syn1Neg);
        this.arguments.add(negTable);
        this.arguments.add(null);
        this.intArrayArguments.add(idxSyn1);
        this.intArrayArguments.add(codes);
        this.realArguments.add(alpha);
        this.realArguments.add(Double.valueOf(nextRandom));
        this.vectorLength = vectorLength;
    }

    public AggregateSkipGram(int w1, int w2, int[] codes, int[] points, int negSamples, double lr, int vectorLength) {
        this.indexingArguments.add(w1);
        this.indexingArguments.add(w2);
        this.indexingArguments.add(vectorLength);
        this.intArrayArguments.add(codes);
        this.intArrayArguments.add(points);
        this.realArguments.add(lr);
    }

    @Override
    public int getSharedMemorySize() {
        return this.vectorLength * Nd4j.sizeOfDataType() + 512;
    }

    @Override
    public int getThreadsPerInstance() {
        if (this.vectorLength > 768) {
            return 768;
        }
        return this.vectorLength;
    }

    @Override
    public String name() {
        return "aggregate_skipgram";
    }

    @Override
    public int opNum() {
        return 3;
    }

    @Override
    public int maxArguments() {
        return 6;
    }

    @Override
    public int maxShapes() {
        return 0;
    }

    @Override
    public int maxIntArrays() {
        return 2;
    }

    @Override
    public int maxIntArraySize() {
        return 40;
    }

    @Override
    public int maxIndexArguments() {
        return 10;
    }

    @Override
    public int maxRealArguments() {
        return 2;
    }
}

