package com.aliasi.classify;

import com.aliasi.corpus.Corpus;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
import com.aliasi.stats.Statistics;
import com.aliasi.tokenizer.TokenizerFactory;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import com.aliasi.util.Counter;
import com.aliasi.util.Exceptions;
import com.aliasi.util.Factory;
import com.aliasi.util.Iterators;
import com.aliasi.util.Math;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.Strings;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.ObjectStreamException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.Formatter;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:com/aliasi/classify/TradNaiveBayesClassifier.class */
public class TradNaiveBayesClassifier implements JointClassifier<CharSequence>, ObjectHandler<Classified<CharSequence>>, Serializable, Compilable {
    static final long serialVersionUID = -300327951207213311L;
    private final Set<String> mCategorySet;
    private final String[] mCategories;
    private final TokenizerFactory mTokenizerFactory;
    private final double mCategoryPrior;
    private final double mTokenInCategoryPrior;
    private Map<String, double[]> mTokenToCountsMap;
    private double[] mTotalCountsPerCategory;
    private double[] mCaseCounts;
    private double mTotalCaseCount;
    private double mLengthNorm;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/aliasi/classify/TradNaiveBayesClassifier$CaseProbAccumulator.class */
    public static class CaseProbAccumulator implements ObjectHandler<CharSequence> {
        double mCaseProb = 0.0d;
        final TradNaiveBayesClassifier mClassifier;

        CaseProbAccumulator(TradNaiveBayesClassifier tradNaiveBayesClassifier) {
            this.mClassifier = tradNaiveBayesClassifier;
        }

        @Override // com.aliasi.corpus.ObjectHandler
        public void handle(CharSequence charSequence) {
            this.mCaseProb += this.mClassifier.log2CaseProb(charSequence);
        }

        public ObjectHandler<Classified<CharSequence>> supHandler() {
            return new ObjectHandler<Classified<CharSequence>>() { // from class: com.aliasi.classify.TradNaiveBayesClassifier.CaseProbAccumulator.1
                @Override // com.aliasi.corpus.ObjectHandler
                public void handle(Classified<CharSequence> classified) {
                    this.handle(classified.getObject());
                }
            };
        }
    }

    /* loaded from: input_file:com/aliasi/classify/TradNaiveBayesClassifier$CompiledBinaryTradNaiveBayesClassifier.class */
    private static class CompiledBinaryTradNaiveBayesClassifier implements JointClassifier<CharSequence> {
        private final TokenizerFactory mTokenizerFactory;
        private final Map<String, Double> mTokenToLog2ProbDiff = new HashMap();
        private final double mLog2CatProbDiff;
        private final double mLengthNorm;
        private final String[] mCats01;
        private final String[] mCats10;

        CompiledBinaryTradNaiveBayesClassifier(String[] strArr, TokenizerFactory tokenizerFactory, Map<String, double[]> map, double[] dArr, double d) {
            this.mTokenizerFactory = tokenizerFactory;
            for (Map.Entry<String, double[]> entry : map.entrySet()) {
                String key = entry.getKey();
                double[] value = entry.getValue();
                this.mTokenToLog2ProbDiff.put(key, Double.valueOf((value[0] - value[1]) / Math.LOG2_E));
            }
            this.mLog2CatProbDiff = (dArr[0] - dArr[1]) / Math.LOG2_E;
            this.mLengthNorm = d;
            this.mCats01 = new String[]{strArr[0], strArr[1]};
            this.mCats10 = new String[]{strArr[1], strArr[0]};
        }

        @Override // com.aliasi.classify.BaseClassifier
        public JointClassification classify(CharSequence charSequence) {
            double d = 0.0d;
            char[] charArray = Strings.toCharArray(charSequence);
            int i = 0;
            Iterator<String> it = this.mTokenizerFactory.tokenizer(charArray, 0, charArray.length).iterator();
            while (it.hasNext()) {
                Double d2 = this.mTokenToLog2ProbDiff.get(it.next());
                i++;
                if (d2 != null) {
                    d += d2.doubleValue();
                }
            }
            if (!Double.isNaN(this.mLengthNorm) && i > 0) {
                d *= this.mLengthNorm / i;
            }
            return classification(d + this.mLog2CatProbDiff);
        }

        JointClassification classification(double d) {
            double exp = Math.exp(d);
            double d2 = exp / (1.0d + exp);
            double d3 = 1.0d - d2;
            double log2 = Math.log2(d2);
            double log22 = Math.log2(d3);
            return d2 > d3 ? new JointClassification(this.mCats01, new double[]{log2, log22}) : new JointClassification(this.mCats10, new double[]{log22, log2});
        }
    }

    /* loaded from: input_file:com/aliasi/classify/TradNaiveBayesClassifier$CompiledTradNaiveBayesClassifier.class */
    private static class CompiledTradNaiveBayesClassifier implements JointClassifier<CharSequence> {
        private final TokenizerFactory mTokenizerFactory;
        private final String[] mCategories;
        private final Map<String, double[]> mTokenToLog2ProbsInCats;
        private final double[] mLog2CatProbs;
        private final double mLengthNorm;

        CompiledTradNaiveBayesClassifier(String[] strArr, TokenizerFactory tokenizerFactory, Map<String, double[]> map, double[] dArr, double d) {
            this.mCategories = strArr;
            this.mTokenizerFactory = tokenizerFactory;
            this.mTokenToLog2ProbsInCats = map;
            this.mLog2CatProbs = dArr;
            this.mLengthNorm = d;
        }

        @Override // com.aliasi.classify.BaseClassifier
        public JointClassification classify(CharSequence charSequence) {
            double[] dArr = new double[this.mCategories.length];
            char[] charArray = Strings.toCharArray(charSequence);
            int i = 0;
            Iterator<String> it = this.mTokenizerFactory.tokenizer(charArray, 0, charArray.length).iterator();
            while (it.hasNext()) {
                double[] dArr2 = this.mTokenToLog2ProbsInCats.get(it.next());
                i++;
                if (dArr2 != null) {
                    for (int i2 = 0; i2 < dArr.length; i2++) {
                        int i3 = i2;
                        dArr[i3] = dArr[i3] + dArr2[i2];
                    }
                }
            }
            if (!Double.isNaN(this.mLengthNorm) && i > 0) {
                for (int i4 = 0; i4 < dArr.length; i4++) {
                    int i5 = i4;
                    dArr[i5] = dArr[i5] * (this.mLengthNorm / i);
                }
            }
            for (int i6 = 0; i6 < dArr.length; i6++) {
                int i7 = i6;
                dArr[i7] = dArr[i7] + this.mLog2CatProbs[i6];
            }
            return JointClassification.create(this.mCategories, dArr);
        }
    }

    /* loaded from: input_file:com/aliasi/classify/TradNaiveBayesClassifier$Compiler.class */
    static class Compiler extends AbstractExternalizable {
        static final long serialVersionUID = 5689464666886334529L;
        private final TradNaiveBayesClassifier mClassifier;

        public Compiler() {
            this(null);
        }

        public Compiler(TradNaiveBayesClassifier tradNaiveBayesClassifier) {
            this.mClassifier = tradNaiveBayesClassifier;
        }

        @Override // com.aliasi.util.AbstractExternalizable, java.io.Externalizable
        public void writeExternal(ObjectOutput objectOutput) throws IOException {
            objectOutput.writeInt(this.mClassifier.mCategories.length);
            for (int i = 0; i < this.mClassifier.mCategories.length; i++) {
                objectOutput.writeUTF(this.mClassifier.mCategories[i]);
            }
            AbstractExternalizable.compileOrSerialize(this.mClassifier.mTokenizerFactory, objectOutput);
            objectOutput.writeInt(this.mClassifier.mTokenToCountsMap.size());
            for (Map.Entry entry : this.mClassifier.mTokenToCountsMap.entrySet()) {
                objectOutput.writeUTF((String) entry.getKey());
                double[] dArr = (double[]) entry.getValue();
                for (int i2 = 0; i2 < this.mClassifier.mCategories.length; i2++) {
                    double log2 = Math.log2(this.mClassifier.probTokenByIndexArray(i2, dArr));
                    if (log2 > 0.0d) {
                        throw new IllegalArgumentException("key=" + ((String) entry.getKey()) + " i=" + i2 + " log2Prob=" + log2 + " prob=" + this.mClassifier.probTokenByIndexArray(i2, dArr) + " token counts[" + i2 + "]=" + dArr[i2] + " totalCatCount=" + this.mClassifier.mTotalCountsPerCategory[i2] + " mTokenToCountsMap.size()=" + this.mClassifier.mTokenToCountsMap.size());
                    }
                    objectOutput.writeDouble(log2);
                }
            }
            for (int i3 = 0; i3 < this.mClassifier.mCategories.length; i3++) {
                objectOutput.writeDouble(Math.log2(this.mClassifier.probCatByIndex(i3)));
            }
            objectOutput.writeDouble(this.mClassifier.mLengthNorm);
        }

        @Override // com.aliasi.util.AbstractExternalizable
        public Object read(ObjectInput objectInput) throws ClassNotFoundException, IOException {
            int readInt = objectInput.readInt();
            String[] strArr = new String[readInt];
            for (int i = 0; i < readInt; i++) {
                strArr[i] = objectInput.readUTF();
            }
            TokenizerFactory tokenizerFactory = (TokenizerFactory) objectInput.readObject();
            int readInt2 = objectInput.readInt();
            HashMap hashMap = new HashMap((readInt2 * 3) / 2);
            for (int i2 = 0; i2 < readInt2; i2++) {
                String readUTF = objectInput.readUTF();
                double[] dArr = new double[readInt];
                for (int i3 = 0; i3 < readInt; i3++) {
                    dArr[i3] = objectInput.readDouble();
                }
                hashMap.put(readUTF, dArr);
            }
            double[] dArr2 = new double[readInt];
            for (int i4 = 0; i4 < readInt; i4++) {
                dArr2[i4] = objectInput.readDouble();
            }
            double readDouble = objectInput.readDouble();
            return strArr.length == 2 ? new CompiledBinaryTradNaiveBayesClassifier(strArr, tokenizerFactory, hashMap, dArr2, readDouble) : new CompiledTradNaiveBayesClassifier(strArr, tokenizerFactory, hashMap, dArr2, readDouble);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/aliasi/classify/TradNaiveBayesClassifier$EmIterator.class */
    public static class EmIterator extends Iterators.Buffered<TradNaiveBayesClassifier> {
        private final Factory<TradNaiveBayesClassifier> mClassifierFactory;
        private final Corpus<ObjectHandler<Classified<CharSequence>>> mLabeledData;
        private final Corpus<ObjectHandler<CharSequence>> mUnlabeledData;
        private final double mMinTokenCount;
        private JointClassifier<CharSequence> mLastClassifier;

        EmIterator(TradNaiveBayesClassifier tradNaiveBayesClassifier, Factory<TradNaiveBayesClassifier> factory, Corpus<ObjectHandler<Classified<CharSequence>>> corpus, Corpus<ObjectHandler<CharSequence>> corpus2, double d) {
            this.mClassifierFactory = factory;
            this.mLabeledData = corpus;
            this.mUnlabeledData = corpus2;
            this.mMinTokenCount = d;
            trainSup(corpus, tradNaiveBayesClassifier);
            compile(tradNaiveBayesClassifier);
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // com.aliasi.util.Iterators.Buffered
        public TradNaiveBayesClassifier bufferNext() {
            TradNaiveBayesClassifier create = this.mClassifierFactory.create();
            trainSup(this.mLabeledData, create);
            trainUnsup(this.mUnlabeledData, create);
            compile(create);
            return create;
        }

        void trainSup(Corpus<ObjectHandler<Classified<CharSequence>>> corpus, TradNaiveBayesClassifier tradNaiveBayesClassifier) {
            try {
                corpus.visitTrain(tradNaiveBayesClassifier);
            } catch (IOException e) {
                throw new IllegalStateException("Error during labeled training", e);
            }
        }

        void trainUnsup(Corpus<ObjectHandler<CharSequence>> corpus, final TradNaiveBayesClassifier tradNaiveBayesClassifier) {
            try {
                corpus.visitTrain(new ObjectHandler<CharSequence>() { // from class: com.aliasi.classify.TradNaiveBayesClassifier.EmIterator.1
                    @Override // com.aliasi.corpus.ObjectHandler
                    public void handle(CharSequence charSequence) {
                        tradNaiveBayesClassifier.trainConditional(charSequence, EmIterator.this.mLastClassifier.classify((JointClassifier) charSequence), 1.0d, EmIterator.this.mMinTokenCount);
                    }
                });
            } catch (IOException e) {
                throw new IllegalStateException("Error during unlabeled training", e);
            }
        }

        void compile(TradNaiveBayesClassifier tradNaiveBayesClassifier) {
            try {
                this.mLastClassifier = (JointClassifier) AbstractExternalizable.compile(tradNaiveBayesClassifier);
            } catch (IOException e) {
                this.mLastClassifier = null;
                throw new IllegalStateException("Error during compilation.", e);
            } catch (ClassNotFoundException e2) {
                this.mLastClassifier = null;
                throw new IllegalStateException("Error during compilation.", e2);
            }
        }
    }

    /* loaded from: input_file:com/aliasi/classify/TradNaiveBayesClassifier$Serializer.class */
    static class Serializer extends AbstractExternalizable {
        static final long serialVersionUID = -4786039228920809976L;
        private final TradNaiveBayesClassifier mClassifier;

        public Serializer(TradNaiveBayesClassifier tradNaiveBayesClassifier) {
            this.mClassifier = tradNaiveBayesClassifier;
        }

        public Serializer() {
            this(null);
        }

        @Override // com.aliasi.util.AbstractExternalizable
        public Object read(ObjectInput objectInput) throws ClassNotFoundException, IOException {
            int readInt = objectInput.readInt();
            String[] strArr = new String[readInt];
            for (int i = 0; i < readInt; i++) {
                strArr[i] = objectInput.readUTF();
            }
            TokenizerFactory tokenizerFactory = (TokenizerFactory) objectInput.readObject();
            double readDouble = objectInput.readDouble();
            double readDouble2 = objectInput.readDouble();
            int readInt2 = objectInput.readInt();
            HashMap hashMap = new HashMap((readInt2 * 3) / 2);
            for (int i2 = 0; i2 < readInt2; i2++) {
                String readUTF = objectInput.readUTF();
                double[] dArr = new double[strArr.length];
                for (int i3 = 0; i3 < strArr.length; i3++) {
                    dArr[i3] = objectInput.readDouble();
                }
                hashMap.put(readUTF, dArr);
            }
            double[] dArr2 = new double[strArr.length];
            for (int i4 = 0; i4 < strArr.length; i4++) {
                dArr2[i4] = objectInput.readDouble();
            }
            double[] dArr3 = new double[strArr.length];
            for (int i5 = 0; i5 < strArr.length; i5++) {
                dArr3[i5] = objectInput.readDouble();
            }
            return new TradNaiveBayesClassifier(strArr, tokenizerFactory, readDouble, readDouble2, hashMap, dArr2, dArr3, objectInput.readDouble(), objectInput.readDouble());
        }

        @Override // com.aliasi.util.AbstractExternalizable, java.io.Externalizable
        public void writeExternal(ObjectOutput objectOutput) throws IOException {
            objectOutput.writeInt(this.mClassifier.mCategories.length);
            for (String str : this.mClassifier.mCategories) {
                objectOutput.writeUTF(str);
            }
            objectOutput.writeObject(this.mClassifier.mTokenizerFactory);
            objectOutput.writeDouble(this.mClassifier.mCategoryPrior);
            objectOutput.writeDouble(this.mClassifier.mTokenInCategoryPrior);
            objectOutput.writeInt(this.mClassifier.mTokenToCountsMap.size());
            for (Map.Entry entry : this.mClassifier.mTokenToCountsMap.entrySet()) {
                objectOutput.writeUTF((String) entry.getKey());
                double[] dArr = (double[]) entry.getValue();
                for (int i = 0; i < this.mClassifier.mCategories.length; i++) {
                    objectOutput.writeDouble(dArr[i]);
                }
            }
            for (int i2 = 0; i2 < this.mClassifier.mCategories.length; i2++) {
                objectOutput.writeDouble(this.mClassifier.mTotalCountsPerCategory[i2]);
            }
            for (int i3 = 0; i3 < this.mClassifier.mCategories.length; i3++) {
                objectOutput.writeDouble(this.mClassifier.mCaseCounts[i3]);
            }
            objectOutput.writeDouble(this.mClassifier.mTotalCaseCount);
            objectOutput.writeDouble(this.mClassifier.mLengthNorm);
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("categories=" + Arrays.asList(this.mCategories) + "\n");
        sb.append("category Prior=" + this.mCategoryPrior + "\n");
        sb.append("token in category prior=" + this.mTokenInCategoryPrior + "\n");
        sb.append("total case count=" + this.mTotalCaseCount + "\n");
        for (int i = 0; i < this.mCategories.length; i++) {
            sb.append("category count(" + this.mCategories[i] + ")=" + this.mCaseCounts[i] + "\n");
        }
        for (String str : this.mTokenToCountsMap.keySet()) {
            sb.append("token=" + str + "\n");
            double[] dArr = this.mTokenToCountsMap.get(str);
            for (int i2 = 0; i2 < this.mCategories.length; i2++) {
                sb.append("  tokenCount(" + this.mCategories[i2] + "," + str + ")=" + dArr[i2] + "\n");
            }
        }
        return sb.toString();
    }

    private TradNaiveBayesClassifier(String[] strArr, TokenizerFactory tokenizerFactory, double d, double d2, Map<String, double[]> map, double[] dArr, double[] dArr2, double d3, double d4) {
        this.mCategories = strArr;
        this.mCategorySet = new HashSet(Arrays.asList(strArr));
        this.mTokenizerFactory = tokenizerFactory;
        this.mCategoryPrior = d;
        this.mTokenInCategoryPrior = d2;
        this.mTokenToCountsMap = map;
        this.mTotalCountsPerCategory = dArr;
        this.mCaseCounts = dArr2;
        this.mTotalCaseCount = d3;
        this.mLengthNorm = d4;
    }

    public TradNaiveBayesClassifier(Set<String> set, TokenizerFactory tokenizerFactory) {
        this(set, tokenizerFactory, 0.5d, 0.5d, Double.NaN);
    }

    public TradNaiveBayesClassifier(Set<String> set, TokenizerFactory tokenizerFactory, double d, double d2, double d3) {
        if (set.size() < 2) {
            throw new IllegalArgumentException("Require at least two categorySet. Found categorySet.size()=" + set.size());
        }
        Exceptions.finiteNonNegative("categoryPrior", d);
        Exceptions.finiteNonNegative("tokenInCategoryPrior", d2);
        setLengthNorm(d3);
        this.mTotalCaseCount = 0.0d;
        this.mCategorySet = new HashSet(set);
        this.mCategories = (String[]) this.mCategorySet.toArray(Strings.EMPTY_STRING_ARRAY);
        Arrays.sort(this.mCategories);
        this.mTokenizerFactory = tokenizerFactory;
        this.mCategoryPrior = d;
        this.mTokenInCategoryPrior = d2;
        this.mTokenToCountsMap = new HashMap();
        this.mTotalCountsPerCategory = new double[this.mCategories.length];
        this.mCaseCounts = new double[this.mCategories.length];
    }

    public Set<String> categorySet() {
        return Collections.unmodifiableSet(this.mCategorySet);
    }

    public void setLengthNorm(double d) {
        if (d <= 0.0d || Double.isInfinite(d)) {
            throw new IllegalArgumentException("Length norm must be finite and positive, or Double.NaN. Found lengthNorm=" + d);
        }
        this.mLengthNorm = d;
    }

    @Override // com.aliasi.classify.BaseClassifier
    public JointClassification classify(CharSequence charSequence) {
        double[] dArr = new double[this.mCategories.length];
        char[] charArray = Strings.toCharArray(charSequence);
        int i = 0;
        Iterator<String> it = this.mTokenizerFactory.tokenizer(charArray, 0, charArray.length).iterator();
        while (it.hasNext()) {
            double[] dArr2 = this.mTokenToCountsMap.get(it.next());
            i++;
            if (dArr2 != null) {
                for (int i2 = 0; i2 < this.mCategories.length; i2++) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + Math.log2(probTokenByIndexArray(i2, dArr2));
                }
            }
        }
        if (!Double.isNaN(this.mLengthNorm) && i > 0) {
            for (int i4 = 0; i4 < dArr.length; i4++) {
                int i5 = i4;
                dArr[i5] = dArr[i5] * (this.mLengthNorm / i);
            }
        }
        for (int i6 = 0; i6 < dArr.length; i6++) {
            int i7 = i6;
            dArr[i7] = dArr[i7] + Math.log2(probCatByIndex(i6));
        }
        return JointClassification.create(this.mCategories, dArr);
    }

    public double lengthNorm() {
        return this.mLengthNorm;
    }

    public boolean isKnownToken(String str) {
        return this.mTokenToCountsMap.containsKey(str);
    }

    public Set<String> knownTokenSet() {
        return Collections.unmodifiableSet(this.mTokenToCountsMap.keySet());
    }

    public double probToken(String str, String str2) {
        int index = getIndex(str2);
        double[] dArr = this.mTokenToCountsMap.get(str);
        if (dArr == null) {
            throw new IllegalArgumentException("Requires known token. Found token=" + str);
        }
        return probTokenByIndexArray(index, dArr);
    }

    @Override // com.aliasi.util.Compilable
    public void compileTo(ObjectOutput objectOutput) throws IOException {
        objectOutput.writeObject(new Compiler(this));
    }

    public double probCat(String str) {
        return probCatByIndex(getIndex(str));
    }

    @Override // com.aliasi.corpus.ObjectHandler
    public void handle(Classified<CharSequence> classified) {
        handle(classified.getObject(), classified.getClassification());
    }

    void handle(CharSequence charSequence, Classification classification) {
        train(charSequence, classification, 1.0d);
    }

    public void trainConditional(CharSequence charSequence, ConditionalClassification conditionalClassification, double d, double d2) {
        if (d < 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("Count multipliers must be finite and non-negative. Found countMultiplier=" + d);
        }
        if (d2 < 0.0d || Double.isNaN(d2) || Double.isInfinite(d2)) {
            throw new IllegalArgumentException("Minimum count must be finite non-negative. Found minCount=" + d2);
        }
        int i = 0;
        while (i < conditionalClassification.size() && conditionalClassification.conditionalProbability(i) * d >= d2) {
            i++;
        }
        ObjectToCounterMap<String> objectToCounterMap = tokenCountMap(charSequence);
        double lengthMultiplier = lengthMultiplier(objectToCounterMap);
        double[] dArr = new double[i];
        int[] iArr = new int[i];
        for (int i2 = 0; i2 < i; i2++) {
            iArr[i2] = getIndex(conditionalClassification.category(i2));
            double conditionalProbability = d * conditionalClassification.conditionalProbability(i2);
            this.mTotalCaseCount += conditionalProbability;
            double[] dArr2 = this.mCaseCounts;
            int i3 = iArr[i2];
            dArr2[i3] = dArr2[i3] + conditionalProbability;
            dArr[i2] = lengthMultiplier * conditionalProbability;
        }
        for (Map.Entry<String, Counter> entry : objectToCounterMap.entrySet()) {
            String key = entry.getKey();
            double doubleValue = entry.getValue().doubleValue();
            double[] dArr3 = this.mTokenToCountsMap.get(key);
            if (dArr3 == null) {
                dArr3 = new double[this.mCategories.length];
                this.mTokenToCountsMap.put(key, dArr3);
            }
            for (int i4 = 0; i4 < i; i4++) {
                double d3 = doubleValue * dArr[i4];
                double[] dArr4 = dArr3;
                int i5 = iArr[i4];
                dArr4[i5] = dArr4[i5] + d3;
                double[] dArr5 = this.mTotalCountsPerCategory;
                int i6 = iArr[i4];
                dArr5[i6] = dArr5[i6] + d3;
            }
        }
    }

    public void train(CharSequence charSequence, Classification classification, double d) {
        if (d == 0.0d) {
            return;
        }
        String bestCategory = classification.bestCategory();
        int index = getIndex(bestCategory);
        if (this.mCaseCounts[index] < (-d)) {
            throw new IllegalArgumentException("Decrement caused negative token count.Revert to previous state. cSeq=" + ((Object) charSequence) + " classification=" + bestCategory + " count=" + d);
        }
        double[] dArr = this.mCaseCounts;
        dArr[index] = dArr[index] + d;
        this.mTotalCaseCount += d;
        double lengthMultiplier = lengthMultiplier(tokenCountMap(charSequence)) * d;
        char[] charArray = Strings.toCharArray(charSequence);
        int i = 0;
        Iterator<String> it = this.mTokenizerFactory.tokenizer(charArray, 0, charArray.length).iterator();
        while (it.hasNext()) {
            String next = it.next();
            double[] dArr2 = this.mTokenToCountsMap.get(next);
            if (lengthMultiplier < 0.0d && (dArr2 == null || dArr2[index] < (-lengthMultiplier))) {
                double[] dArr3 = this.mCaseCounts;
                dArr3[index] = dArr3[index] - d;
                this.mTotalCaseCount -= d;
                int i2 = 0;
                Iterator<String> it2 = this.mTokenizerFactory.tokenizer(charArray, 0, charArray.length).iterator();
                while (it2.hasNext()) {
                    String next2 = it2.next();
                    if (i2 >= i) {
                        break;
                    }
                    i2++;
                    double[] dArr4 = this.mTokenToCountsMap.get(next2);
                    dArr4[index] = dArr4[index] - lengthMultiplier;
                    double[] dArr5 = this.mTotalCountsPerCategory;
                    dArr5[index] = dArr5[index] - lengthMultiplier;
                }
                throw new IllegalArgumentException("Decrement caused negative token count.Revert to previous state. cSeq=" + ((Object) charSequence) + " classification=" + bestCategory + " count=" + d);
            }
            i++;
            if (dArr2 == null) {
                dArr2 = new double[this.mCategories.length];
                this.mTokenToCountsMap.put(next, dArr2);
            }
            double[] dArr6 = dArr2;
            dArr6[index] = dArr6[index] + lengthMultiplier;
            double[] dArr7 = this.mTotalCountsPerCategory;
            dArr7[index] = dArr7[index] + lengthMultiplier;
        }
    }

    public double log2CaseProb(CharSequence charSequence) {
        JointClassification classify = classify(charSequence);
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < classify.size(); i++) {
            double jointLog2Probability = classify.jointLog2Probability(i);
            if (jointLog2Probability > d) {
                d = jointLog2Probability;
            }
        }
        double d2 = 0.0d;
        for (int i2 = 0; i2 < classify.size(); i2++) {
            d2 += Math.pow(2.0d, classify.jointLog2Probability(i2) - d);
        }
        return d + Math.log2(d2);
    }

    public double log2ModelProb() {
        double[] dArr = new double[this.mCategories.length];
        for (int i = 0; i < this.mCategories.length; i++) {
            dArr[i] = probCatByIndex(i);
        }
        double dirichletLog2Prob = Statistics.dirichletLog2Prob(this.mCategoryPrior, dArr);
        double[] dArr2 = new double[this.mTokenToCountsMap.size()];
        for (int i2 = 0; i2 < this.mCategories.length; i2++) {
            int i3 = 0;
            Iterator<double[]> it = this.mTokenToCountsMap.values().iterator();
            while (it.hasNext()) {
                int i4 = i3;
                i3++;
                dArr2[i4] = (it.next()[i2] + this.mTokenInCategoryPrior) / (this.mTotalCountsPerCategory[i2] + (this.mCaseCounts.length * this.mTokenInCategoryPrior));
            }
            dirichletLog2Prob += Statistics.dirichletLog2Prob(this.mTokenInCategoryPrior, dArr2);
        }
        return dirichletLog2Prob;
    }

    private Object writeReplace() throws ObjectStreamException {
        return new Serializer(this);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double probTokenByIndexArray(int i, double[] dArr) {
        return (dArr[i] + this.mTokenInCategoryPrior) / (this.mTotalCountsPerCategory[i] + (this.mTokenToCountsMap.size() * this.mTokenInCategoryPrior));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double probCatByIndex(int i) {
        return (this.mCaseCounts[i] + this.mCategoryPrior) / (this.mTotalCaseCount + (this.mCategories.length * this.mCategoryPrior));
    }

    private ObjectToCounterMap<String> tokenCountMap(CharSequence charSequence) {
        ObjectToCounterMap<String> objectToCounterMap = new ObjectToCounterMap<>();
        char[] charArray = Strings.toCharArray(charSequence);
        Iterator<String> it = this.mTokenizerFactory.tokenizer(charArray, 0, charArray.length).iterator();
        while (it.hasNext()) {
            objectToCounterMap.increment(it.next());
        }
        return objectToCounterMap;
    }

    private double lengthMultiplier(ObjectToCounterMap<String> objectToCounterMap) {
        if (Double.isNaN(this.mLengthNorm)) {
            return 1.0d;
        }
        int i = 0;
        Iterator<Counter> it = objectToCounterMap.values().iterator();
        while (it.hasNext()) {
            i += it.next().intValue();
        }
        if (i != 0.0d) {
            return this.mLengthNorm / i;
        }
        return 1.0d;
    }

    private int getIndex(String str) {
        int binarySearch = Arrays.binarySearch(this.mCategories, str);
        if (binarySearch < 0) {
            throw new IllegalArgumentException("Unknown category.  Require category in category set. Found category=" + str + " category set=" + this.mCategorySet);
        }
        return binarySearch;
    }

    public static Iterator<TradNaiveBayesClassifier> emIterator(TradNaiveBayesClassifier tradNaiveBayesClassifier, Factory<TradNaiveBayesClassifier> factory, Corpus<ObjectHandler<Classified<CharSequence>>> corpus, Corpus<ObjectHandler<CharSequence>> corpus2, double d) throws IOException {
        return new EmIterator(tradNaiveBayesClassifier, factory, corpus, corpus2, d);
    }

    public static TradNaiveBayesClassifier emTrain(TradNaiveBayesClassifier tradNaiveBayesClassifier, Factory<TradNaiveBayesClassifier> factory, Corpus<ObjectHandler<Classified<CharSequence>>> corpus, Corpus<ObjectHandler<CharSequence>> corpus2, double d, int i, double d2, Reporter reporter) throws IOException {
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        System.currentTimeMillis();
        double d3 = Double.NEGATIVE_INFINITY;
        Iterator<TradNaiveBayesClassifier> emIterator = emIterator(tradNaiveBayesClassifier, factory, corpus, corpus2, d);
        TradNaiveBayesClassifier tradNaiveBayesClassifier2 = null;
        for (int i2 = 0; emIterator.hasNext() && i2 < i; i2++) {
            tradNaiveBayesClassifier2 = emIterator.next();
            double log2ModelProb = tradNaiveBayesClassifier2.log2ModelProb();
            double dataProb = dataProb(tradNaiveBayesClassifier2, corpus, corpus2);
            double d4 = log2ModelProb + dataProb;
            double relativeDiff = relativeDiff(d3, d4);
            if (reporter.isDebugEnabled()) {
                Formatter formatter = new Formatter();
                formatter.format("epoch=%4d   dataLogProb=%15.2f   modelLogProb=%15.2f   logProb=%15.2f   diff=%15.12f", Integer.valueOf(i2), Double.valueOf(dataProb), Double.valueOf(log2ModelProb), Double.valueOf(d4), Double.valueOf(relativeDiff));
                reporter.debug(formatter.toString());
            }
            if (!Double.isNaN(d3) && relativeDiff < d2) {
                reporter.info("Converged");
                return tradNaiveBayesClassifier2;
            }
            d3 = d4;
        }
        return tradNaiveBayesClassifier2;
    }

    static double dataProb(TradNaiveBayesClassifier tradNaiveBayesClassifier, Corpus<ObjectHandler<Classified<CharSequence>>> corpus, Corpus<ObjectHandler<CharSequence>> corpus2) throws IOException {
        CaseProbAccumulator caseProbAccumulator = new CaseProbAccumulator(tradNaiveBayesClassifier);
        corpus.visitTrain(caseProbAccumulator.supHandler());
        corpus2.visitTrain(caseProbAccumulator);
        return caseProbAccumulator.mCaseProb;
    }

    static double relativeDiff(double d, double d2) {
        return (2.0d * Math.abs(d - d2)) / (Math.abs(d) + Math.abs(d2));
    }
}
