/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.fasttext;

import ai.djl.Device;
import ai.djl.fasttext.FtTrainingMode;
import ai.djl.training.DataManager;
import ai.djl.training.TrainingConfig;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;

public class FtTrainingConfig
implements TrainingConfig {
    private FtTrainingMode trainingMode;
    private Path outputDir;
    private String modelName;
    private int epoch;
    private int minWordCount;
    private int minLabelCount;
    private int maxNgramLength;
    private int minCharLength;
    private int maxCharLength;
    private int bucket;
    private float samplingThreshold;
    private String labelPrefix;
    private float learningRate;
    private int learningRateUpdateRate;
    private int wordVecSize;
    private int contextWindow;
    private int numNegativesSampled;
    private int threads;
    private String loss;

    FtTrainingConfig(Builder builder) {
        this.trainingMode = builder.trainingMode;
        this.outputDir = builder.outputDir;
        this.modelName = builder.modelName;
        this.epoch = builder.epoch;
        this.minWordCount = builder.minWordCount;
        this.minLabelCount = builder.minLabelCount;
        this.maxNgramLength = builder.maxNgramLength;
        this.minCharLength = builder.minCharLength;
        this.maxCharLength = builder.maxCharLength;
        this.bucket = builder.bucket;
        this.samplingThreshold = builder.samplingThreshold;
        this.labelPrefix = builder.labelPrefix;
        this.learningRate = builder.learningRate;
        this.learningRateUpdateRate = builder.learningRateUpdateRate;
        this.wordVecSize = builder.wordVecSize;
        this.contextWindow = builder.contextWindow;
        this.numNegativesSampled = builder.numNegativesSampled;
        this.threads = builder.threads;
        this.loss = builder.loss;
    }

    public FtTrainingMode getTrainingMode() {
        return this.trainingMode;
    }

    public Path getOutputDir() {
        return this.outputDir;
    }

    public String getModelName() {
        return this.modelName;
    }

    public int getEpoch() {
        return this.epoch;
    }

    public int getMinWordCount() {
        return this.minWordCount;
    }

    public int getMinLabelCount() {
        return this.minLabelCount;
    }

    public int getMaxNgramLength() {
        return this.maxNgramLength;
    }

    public int getMinCharLength() {
        return this.minCharLength;
    }

    public int getMaxCharLength() {
        return this.maxCharLength;
    }

    public int getBucket() {
        return this.bucket;
    }

    public float getSamplingThreshold() {
        return this.samplingThreshold;
    }

    public String getLabelPrefix() {
        return this.labelPrefix;
    }

    public float getLearningRate() {
        return this.learningRate;
    }

    public int getLearningRateUpdateRate() {
        return this.learningRateUpdateRate;
    }

    public int getWordVecSize() {
        return this.wordVecSize;
    }

    public int getContextWindow() {
        return this.contextWindow;
    }

    public int getNumNegativesSampled() {
        return this.numNegativesSampled;
    }

    public int getThreads() {
        return this.threads;
    }

    public String getLoss() {
        return this.loss;
    }

    public Device[] getDevices() {
        return new Device[0];
    }

    public Initializer getInitializer() {
        return null;
    }

    public Optimizer getOptimizer() {
        return null;
    }

    public Loss getLossFunction() {
        return null;
    }

    public DataManager getDataManager() {
        return null;
    }

    public List<Evaluator> getEvaluators() {
        return null;
    }

    public List<TrainingListener> getTrainingListeners() {
        return null;
    }

    public String[] toCommand(String input) {
        Path modelFile = this.outputDir.resolve(this.modelName).toAbsolutePath();
        ArrayList<String> cmd = new ArrayList<String>();
        cmd.add("fasttext");
        cmd.add(this.trainingMode.name().toLowerCase());
        cmd.add("-input");
        cmd.add(input);
        cmd.add("-output");
        cmd.add(modelFile.toString());
        if (this.epoch >= 0) {
            cmd.add("-epoch");
            cmd.add(String.valueOf(this.epoch));
        }
        if (this.minWordCount >= 0) {
            cmd.add("-minCount");
            cmd.add(String.valueOf(this.minWordCount));
        }
        if (this.minLabelCount >= 0) {
            cmd.add("-minCountLabel");
            cmd.add(String.valueOf(this.minLabelCount));
        }
        if (this.maxNgramLength >= 0) {
            cmd.add("-wordNgrams");
            cmd.add(String.valueOf(this.maxNgramLength));
        }
        if (this.minCharLength >= 0) {
            cmd.add("-minn");
            cmd.add(String.valueOf(this.minCharLength));
        }
        if (this.maxCharLength >= 0) {
            cmd.add("-maxn");
            cmd.add(String.valueOf(this.maxCharLength));
        }
        if (this.bucket >= 0) {
            cmd.add("-bucket");
            cmd.add(String.valueOf(this.bucket));
        }
        if (this.samplingThreshold >= 0.0f) {
            cmd.add("-t");
            cmd.add(String.valueOf(this.samplingThreshold));
        }
        if (this.labelPrefix != null) {
            cmd.add("-label");
            cmd.add(this.labelPrefix);
        }
        if (this.learningRate >= 0.0f) {
            cmd.add("-lr");
            cmd.add(String.valueOf(this.learningRate));
        }
        if (this.learningRateUpdateRate >= 0) {
            cmd.add("-lrUpdateRate");
            cmd.add(String.valueOf(this.learningRateUpdateRate));
        }
        if (this.wordVecSize >= 0) {
            cmd.add("-dim");
            cmd.add(String.valueOf(this.wordVecSize));
        }
        if (this.contextWindow >= 0) {
            cmd.add("-ws");
            cmd.add(String.valueOf(this.contextWindow));
        }
        if (this.numNegativesSampled >= 0) {
            cmd.add("-neg");
            cmd.add(String.valueOf(this.numNegativesSampled));
        }
        if (this.threads >= 0) {
            cmd.add("-thread");
            cmd.add(String.valueOf(this.threads));
        }
        if (this.loss != null) {
            cmd.add("-loss");
            cmd.add(this.loss);
        }
        return cmd.toArray(new String[0]);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static enum FtLoss {
        NS,
        HS,
        SOFTMAX,
        OVA;

    }

    public static final class Builder {
        FtTrainingMode trainingMode = FtTrainingMode.SUPERVISED;
        Path outputDir;
        String modelName;
        int epoch = -1;
        int minWordCount = -1;
        int minLabelCount = -1;
        int maxNgramLength = -1;
        int minCharLength = -1;
        int maxCharLength = -1;
        int bucket = -1;
        float samplingThreshold = -1.0f;
        String labelPrefix;
        float learningRate = -1.0f;
        int learningRateUpdateRate = -1;
        int wordVecSize = -1;
        int contextWindow = -1;
        int numNegativesSampled = -1;
        int threads = -1;
        String loss;

        Builder() {
        }

        public Builder setOutputDir(Path outputDir) {
            this.outputDir = outputDir;
            return this;
        }

        public Builder setModelName(String modelName) {
            this.modelName = modelName;
            return this;
        }

        public Builder optTrainingMode(FtTrainingMode trainingMode) {
            this.trainingMode = trainingMode;
            return this;
        }

        public Builder optEpoch(int epoch) {
            this.epoch = epoch;
            return this;
        }

        public Builder optMinWordCount(int minWordCount) {
            this.minWordCount = minWordCount;
            return this;
        }

        public Builder optMinLabelCount(int minLabelCount) {
            this.minLabelCount = minLabelCount;
            return this;
        }

        public Builder optMaxNGramLength(int maxNgramLength) {
            this.maxNgramLength = maxNgramLength;
            return this;
        }

        public Builder optMinCharLength(int minCharLength) {
            this.minCharLength = minCharLength;
            return this;
        }

        public Builder optMaxCharLength(int maxCharLength) {
            this.maxCharLength = maxCharLength;
            return this;
        }

        public Builder optBucket(int bucket) {
            this.bucket = bucket;
            return this;
        }

        public Builder optSamplingThreshold(float samplingThreshold) {
            this.samplingThreshold = samplingThreshold;
            return this;
        }

        public Builder optLabelPrefix(String labelPrefix) {
            this.labelPrefix = labelPrefix;
            return this;
        }

        public Builder optLearningRate(float learningRate) {
            this.learningRate = learningRate;
            return this;
        }

        public Builder optLearningRateUpdateRate(int learningRateUpdateRate) {
            this.learningRateUpdateRate = learningRateUpdateRate;
            return this;
        }

        public Builder optWordVecSize(int wordVecSize) {
            this.wordVecSize = wordVecSize;
            return this;
        }

        public Builder optContextWindow(int contextWindow) {
            this.contextWindow = contextWindow;
            return this;
        }

        public Builder optNumNegativesSampled(int numNegativesSampled) {
            this.numNegativesSampled = numNegativesSampled;
            return this;
        }

        public Builder optThreads(int threads) {
            this.threads = threads;
            return this;
        }

        public Builder optLoss(FtLoss loss) {
            this.loss = loss.name().toLowerCase();
            return this;
        }

        public FtTrainingConfig build() {
            return new FtTrainingConfig(this);
        }
    }
}

