/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.hmm;

import com.aliasi.hmm.HiddenMarkovModel;
import com.aliasi.hmm.TagWordLattice;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.tag.MarginalTagger;
import com.aliasi.tag.NBestTagger;
import com.aliasi.tag.ScoredTagging;
import com.aliasi.tag.TagLattice;
import com.aliasi.tag.Tagger;
import com.aliasi.tag.Tagging;
import com.aliasi.util.BoundedPriorityQueue;
import com.aliasi.util.Iterators;
import com.aliasi.util.Scored;
import com.aliasi.util.ScoredObject;
import com.aliasi.util.Strings;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class HmmDecoder
implements Tagger<String>,
NBestTagger<String>,
MarginalTagger<String> {
    private final HiddenMarkovModel mHmm;
    private Map<String, double[]> mEmissionCache;
    private Map<String, double[]> mEmissionLog2Cache;
    private double mLog2EmissionBeam;
    private double mLog2Beam;

    public HmmDecoder(HiddenMarkovModel hmm) {
        this(hmm, null, null);
    }

    public HmmDecoder(HiddenMarkovModel hmm, Map<String, double[]> emissionCache, Map<String, double[]> emissionLog2Cache) {
        this(hmm, emissionCache, emissionLog2Cache, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY);
    }

    public HmmDecoder(HiddenMarkovModel hmm, Map<String, double[]> emissionCache, Map<String, double[]> emissionLog2Cache, double log2Beam, double log2EmissionBeam) {
        this.mHmm = hmm;
        this.mEmissionCache = emissionCache;
        this.mEmissionLog2Cache = emissionLog2Cache;
        this.setLog2Beam(log2Beam);
        this.setLog2EmissionBeam(log2EmissionBeam);
    }

    public HiddenMarkovModel getHmm() {
        return this.mHmm;
    }

    public Map<String, double[]> emissionCache() {
        return this.mEmissionCache;
    }

    public Map<String, double[]> emissionLog2Cache() {
        return this.mEmissionLog2Cache;
    }

    public void setEmissionCache(Map<String, double[]> cache) {
        this.mEmissionCache = cache;
    }

    public void setLog2EmissionBeam(double log2EmissionBeam) {
        if (log2EmissionBeam <= 0.0 || Double.isNaN(log2EmissionBeam)) {
            String msg = "Beam width must be a positive number. Found log2EmissionBeam=" + log2EmissionBeam;
            throw new IllegalArgumentException(msg);
        }
        this.mLog2EmissionBeam = log2EmissionBeam;
    }

    public void setLog2Beam(double log2Beam) {
        if (log2Beam <= 0.0 || Double.isNaN(log2Beam)) {
            String msg = "Beam width must be a positive number. Found log2EmissionBeam=" + log2Beam;
            throw new IllegalArgumentException(msg);
        }
        this.mLog2Beam = log2Beam;
    }

    public void setEmissionLog2Cache(Map<String, double[]> cache) {
        this.mEmissionLog2Cache = cache;
    }

    double[] cachedEmitProbs(String emission) {
        double[] emitProbs = this.mEmissionCache.get(emission);
        if (emitProbs != null) {
            return emitProbs;
        }
        emitProbs = this.computeEmitProbs(emission);
        this.mEmissionCache.put(emission, emitProbs);
        return emitProbs;
    }

    double[] computeEmitProbs(String emission) {
        int numTags = this.mHmm.stateSymbolTable().numSymbols();
        double[] emitProbs = new double[numTags];
        for (int i = 0; i < numTags; ++i) {
            emitProbs[i] = this.mHmm.emitProb(i, (CharSequence)emission);
        }
        return emitProbs;
    }

    double[] emitProbs(String emission) {
        return this.mEmissionCache == null ? this.computeEmitProbs(emission) : this.cachedEmitProbs(emission);
    }

    double[] cachedEmitLog2Probs(String emission) {
        double[] emitLog2Probs = this.mEmissionLog2Cache.get(emission);
        if (emitLog2Probs != null) {
            return emitLog2Probs;
        }
        emitLog2Probs = this.computeEmitLog2Probs(emission);
        this.mEmissionLog2Cache.put(emission, emitLog2Probs);
        return emitLog2Probs;
    }

    double[] computeEmitLog2Probs(String emission) {
        int numTags = this.mHmm.stateSymbolTable().numSymbols();
        double[] emitLog2Probs = new double[numTags];
        for (int i = 0; i < numTags; ++i) {
            emitLog2Probs[i] = this.mHmm.emitLog2Prob(i, (CharSequence)emission);
        }
        HmmDecoder.additiveBeamPrune(emitLog2Probs, this.mLog2EmissionBeam);
        return emitLog2Probs;
    }

    static void additiveBeamPrune(double[] emitLog2Probs, double beam) {
        int i;
        if (beam == Double.POSITIVE_INFINITY) {
            return;
        }
        double best = emitLog2Probs[0];
        for (i = 1; i < emitLog2Probs.length; ++i) {
            if (!(emitLog2Probs[i] > best)) continue;
            best = emitLog2Probs[i];
        }
        for (i = 1; i < emitLog2Probs.length; ++i) {
            if (!(emitLog2Probs[i] + beam < best)) continue;
            emitLog2Probs[i] = Double.NEGATIVE_INFINITY;
        }
    }

    double[] emitLog2Probs(String emission) {
        return this.mEmissionLog2Cache == null ? this.computeEmitLog2Probs(emission) : this.cachedEmitLog2Probs(emission);
    }

    TagWordLattice lattice(String[] emissions) {
        int numTokens = emissions.length;
        int numTags = this.mHmm.stateSymbolTable().numSymbols();
        if (numTokens == 0) {
            return new TagWordLattice(emissions, this.mHmm.stateSymbolTable(), new double[numTags], new double[numTags], new double[0][numTags][numTags]);
        }
        double[] starts = new double[numTags];
        double[] emitProbs = this.emitProbs(emissions[0]);
        for (int tagId = 0; tagId < numTags; ++tagId) {
            starts[tagId] = this.mHmm.startProb(tagId) * emitProbs[tagId];
        }
        double[][][] transitions = new double[numTokens][][];
        for (int i = 1; i < numTokens; ++i) {
            double[][] transitionsI = new double[numTags][];
            transitions[i] = transitionsI;
            double[] emitProbs2 = this.emitProbs(emissions[i]);
            for (int prevTagId = 0; prevTagId < numTags; ++prevTagId) {
                double[] transitionsIPrevTag = new double[numTags];
                transitions[i][prevTagId] = transitionsIPrevTag;
                for (int tagId = 0; tagId < numTags; ++tagId) {
                    double transitEstimate = this.mHmm.transitProb(prevTagId, tagId);
                    transitionsIPrevTag[tagId] = transitEstimate * emitProbs2[tagId];
                }
            }
        }
        double[] ends = new double[numTags];
        for (int tagId = 0; tagId < numTags; ++tagId) {
            ends[tagId] = this.mHmm.endProb(tagId);
        }
        return new TagWordLattice(emissions, this.mHmm.stateSymbolTable(), starts, ends, transitions);
    }

    String[] firstBest(String[] emissions) {
        if (emissions.length == 0) {
            return Strings.EMPTY_STRING_ARRAY;
        }
        return new Viterbi(emissions).bestStates();
    }

    Iterator<ScoredObject<String[]>> nBest(String[] emissions) {
        if (emissions.length == 0) {
            ScoredObject<String[]> result = new ScoredObject<String[]>(Strings.EMPTY_STRING_ARRAY, 0.0);
            return Iterators.singleton(result);
        }
        Viterbi viterbiLattice = new Viterbi(emissions);
        return new NBestIterator(viterbiLattice, Integer.MAX_VALUE);
    }

    Iterator<ScoredObject<String[]>> nBest(String[] emissions, int maxN) {
        if (emissions.length == 0) {
            ScoredObject<String[]> result = new ScoredObject<String[]>(Strings.EMPTY_STRING_ARRAY, 0.0);
            return Iterators.singleton(result);
        }
        Viterbi viterbiLattice = new Viterbi(emissions);
        return new NBestIterator(viterbiLattice, maxN);
    }

    Iterator<ScoredObject<String[]>> nBestConditional(String[] emissions) {
        Iterator<ScoredObject<String[]>> nBestIterator = this.nBest(emissions);
        double jointLog2Prob = this.lattice(emissions).log2Total();
        return new JointIterator(nBestIterator, jointLog2Prob);
    }

    @Override
    public Tagging<String> tag(List<String> tokens) {
        String[] tokenArray = tokens.toArray(Strings.EMPTY_STRING_ARRAY);
        String[] tags = this.firstBest(tokenArray);
        return new Tagging<String>(Arrays.asList(tokenArray), Arrays.asList(tags));
    }

    @Override
    public Iterator<ScoredTagging<String>> tagNBest(List<String> tokens, int maxResults) {
        String[] tokenArray = tokens.toArray(Strings.EMPTY_STRING_ARRAY);
        Iterator<ScoredObject<String[]>> it = this.nBest(tokenArray, maxResults);
        return new TaggingIteratorAdapter(tokens, it, maxResults);
    }

    @Override
    public Iterator<ScoredTagging<String>> tagNBestConditional(List<String> tokens, int maxResults) {
        String[] tokenArray = tokens.toArray(Strings.EMPTY_STRING_ARRAY);
        Iterator<ScoredObject<String[]>> it = this.nBestConditional(tokenArray);
        return new TaggingIteratorAdapter(tokens, it, maxResults);
    }

    @Override
    public TagLattice<String> tagMarginal(List<String> tokens) {
        String[] tokenArray = tokens.toArray(Strings.EMPTY_STRING_ARRAY);
        return this.lattice(tokenArray);
    }

    void unprunedSources(double[] sources, int[] survivors, double beam) {
        double best = sources[0];
        for (int i = 0; i < sources.length; ++i) {
            if (!(sources[i] > best)) continue;
            best = sources[i];
        }
        int next = 0;
        for (int i = 0; i < sources.length; ++i) {
            if (!(sources[i] + beam >= best)) continue;
            survivors[next++] = i;
        }
        survivors[next] = -1;
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private static final class JointIterator
    extends Iterators.Modifier<ScoredObject<String[]>> {
        final double mLog2TotalProb;

        JointIterator(Iterator<ScoredObject<String[]>> nBestIterator, double log2TotalProb) {
            super(nBestIterator);
            this.mLog2TotalProb = log2TotalProb;
        }

        @Override
        public ScoredObject<String[]> modify(ScoredObject<String[]> jointObj) {
            String[] tags = jointObj.getObject();
            double log2JointProb = jointObj.score();
            double log2CondProb = log2JointProb - this.mLog2TotalProb;
            return new ScoredObject<String[]>(tags, log2CondProb);
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private final class State
    implements Scored {
        private final double mScore;
        private final double mContScore;
        private final int mTagId;
        private final State mPreviousState;
        private final int mEmissionIndex;

        State(int emissionIndex, double score, double contScore, int tagId, State previousState) {
            this.mEmissionIndex = emissionIndex;
            this.mScore = score;
            this.mContScore = contScore;
            this.mTagId = tagId;
            this.mPreviousState = previousState;
        }

        public int emissionIndex() {
            return this.mEmissionIndex;
        }

        @Override
        public double score() {
            return this.mScore + this.mContScore;
        }

        ScoredObject<String[]> result(int numTags) {
            return new ScoredObject<String[]>(this.tags(numTags), this.score());
        }

        String[] tags(int numTags) {
            SymbolTable symTable = HmmDecoder.this.mHmm.stateSymbolTable();
            String[] tags = new String[numTags];
            State state = this;
            for (int i = 0; i < numTags; ++i) {
                tags[i] = symTable.idToSymbol(state.mTagId);
                state = state.mPreviousState;
            }
            return tags;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class NBestIterator
    extends Iterators.Buffered<ScoredObject<String[]>> {
        private final Viterbi mViterbi;
        private final BoundedPriorityQueue<State> mPQ;

        NBestIterator(Viterbi vit, int maxSize) {
            this.mViterbi = vit;
            this.mPQ = new BoundedPriorityQueue(ScoredObject.comparator(), maxSize);
            String[] emissions = vit.mEmissions;
            int numStates = HmmDecoder.this.mHmm.stateSymbolTable().numSymbols();
            int numEmits = emissions.length;
            int lastEmitIndex = numEmits - 1;
            for (int tagId = 0; tagId < numStates; ++tagId) {
                double contScore = vit.mLattice[lastEmitIndex][tagId];
                if (!(contScore > Double.NEGATIVE_INFINITY)) continue;
                double score = 0.0;
                this.mPQ.offer(new State(lastEmitIndex, score, contScore, tagId, null));
            }
        }

        @Override
        public ScoredObject<String[]> bufferNext() {
            int numTags = HmmDecoder.this.mHmm.stateSymbolTable().numSymbols();
            int numEmissions = this.mViterbi.mEmissions.length;
            int lastEmitIndex = numEmissions - 1;
            while (!this.mPQ.isEmpty()) {
                State st = this.mPQ.poll();
                int emitIndex = st.emissionIndex();
                if (emitIndex == 0) {
                    this.mPQ.setMaxSize(this.mPQ.maxSize() - 1);
                    return st.result(numEmissions);
                }
                String emission = this.mViterbi.mEmissions[emitIndex];
                int emitTagId = st.mTagId;
                double score = st.mScore;
                if (emitIndex == lastEmitIndex) {
                    score += HmmDecoder.this.mHmm.endLog2Prob(emitTagId);
                }
                int emitIndexMinus1 = emitIndex - 1;
                double emitLog2Prob = HmmDecoder.this.mHmm.emitLog2Prob(emitTagId, (CharSequence)emission);
                for (int tagId = 0; tagId < numTags; ++tagId) {
                    double nextScore = score + HmmDecoder.this.mHmm.transitLog2Prob(tagId, emitTagId) + emitLog2Prob;
                    double contScore = this.mViterbi.mLattice[emitIndexMinus1][tagId];
                    if (!(nextScore > Double.NEGATIVE_INFINITY) || !(contScore > Double.NEGATIVE_INFINITY)) continue;
                    this.mPQ.offer(new State(emitIndexMinus1, nextScore, contScore, tagId, st));
                }
            }
            return null;
        }
    }

    private class Viterbi {
        private final String[] mEmissions;
        private final double[][] mLattice;
        private final int[][] mBackPts;

        Viterbi(String[] emissions) {
            this.mEmissions = emissions;
            HiddenMarkovModel hmm = HmmDecoder.this.mHmm;
            int numStates = hmm.stateSymbolTable().numSymbols();
            int numEmits = emissions.length;
            double[][] lattice = new double[numEmits][numStates];
            this.mLattice = lattice;
            int[][] backPts = new int[numEmits][numStates];
            this.mBackPts = backPts;
            if (emissions.length == 0) {
                return;
            }
            double[] emitLog2Probs = HmmDecoder.this.emitLog2Probs(emissions[0]);
            for (int stateId = 0; stateId < numStates; ++stateId) {
                lattice[0][stateId] = emitLog2Probs[stateId] + hmm.startLog2Prob(stateId);
            }
            int[] unprunedSources = new int[numStates + 1];
            for (int i = 1; i < numEmits; ++i) {
                double[] lastSlice = lattice[i - 1];
                HmmDecoder.this.unprunedSources(lastSlice, unprunedSources, HmmDecoder.this.mLog2Beam);
                double[] emitLog2Probs2 = HmmDecoder.this.emitLog2Probs(emissions[i]);
                for (int targetId = 0; targetId < numStates; ++targetId) {
                    if (Double.NEGATIVE_INFINITY != emitLog2Probs2[targetId]) {
                        double best = Double.NEGATIVE_INFINITY;
                        int bk = 0;
                        int next = 0;
                        while (unprunedSources[next] != -1) {
                            int sourceId = unprunedSources[next];
                            double est = lastSlice[sourceId] + hmm.transitLog2Prob(sourceId, targetId);
                            if (est > best) {
                                best = est;
                                bk = sourceId;
                            }
                            ++next;
                        }
                        lattice[i][targetId] = best + emitLog2Probs2[targetId];
                        backPts[i][targetId] = bk;
                        continue;
                    }
                    lattice[i][targetId] = Double.NEGATIVE_INFINITY;
                    backPts[i][targetId] = 0;
                }
            }
            double[] lastColumn = lattice[numEmits - 1];
            for (int i = 0; i < numStates; ++i) {
                int n = i;
                lastColumn[n] = lastColumn[n] + hmm.endLog2Prob(i);
            }
        }

        String[] bestStates() {
            int i;
            HiddenMarkovModel hmm = HmmDecoder.this.mHmm;
            int numStates = hmm.stateSymbolTable().numSymbols();
            int numEmits = this.mEmissions.length;
            if (numEmits == 0) {
                return Strings.EMPTY_STRING_ARRAY;
            }
            int[][] backPts = this.mBackPts;
            double[][] lattice = this.mLattice;
            int[] bestStateIds = new int[numEmits];
            int bestStateId = 0;
            double[] lastCol = lattice[numEmits - 1];
            for (i = 1; i < numStates; ++i) {
                if (!(lastCol[i] > lastCol[bestStateId])) continue;
                bestStateId = i;
            }
            bestStateIds[numEmits - 1] = bestStateId;
            i = numEmits;
            while (--i > 0) {
                bestStateIds[i - 1] = backPts[i][bestStateIds[i]];
            }
            String[] bestStates = new String[numEmits];
            SymbolTable st = hmm.stateSymbolTable();
            for (int i2 = 0; i2 < bestStates.length; ++i2) {
                bestStates[i2] = st.idToSymbol(bestStateIds[i2]);
            }
            return bestStates;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class TaggingIteratorAdapter
    implements Iterator<ScoredTagging<String>> {
        private final Iterator<ScoredObject<String[]>> mIt;
        private final List<String> mTokens;
        private final int mMaxResults;
        private int mResults = 0;

        TaggingIteratorAdapter(List<String> tokens, Iterator<ScoredObject<String[]>> it, int maxResults) {
            this.mTokens = tokens;
            this.mIt = it;
            this.mMaxResults = maxResults;
        }

        @Override
        public ScoredTagging<String> next() {
            ScoredObject<String[]> so = this.mIt.next();
            double score = so.score();
            String[] tags = so.getObject();
            List<String> tagList = Arrays.asList(tags);
            ++this.mResults;
            return new ScoredTagging<String>(this.mTokens, tagList, score);
        }

        @Override
        public boolean hasNext() {
            return this.mResults < this.mMaxResults && this.mIt.hasNext();
        }

        @Override
        public void remove() {
            this.mIt.remove();
        }
    }
}

