/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.aggregates.impl;

import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.aggregates.BaseAggregate;
import org.nd4j.linalg.factory.Nd4j;

public class AggregateCBOW
extends BaseAggregate {
    private int vectorLength;

    public AggregateCBOW(@NonNull INDArray syn0, INDArray syn1, INDArray syn1Neg, @NonNull INDArray expTable, INDArray negTable, int wordIdx, int[] idxSyn0, int[] idxSyn1, int[] codes, int negativeRounds, int ngStarter, int vectorLength, double alpha, long nextRandom, int vocabSize, int numLabels, boolean trainWords, INDArray inferenceVector) {
        this(syn0, syn1, syn1Neg, expTable, negTable, wordIdx, idxSyn0, idxSyn1, codes, negativeRounds, ngStarter, vectorLength, alpha, nextRandom, vocabSize);
        if (syn0 == null) {
            throw new NullPointerException("syn0 is marked @NonNull but is null");
        }
        if (expTable == null) {
            throw new NullPointerException("expTable is marked @NonNull but is null");
        }
        this.indexingArguments.set(9, numLabels);
        this.indexingArguments.set(10, trainWords ? 1 : 0);
        this.indexingArguments.set(11, inferenceVector == null ? 0 : 1);
        this.arguments.set(5, inferenceVector);
    }

    public AggregateCBOW(@NonNull INDArray syn0, INDArray syn1, INDArray syn1Neg, @NonNull INDArray expTable, INDArray negTable, int wordIdx, int[] idxSyn0, int[] idxSyn1, int[] codes, int negativeRounds, int ngStarter, int vectorLength, double alpha, long nextRandom, int vocabSize) {
        if (syn0 == null) {
            throw new NullPointerException("syn0 is marked @NonNull but is null");
        }
        if (expTable == null) {
            throw new NullPointerException("expTable is marked @NonNull but is null");
        }
        this.indexingArguments.add(vectorLength);
        this.indexingArguments.add(idxSyn1.length);
        this.indexingArguments.add(negativeRounds);
        this.indexingArguments.add((int)expTable.length());
        this.indexingArguments.add(vocabSize);
        this.indexingArguments.add(ngStarter);
        this.indexingArguments.add(negTable == null ? 0 : (int)negTable.length());
        this.indexingArguments.add(idxSyn0.length);
        this.indexingArguments.add(wordIdx);
        this.indexingArguments.add(0);
        this.indexingArguments.add(1);
        this.indexingArguments.add(0);
        this.arguments.add(syn0);
        this.arguments.add(syn1);
        this.arguments.add(expTable);
        this.arguments.add(syn1Neg);
        this.arguments.add(negTable);
        this.arguments.add(null);
        this.intArrayArguments.add(idxSyn0);
        this.intArrayArguments.add(idxSyn1);
        this.intArrayArguments.add(codes);
        this.realArguments.add(alpha);
        this.realArguments.add(Double.valueOf(nextRandom));
        this.vectorLength = vectorLength;
    }

    @Override
    public String name() {
        return "aggregate_cbow";
    }

    @Override
    public int opNum() {
        return 4;
    }

    @Override
    public int maxArguments() {
        return 6;
    }

    @Override
    public int maxShapes() {
        return 0;
    }

    @Override
    public int maxIntArrays() {
        return 3;
    }

    @Override
    public int maxIntArraySize() {
        return 40;
    }

    @Override
    public int maxIndexArguments() {
        return 12;
    }

    @Override
    public int maxRealArguments() {
        return 2;
    }

    @Override
    public int getSharedMemorySize() {
        return this.vectorLength * Nd4j.sizeOfDataType() * 2 + 512;
    }

    @Override
    public int getThreadsPerInstance() {
        if (this.vectorLength > 768) {
            return 768;
        }
        return this.vectorLength;
    }
}

