/*
 * Decompiled with CFR 0.152.
 */
package com.hankcs.hanlp.model.maxent;

import com.hankcs.hanlp.collection.dartsclone.Pair;
import com.hankcs.hanlp.collection.trie.DoubleArrayTrie;
import com.hankcs.hanlp.corpus.io.ByteArray;
import com.hankcs.hanlp.model.maxent.Context;
import com.hankcs.hanlp.model.maxent.EvalParameters;
import com.hankcs.hanlp.model.maxent.UniformPrior;
import com.hankcs.hanlp.utility.Predefine;
import com.hankcs.hanlp.utility.TextUtility;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;
import java.util.TreeMap;

public class MaxEntModel {
    int correctionConstant;
    double correctionParam;
    UniformPrior prior;
    protected String[] outcomeNames;
    EvalParameters evalParams;
    DoubleArrayTrie<Integer> pmap;

    public final double[] eval(String[] context) {
        return this.eval(context, new double[this.evalParams.getNumOutcomes()]);
    }

    public final List<Pair<String, Double>> predict(String[] context) {
        ArrayList<Pair<String, Double>> result = new ArrayList<Pair<String, Double>>(this.outcomeNames.length);
        double[] p = this.eval(context);
        for (int i = 0; i < p.length; ++i) {
            result.add(new Pair<String, Double>(this.outcomeNames[i], p[i]));
        }
        return result;
    }

    public final Pair<String, Double> predictBest(String[] context) {
        List<Pair<String, Double>> resultList = this.predict(context);
        double bestP = -1.0;
        Pair<String, Double> bestPair = null;
        for (Pair<String, Double> pair : resultList) {
            if (!(pair.getSecond() > bestP)) continue;
            bestP = pair.getSecond();
            bestPair = pair;
        }
        return bestPair;
    }

    public final List<Pair<String, Double>> predict(Collection<String> context) {
        return this.predict(context.toArray(new String[0]));
    }

    public final double[] eval(String[] context, double[] outsums) {
        assert (context != null);
        int[] scontexts = new int[context.length];
        for (int i = 0; i < context.length; ++i) {
            Integer ci = this.pmap.get(context[i]);
            scontexts[i] = ci == null ? -1 : ci;
        }
        this.prior.logPrior(outsums);
        return MaxEntModel.eval(scontexts, outsums, this.evalParams);
    }

    public static double[] eval(int[] context, double[] prior, EvalParameters model) {
        int oid;
        Context[] params = model.getParams();
        int[] numfeats = new int[model.getNumOutcomes()];
        double value = 1.0;
        for (int ci = 0; ci < context.length; ++ci) {
            if (context[ci] < 0) continue;
            Context predParams = params[context[ci]];
            int[] activeOutcomes = predParams.getOutcomes();
            double[] activeParameters = predParams.getParameters();
            for (int ai = 0; ai < activeOutcomes.length; ++ai) {
                int oid2;
                int n = oid2 = activeOutcomes[ai];
                numfeats[n] = numfeats[n] + 1;
                int n2 = oid2;
                prior[n2] = prior[n2] + activeParameters[ai] * value;
            }
        }
        double normal = 0.0;
        for (oid = 0; oid < model.getNumOutcomes(); ++oid) {
            prior[oid] = model.getCorrectionParam() != 0.0 ? Math.exp(prior[oid] * model.getConstantInverse() + (1.0 - (double)numfeats[oid] / model.getCorrectionConstant()) * model.getCorrectionParam()) : Math.exp(prior[oid] * model.getConstantInverse());
            normal += prior[oid];
        }
        oid = 0;
        while (oid < model.getNumOutcomes()) {
            int n = oid++;
            prior[n] = prior[n] / normal;
        }
        return prior;
    }

    public static MaxEntModel create(String path) {
        MaxEntModel m = new MaxEntModel();
        try {
            BufferedReader br = new BufferedReader(new InputStreamReader((InputStream)new FileInputStream(path), "UTF-8"));
            DataOutputStream out = new DataOutputStream(new FileOutputStream(path + ".bin"));
            br.readLine();
            m.correctionConstant = Integer.parseInt(br.readLine());
            out.writeInt(m.correctionConstant);
            m.correctionParam = Double.parseDouble(br.readLine());
            out.writeDouble(m.correctionParam);
            int numOutcomes = Integer.parseInt(br.readLine());
            out.writeInt(numOutcomes);
            String[] outcomeLabels = new String[numOutcomes];
            m.outcomeNames = outcomeLabels;
            for (int i = 0; i < numOutcomes; ++i) {
                outcomeLabels[i] = br.readLine();
                TextUtility.writeString(outcomeLabels[i], out);
            }
            int numOCTypes = Integer.parseInt(br.readLine());
            out.writeInt(numOCTypes);
            int[][] outcomePatterns = new int[numOCTypes][];
            for (int i = 0; i < numOCTypes; ++i) {
                StringTokenizer tok = new StringTokenizer(br.readLine(), " ");
                int[] infoInts = new int[tok.countTokens()];
                out.writeInt(infoInts.length);
                int j = 0;
                while (tok.hasMoreTokens()) {
                    infoInts[j] = Integer.parseInt(tok.nextToken());
                    out.writeInt(infoInts[j]);
                    ++j;
                }
                outcomePatterns[i] = infoInts;
            }
            int NUM_PREDS = Integer.parseInt(br.readLine());
            out.writeInt(NUM_PREDS);
            String[] predLabels = new String[NUM_PREDS];
            m.pmap = new DoubleArrayTrie();
            TreeMap<String, Integer> tmpMap = new TreeMap<String, Integer>();
            for (int i = 0; i < NUM_PREDS; ++i) {
                predLabels[i] = br.readLine();
                assert (!tmpMap.containsKey(predLabels[i])) : "\u91cd\u590d\u7684\u952e\uff1a " + predLabels[i] + " \u8bf7\u4f7f\u7528 -Dfile.encoding=UTF-8 \u8bad\u7ec3";
                TextUtility.writeString(predLabels[i], out);
                tmpMap.put(predLabels[i], i);
            }
            m.pmap.build(tmpMap);
            for (Map.Entry entry : tmpMap.entrySet()) {
                out.writeInt((Integer)entry.getValue());
            }
            m.pmap.save(out);
            Context[] params = new Context[NUM_PREDS];
            int pid = 0;
            for (int i = 0; i < outcomePatterns.length; ++i) {
                int[] outcomePattern = new int[outcomePatterns[i].length - 1];
                for (int k = 1; k < outcomePatterns[i].length; ++k) {
                    outcomePattern[k - 1] = outcomePatterns[i][k];
                }
                for (int j = 0; j < outcomePatterns[i][0]; ++j) {
                    double[] contextParameters = new double[outcomePatterns[i].length - 1];
                    for (int k = 1; k < outcomePatterns[i].length; ++k) {
                        contextParameters[k - 1] = Double.parseDouble(br.readLine());
                        out.writeDouble(contextParameters[k - 1]);
                    }
                    params[pid] = new Context(outcomePattern, contextParameters);
                    ++pid;
                }
            }
            m.prior = new UniformPrior();
            m.prior.setLabels(outcomeLabels);
            m.evalParams = new EvalParameters(params, m.correctionParam, m.correctionConstant, outcomeLabels.length);
            out.close();
        }
        catch (Exception e) {
            Predefine.logger.severe("\u4ece" + path + "\u52a0\u8f7d\u6700\u5927\u71b5\u6a21\u578b\u5931\u8d25\uff01" + TextUtility.exceptionToString(e));
            return null;
        }
        return m;
    }

    public static MaxEntModel create(ByteArray byteArray) {
        MaxEntModel m = new MaxEntModel();
        m.correctionConstant = byteArray.nextInt();
        m.correctionParam = byteArray.nextDouble();
        int numOutcomes = byteArray.nextInt();
        String[] outcomeLabels = new String[numOutcomes];
        m.outcomeNames = outcomeLabels;
        for (int i = 0; i < numOutcomes; ++i) {
            outcomeLabels[i] = byteArray.nextString();
        }
        int numOCTypes = byteArray.nextInt();
        int[][] outcomePatterns = new int[numOCTypes][];
        for (int i = 0; i < numOCTypes; ++i) {
            int length = byteArray.nextInt();
            int[] infoInts = new int[length];
            for (int j = 0; j < length; ++j) {
                infoInts[j] = byteArray.nextInt();
            }
            outcomePatterns[i] = infoInts;
        }
        int NUM_PREDS = byteArray.nextInt();
        String[] predLabels = new String[NUM_PREDS];
        m.pmap = new DoubleArrayTrie();
        for (int i = 0; i < NUM_PREDS; ++i) {
            predLabels[i] = byteArray.nextString();
        }
        Integer[] v = new Integer[NUM_PREDS];
        for (int i = 0; i < v.length; ++i) {
            v[i] = byteArray.nextInt();
        }
        m.pmap.load(byteArray, (Integer[])v);
        Context[] params = new Context[NUM_PREDS];
        int pid = 0;
        for (int i = 0; i < outcomePatterns.length; ++i) {
            int[] outcomePattern = new int[outcomePatterns[i].length - 1];
            for (int k = 1; k < outcomePatterns[i].length; ++k) {
                outcomePattern[k - 1] = outcomePatterns[i][k];
            }
            for (int j = 0; j < outcomePatterns[i][0]; ++j) {
                double[] contextParameters = new double[outcomePatterns[i].length - 1];
                for (int k = 1; k < outcomePatterns[i].length; ++k) {
                    contextParameters[k - 1] = byteArray.nextDouble();
                }
                params[pid] = new Context(outcomePattern, contextParameters);
                ++pid;
            }
        }
        m.prior = new UniformPrior();
        m.prior.setLabels(outcomeLabels);
        m.evalParams = new EvalParameters(params, m.correctionParam, m.correctionConstant, outcomeLabels.length);
        return m;
    }

    public static MaxEntModel load(String txtPath) {
        ByteArray byteArray = ByteArray.createByteArray(txtPath + ".bin");
        if (byteArray != null) {
            return MaxEntModel.create(byteArray);
        }
        return MaxEntModel.create(txtPath);
    }
}

