/*
 * Decompiled with CFR 0.152.
 */
package com.hankcs.hanlp.model.perceptron;

import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.classification.utilities.io.ConsoleLogger;
import com.hankcs.hanlp.collection.trie.DoubleArrayTrie;
import com.hankcs.hanlp.corpus.document.sentence.Sentence;
import com.hankcs.hanlp.model.perceptron.InstanceConsumer;
import com.hankcs.hanlp.model.perceptron.common.FrequencyMap;
import com.hankcs.hanlp.model.perceptron.feature.ImmutableFeatureHashMap;
import com.hankcs.hanlp.model.perceptron.feature.MutableFeatureMap;
import com.hankcs.hanlp.model.perceptron.instance.Instance;
import com.hankcs.hanlp.model.perceptron.instance.InstanceHandler;
import com.hankcs.hanlp.model.perceptron.model.AveragedPerceptron;
import com.hankcs.hanlp.model.perceptron.model.LinearModel;
import com.hankcs.hanlp.model.perceptron.model.StructuredPerceptron;
import com.hankcs.hanlp.model.perceptron.tagset.TagSet;
import com.hankcs.hanlp.model.perceptron.utility.IOUtility;
import com.hankcs.hanlp.model.perceptron.utility.Utility;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.LinkedList;

public abstract class PerceptronTrainer
extends InstanceConsumer {
    protected abstract TagSet createTagSet();

    public Result train(String trainingFile, String developFile, String modelFile, double compressRatio, int maxIteration, int threadNum) throws IOException {
        LinearModel model;
        if (developFile == null) {
            developFile = trainingFile;
        }
        TagSet tagSet = this.createTagSet();
        MutableFeatureMap mutableFeatureMap = new MutableFeatureMap(tagSet);
        ConsoleLogger logger = new ConsoleLogger();
        logger.start("\u5f00\u59cb\u52a0\u8f7d\u8bad\u7ec3\u96c6...\n", new Object[0]);
        Instance[] instances = this.loadTrainInstances(trainingFile, mutableFeatureMap);
        tagSet.lock();
        logger.finish("\n\u52a0\u8f7d\u5b8c\u6bd5\uff0c\u5b9e\u4f8b\u4e00\u5171%d\u53e5\uff0c\u7279\u5f81\u603b\u6570%d\n", instances.length, mutableFeatureMap.size() * tagSet.size());
        ImmutableFeatureHashMap immutableFeatureMap = new ImmutableFeatureHashMap(mutableFeatureMap.featureIdMap.entrySet(), tagSet);
        mutableFeatureMap = null;
        double[] accuracy = null;
        if (threadNum == 1) {
            model = new AveragedPerceptron(immutableFeatureMap);
            double[] total = new double[((AveragedPerceptron)model).parameter.length];
            int[] timestamp = new int[((AveragedPerceptron)model).parameter.length];
            int current = 0;
            for (int iter = 1; iter <= maxIteration; ++iter) {
                Utility.shuffleArray(instances);
                for (Instance instance : instances) {
                    ++current;
                    int[] guessLabel = new int[instance.length()];
                    model.viterbiDecode(instance, guessLabel);
                    for (int i = 0; i < instance.length(); ++i) {
                        int[] featureVector = instance.getFeatureAt(i);
                        int[] goldFeature = new int[featureVector.length];
                        int[] predFeature = new int[featureVector.length];
                        for (int j = 0; j < featureVector.length - 1; ++j) {
                            goldFeature[j] = featureVector[j] * tagSet.size() + instance.tagArray[i];
                            predFeature[j] = featureVector[j] * tagSet.size() + guessLabel[i];
                        }
                        goldFeature[featureVector.length - 1] = (i == 0 ? tagSet.bosId() : instance.tagArray[i - 1]) * tagSet.size() + instance.tagArray[i];
                        predFeature[featureVector.length - 1] = (i == 0 ? tagSet.bosId() : guessLabel[i - 1]) * tagSet.size() + guessLabel[i];
                        ((AveragedPerceptron)model).update(goldFeature, predFeature, total, timestamp, current);
                    }
                }
                accuracy = trainingFile.equals(developFile) ? IOUtility.evaluate(instances, model) : this.evaluate(developFile, model);
                System.out.printf("Iter#%d - ", iter);
                this.printAccuracy(accuracy);
            }
            ((AveragedPerceptron)model).average(total, timestamp, current);
            accuracy = trainingFile.equals(developFile) ? IOUtility.evaluate(instances, model) : this.evaluate(developFile, model);
            System.out.print("AP - ");
            this.printAccuracy(accuracy);
            logger.start("\u4ee5\u538b\u7f29\u6bd4 %.2f \u4fdd\u5b58\u6a21\u578b\u5230 %s ... ", compressRatio, modelFile);
            model.save(modelFile, immutableFeatureMap.featureIdMap.entrySet(), compressRatio);
            logger.finish(" \u4fdd\u5b58\u5b8c\u6bd5\n", new Object[0]);
            if (compressRatio == 0.0) {
                return new Result(model, accuracy);
            }
        } else {
            StructuredPerceptron[] models = new StructuredPerceptron[threadNum];
            for (int i = 0; i < models.length; ++i) {
                models[i] = new StructuredPerceptron(immutableFeatureMap);
            }
            TrainingWorker[] workers = new TrainingWorker[threadNum];
            int job = instances.length / threadNum;
            for (int iter = 1; iter <= maxIteration; ++iter) {
                Utility.shuffleArray(instances);
                try {
                    int i;
                    for (i = 0; i < workers.length; ++i) {
                        workers[i] = new TrainingWorker(instances, i * job, i == workers.length - 1 ? instances.length : (i + 1) * job, models[i]);
                        workers[i].start();
                    }
                    for (TrainingWorker worker : workers) {
                        worker.join();
                    }
                    int j = 0;
                    while (j < models[0].parameter.length) {
                        for (int i2 = 1; i2 < models.length; ++i2) {
                            int n = j;
                            models[0].parameter[n] = models[0].parameter[n] + models[i2].parameter[j];
                        }
                        int n = j++;
                        models[0].parameter[n] = models[0].parameter[n] / (float)threadNum;
                    }
                    for (i = 1; i < models.length; ++i) {
                        System.arraycopy(models[0].parameter, 0, models[i].parameter, 0, models[0].parameter.length);
                    }
                    accuracy = trainingFile.equals(developFile) ? IOUtility.evaluate(instances, models[0]) : this.evaluate(developFile, models[0]);
                    System.out.printf("Iter#%d - ", iter);
                    this.printAccuracy(accuracy);
                    continue;
                }
                catch (InterruptedException e) {
                    System.err.printf("\u7ebf\u7a0b\u540c\u6b65\u5f02\u5e38\uff0c\u8bad\u7ec3\u5931\u8d25\n", new Object[0]);
                    e.printStackTrace();
                    return null;
                }
            }
            logger.start("\u4ee5\u538b\u7f29\u6bd4 %.2f \u4fdd\u5b58\u6a21\u578b\u5230 %s ... ", compressRatio, modelFile);
            models[0].save(modelFile, immutableFeatureMap.featureIdMap.entrySet(), compressRatio, HanLP.Config.DEBUG);
            logger.finish(" \u4fdd\u5b58\u5b8c\u6bd5\n", new Object[0]);
            if (compressRatio == 0.0) {
                return new Result(models[0], accuracy);
            }
        }
        model = new LinearModel(modelFile);
        if (compressRatio > 0.0) {
            accuracy = this.evaluate(developFile, model);
            System.out.printf("\n%.2f compressed model - ", compressRatio);
            this.printAccuracy(accuracy);
        }
        return new Result(model, accuracy);
    }

    private void printAccuracy(double[] accuracy) {
        if (accuracy.length == 3) {
            System.out.printf("P:%.2f R:%.2f F:%.2f\n", accuracy[0], accuracy[1], accuracy[2]);
        } else {
            System.out.printf("P:%.2f\n", accuracy[0]);
        }
    }

    protected Instance[] loadTrainInstances(String trainingFile, final MutableFeatureMap mutableFeatureMap) throws IOException {
        final LinkedList instanceList = new LinkedList();
        IOUtility.loadInstance(trainingFile, new InstanceHandler(){

            @Override
            public boolean process(Sentence sentence) {
                Utility.normalize(sentence);
                instanceList.add(PerceptronTrainer.this.createInstance(sentence, mutableFeatureMap));
                return false;
            }
        });
        Instance[] instances = new Instance[instanceList.size()];
        instanceList.toArray(instances);
        return instances;
    }

    private static DoubleArrayTrie<Integer> loadDictionary(String trainingFile, String dictionaryFile) throws IOException {
        FrequencyMap dictionaryMap = new FrequencyMap();
        if (dictionaryFile == null) {
            System.out.printf("\u4ece\u8bad\u7ec3\u6587\u4ef6%s\u4e2d\u7edf\u8ba1\u8bcd\u5e93...\n", trainingFile);
            PerceptronTrainer.loadWordFromFile(trainingFile, dictionaryMap, true);
        } else {
            System.out.printf("\u4ece\u5916\u90e8\u8bcd\u5178%s\u4e2d\u52a0\u8f7d\u8bcd\u5e93...\n", trainingFile);
            PerceptronTrainer.loadWordFromFile(dictionaryFile, dictionaryMap, false);
        }
        DoubleArrayTrie<Integer> dat = new DoubleArrayTrie<Integer>();
        dat.build(dictionaryMap);
        System.out.printf("\u52a0\u8f7d\u5b8c\u6bd5\uff0c\u8bcd\u5e93\u603b\u8bcd\u6570\uff1a%d\uff0c\u603b\u8bcd\u9891\uff1a%d\n", dictionaryMap.size(), dictionaryMap.totalFrequency);
        return dat;
    }

    public Result train(String trainingFile, String modelFile) throws IOException {
        return this.train(trainingFile, trainingFile, modelFile);
    }

    public Result train(String trainingFile, String developFile, String modelFile) throws IOException {
        return this.train(trainingFile, developFile, modelFile, 0.1, 10, Runtime.getRuntime().availableProcessors());
    }

    private static void loadWordFromFile(String path, FrequencyMap storage, boolean segmented) throws IOException {
        String line;
        BufferedReader br = IOUtility.newBufferedReader(path);
        while ((line = br.readLine()) != null) {
            if (segmented) {
                for (String word : IOUtility.readLineToArray(line)) {
                    storage.add(word);
                }
                continue;
            }
            if ((line = line.trim()).length() == 0) continue;
            storage.add(line);
        }
        br.close();
    }

    private static class TrainingWorker
    extends Thread {
        private Instance[] instances;
        private int start;
        private int end;
        private StructuredPerceptron model;

        public TrainingWorker(Instance[] instances, int start, int end, StructuredPerceptron model) {
            this.instances = instances;
            this.start = start;
            this.end = end;
            this.model = model;
        }

        @Override
        public void run() {
            for (int s = this.start; s < this.end; ++s) {
                Instance instance = this.instances[s];
                this.model.update(instance);
            }
        }
    }

    public static class Result {
        LinearModel model;
        double[] prf;

        public Result(LinearModel model, double[] prf) {
            this.model = model;
            this.prf = prf;
        }

        public double getAccuracy() {
            if (this.prf.length == 3) {
                return this.prf[2];
            }
            return this.prf[0];
        }

        public LinearModel getModel() {
            return this.model;
        }
    }
}

