package com.aliasi.classify;

import com.aliasi.corpus.Corpus;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.features.Features;
import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.LogisticRegression;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.symbol.MapSymbolTable;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import com.aliasi.util.FeatureExtractor;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.ObjectToDoubleMap;
import com.aliasi.util.ScoredObject;
import java.io.CharArrayWriter;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:com/aliasi/classify/LogisticRegressionClassifier.class */
public class LogisticRegressionClassifier<E> implements ConditionalClassifier<E>, Compilable, Serializable {
    static final long serialVersionUID = -400005337034204553L;
    private final LogisticRegression mModel;
    private final FeatureExtractor<? super E> mFeatureExtractor;
    private final boolean mAddInterceptFeature;
    private final SymbolTable mFeatureSymbolTable;
    private final String[] mCategorySymbols;
    public static final String INTERCEPT_FEATURE_NAME = "*&^INTERCEPT%$^&**";
    static final Vector[] EMPTY_VECTOR_ARRAY = new Vector[0];

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/aliasi/classify/LogisticRegressionClassifier$DataExtractor.class */
    public static class DataExtractor<F> implements ObjectHandler<Classified<F>> {
        final FeatureExtractor<? super F> mFeatureExtractor;
        final SymbolTable mFeatureSymbolTable;
        final SymbolTable mCategorySymbolTable;
        final boolean mAddInterceptFeature;
        final int mNumSymbols;
        final List<Vector> mInputVectorList = new ArrayList();
        final List<Integer> mOutputCategoryList = new ArrayList();

        DataExtractor(FeatureExtractor<? super F> featureExtractor, SymbolTable symbolTable, SymbolTable symbolTable2, boolean z, int i) {
            this.mFeatureExtractor = featureExtractor;
            this.mFeatureSymbolTable = symbolTable;
            this.mCategorySymbolTable = symbolTable2;
            this.mAddInterceptFeature = z;
            this.mNumSymbols = i;
        }

        @Override // com.aliasi.corpus.ObjectHandler
        public void handle(Classified<F> classified) {
            F object = classified.getObject();
            Integer valueOf = Integer.valueOf(this.mCategorySymbolTable.getOrAddSymbol(classified.getClassification().bestCategory()));
            this.mInputVectorList.add(Features.toVector(this.mFeatureExtractor.features(object), this.mFeatureSymbolTable, this.mNumSymbols, this.mAddInterceptFeature));
            this.mOutputCategoryList.add(valueOf);
        }

        int[] categories() {
            int[] iArr = new int[this.mOutputCategoryList.size()];
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = this.mOutputCategoryList.get(i).intValue();
            }
            return iArr;
        }

        Vector[] inputs() {
            return (Vector[]) this.mInputVectorList.toArray(LogisticRegressionClassifier.EMPTY_VECTOR_ARRAY);
        }
    }

    /* loaded from: input_file:com/aliasi/classify/LogisticRegressionClassifier$Externalizer.class */
    static class Externalizer<G> extends AbstractExternalizable {
        static final long serialVersionUID = -2003123148721825458L;
        final LogisticRegressionClassifier<G> mClassifier;

        public Externalizer() {
            this(null);
        }

        public Externalizer(LogisticRegressionClassifier<G> logisticRegressionClassifier) {
            this.mClassifier = logisticRegressionClassifier;
        }

        @Override // com.aliasi.util.AbstractExternalizable, java.io.Externalizable
        public void writeExternal(ObjectOutput objectOutput) throws IOException {
            objectOutput.writeObject(((LogisticRegressionClassifier) this.mClassifier).mModel);
            objectOutput.writeObject(((LogisticRegressionClassifier) this.mClassifier).mFeatureExtractor);
            objectOutput.writeBoolean(((LogisticRegressionClassifier) this.mClassifier).mAddInterceptFeature);
            objectOutput.writeObject(((LogisticRegressionClassifier) this.mClassifier).mFeatureSymbolTable);
            objectOutput.writeInt(((LogisticRegressionClassifier) this.mClassifier).mCategorySymbols.length);
            for (int i = 0; i < ((LogisticRegressionClassifier) this.mClassifier).mCategorySymbols.length; i++) {
                objectOutput.writeUTF(((LogisticRegressionClassifier) this.mClassifier).mCategorySymbols[i]);
            }
        }

        @Override // com.aliasi.util.AbstractExternalizable
        public Object read(ObjectInput objectInput) throws IOException, ClassNotFoundException {
            LogisticRegression logisticRegression = (LogisticRegression) objectInput.readObject();
            FeatureExtractor featureExtractor = (FeatureExtractor) objectInput.readObject();
            boolean readBoolean = objectInput.readBoolean();
            SymbolTable symbolTable = (SymbolTable) objectInput.readObject();
            String[] strArr = new String[objectInput.readInt()];
            for (int i = 0; i < strArr.length; i++) {
                strArr[i] = objectInput.readUTF();
            }
            return new LogisticRegressionClassifier(logisticRegression, featureExtractor, readBoolean, symbolTable, strArr);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/aliasi/classify/LogisticRegressionClassifier$FeatureCounter.class */
    public static class FeatureCounter<H> implements ObjectHandler<Classified<H>> {
        private final FeatureExtractor<? super H> mFeatureExtractor;
        private final ObjectToCounterMap<String> mFeatureCounter;

        FeatureCounter(FeatureExtractor<? super H> featureExtractor, ObjectToCounterMap<String> objectToCounterMap) {
            this.mFeatureExtractor = featureExtractor;
            this.mFeatureCounter = objectToCounterMap;
        }

        @Override // com.aliasi.corpus.ObjectHandler
        public void handle(Classified<H> classified) {
            Iterator<String> it = this.mFeatureExtractor.features(classified.getObject()).keySet().iterator();
            while (it.hasNext()) {
                this.mFeatureCounter.increment(it.next());
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/aliasi/classify/LogisticRegressionClassifier$RegressionHandlerAdapter.class */
    public static class RegressionHandlerAdapter<F> implements ObjectHandler<LogisticRegression> {
        private final ObjectHandler<LogisticRegressionClassifier<F>> mClassifierHandler;
        private final FeatureExtractor<? super F> mFeatureExtractor;
        private final boolean mAddInterceptFeature;
        private final SymbolTable mFeatureSymbolTable;
        private final String[] mCategorySymbols;

        public RegressionHandlerAdapter(ObjectHandler<LogisticRegressionClassifier<F>> objectHandler, FeatureExtractor<? super F> featureExtractor, boolean z, SymbolTable symbolTable, String[] strArr) {
            this.mClassifierHandler = objectHandler;
            this.mFeatureExtractor = featureExtractor;
            this.mAddInterceptFeature = z;
            this.mFeatureSymbolTable = symbolTable;
            this.mCategorySymbols = strArr;
        }

        @Override // com.aliasi.corpus.ObjectHandler
        public void handle(LogisticRegression logisticRegression) {
            this.mClassifierHandler.handle(new LogisticRegressionClassifier<>(logisticRegression, this.mFeatureExtractor, this.mAddInterceptFeature, this.mFeatureSymbolTable, this.mCategorySymbols));
        }
    }

    LogisticRegressionClassifier(LogisticRegression logisticRegression, FeatureExtractor<? super E> featureExtractor, boolean z, SymbolTable symbolTable, String[] strArr) {
        if (logisticRegression.numOutcomes() != strArr.length) {
            throw new IllegalArgumentException("Number of model outcomes must match category symbols length. Found model.numOutcomes()=" + logisticRegression.numOutcomes() + " categorySymbols.length=" + strArr.length);
        }
        HashSet hashSet = new HashSet();
        for (int i = 0; i < strArr.length; i++) {
            if (!hashSet.add(strArr[i])) {
                throw new IllegalArgumentException("Categories must be unique. Found duplicate category categorySymbols[" + i + "]=" + strArr[i]);
            }
        }
        this.mModel = logisticRegression;
        this.mFeatureExtractor = featureExtractor;
        this.mAddInterceptFeature = z;
        this.mFeatureSymbolTable = symbolTable;
        this.mCategorySymbols = strArr;
    }

    public SymbolTable featureSymbolTable() {
        return MapSymbolTable.unmodifiableView(this.mFeatureSymbolTable);
    }

    public List<String> categorySymbols() {
        return Arrays.asList(this.mCategorySymbols);
    }

    public LogisticRegression model() {
        return this.mModel;
    }

    public boolean addInterceptFeature() {
        return this.mAddInterceptFeature;
    }

    public FeatureExtractor<E> featureExtractor() {
        return new FeatureExtractor<E>() { // from class: com.aliasi.classify.LogisticRegressionClassifier.1
            @Override // com.aliasi.util.FeatureExtractor
            public Map<String, ? extends Number> features(E e) {
                return LogisticRegressionClassifier.this.mFeatureExtractor.features(e);
            }
        };
    }

    public ConditionalClassification classifyVector(Vector vector) {
        double[] classify = this.mModel.classify(vector);
        ScoredObject[] scoredObjectArr = new ScoredObject[classify.length];
        for (int i = 0; i < classify.length; i++) {
            scoredObjectArr[i] = new ScoredObject(this.mCategorySymbols[i], classify[i]);
        }
        Arrays.sort(scoredObjectArr, ScoredObject.reverseComparator());
        String[] strArr = new String[classify.length];
        for (int i2 = 0; i2 < classify.length; i2++) {
            strArr[i2] = ((String) scoredObjectArr[i2].getObject()).toString();
            classify[i2] = scoredObjectArr[i2].score();
        }
        return new ConditionalClassification(strArr, classify);
    }

    public ConditionalClassification classifyFeatures(Map<String, ? extends Number> map) {
        return classifyVector(Features.toVector(map, this.mFeatureSymbolTable, this.mFeatureSymbolTable.numSymbols(), this.mAddInterceptFeature));
    }

    @Override // com.aliasi.classify.ConditionalClassifier, com.aliasi.classify.ScoredClassifier, com.aliasi.classify.RankedClassifier, com.aliasi.classify.BaseClassifier
    public ConditionalClassification classify(E e) {
        return classifyFeatures(this.mFeatureExtractor.features(e));
    }

    @Override // com.aliasi.util.Compilable
    public void compileTo(ObjectOutput objectOutput) throws IOException {
        objectOutput.writeObject(new Externalizer(this));
    }

    private int categoryToId(String str) {
        for (int i = 0; i < this.mCategorySymbols.length; i++) {
            if (this.mCategorySymbols[i].equals(str)) {
                return i;
            }
        }
        return -1;
    }

    public ObjectToDoubleMap<String> featureValues(String str) {
        int categoryToId = categoryToId(str);
        if (categoryToId < 0) {
            throw new IllegalArgumentException("Unknown category=" + str);
        }
        ObjectToDoubleMap<String> objectToDoubleMap = new ObjectToDoubleMap<>();
        if (categoryToId == this.mCategorySymbols.length - 1) {
            return objectToDoubleMap;
        }
        int numSymbols = this.mFeatureSymbolTable.numSymbols();
        Vector vector = this.mModel.weightVectors()[categoryToId];
        for (int i = 0; i < numSymbols; i++) {
            objectToDoubleMap.set(this.mFeatureSymbolTable.idToSymbol(i), vector.value(i));
        }
        return objectToDoubleMap;
    }

    public String toString() {
        CharArrayWriter charArrayWriter = new CharArrayWriter();
        PrintWriter printWriter = new PrintWriter(charArrayWriter);
        List<String> categorySymbols = categorySymbols();
        printWriter.println("NUMBER OF CATEGORIES=" + categorySymbols.size());
        printWriter.println("NUMBER OF FEATURES=" + this.mFeatureSymbolTable.numSymbols());
        for (int i = 0; i < categorySymbols.size() - 1; i++) {
            String str = categorySymbols.get(i);
            printWriter.println("\n  CATEGORY=" + str);
            ObjectToDoubleMap<String> featureValues = featureValues(str);
            for (String str2 : featureValues.keysOrderedByValueList()) {
                printWriter.printf("%20s %15.6f\n", str2, featureValues.get(str2));
            }
        }
        printWriter.write(10);
        return charArrayWriter.toString();
    }

    private Object writeReplace() {
        return new Externalizer(this);
    }

    public static <F> LogisticRegressionClassifier<F> train(Corpus<ObjectHandler<Classified<F>>> corpus, FeatureExtractor<? super F> featureExtractor, int i, boolean z, RegressionPrior regressionPrior, AnnealingSchedule annealingSchedule, double d, int i2, int i3, Reporter reporter) throws IOException {
        return train(corpus, featureExtractor, i, z, regressionPrior, -1, null, annealingSchedule, d, 5, i2, i3, null, reporter);
    }

    public static <F> LogisticRegressionClassifier<F> train(Corpus<ObjectHandler<Classified<F>>> corpus, FeatureExtractor<? super F> featureExtractor, int i, boolean z, RegressionPrior regressionPrior, int i2, LogisticRegressionClassifier<F> logisticRegressionClassifier, AnnealingSchedule annealingSchedule, double d, int i3, int i4, int i5, ObjectHandler<LogisticRegressionClassifier<F>> objectHandler, Reporter reporter) throws IOException {
        MapSymbolTable mapSymbolTable = new MapSymbolTable();
        MapSymbolTable mapSymbolTable2 = new MapSymbolTable();
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        if (z) {
            mapSymbolTable.getOrAddSymbol(INTERCEPT_FEATURE_NAME);
        }
        reporter.info("Feature Extractor class=" + featureExtractor.getClass());
        reporter.info("min feature count=" + i);
        reporter.info("Extracting Training Data");
        reporter.debug("  Counting features");
        ObjectToCounterMap objectToCounterMap = new ObjectToCounterMap();
        corpus.visitTrain(new FeatureCounter<>(featureExtractor, objectToCounterMap));
        reporter.debug("  Pruning features");
        objectToCounterMap.prune(i);
        Iterator<E> it = objectToCounterMap.keySet().iterator();
        while (it.hasNext()) {
            mapSymbolTable.getOrAddSymbol((String) it.next());
        }
        reporter.debug("  Extracting vectors");
        DataExtractor dataExtractor = new DataExtractor(featureExtractor, mapSymbolTable, mapSymbolTable2, z, mapSymbolTable.numSymbols());
        corpus.visitTrain(dataExtractor);
        Vector[] inputs = dataExtractor.inputs();
        int[] categories = dataExtractor.categories();
        int numDimensions = inputs[0].numDimensions();
        String[] strArr = new String[mapSymbolTable2.numSymbols()];
        for (int i6 = 0; i6 < strArr.length; i6++) {
            strArr[i6] = mapSymbolTable2.idToSymbol(i6);
        }
        LogisticRegression logisticRegression = null;
        if (logisticRegressionClassifier != null) {
            reporter.debug("hot starting");
            HashSet hashSet = new HashSet(logisticRegressionClassifier.categorySymbols());
            Vector[] vectorArr = new Vector[strArr.length - 1];
            for (int i7 = 0; i7 < vectorArr.length; i7++) {
                vectorArr[i7] = new DenseVector(numDimensions);
            }
            for (int i8 = 0; i8 < vectorArr.length - 1; i8++) {
                String str = strArr[i8];
                if (hashSet.contains(str)) {
                    ObjectToDoubleMap<String> featureValues = logisticRegressionClassifier.featureValues(str);
                    for (int i9 = 0; i9 < numDimensions; i9++) {
                        vectorArr[i8].setValue(i9, featureValues.getValue(mapSymbolTable.idToSymbol(i9)));
                    }
                }
            }
            logisticRegression = new LogisticRegression(vectorArr);
        }
        reporter.info(logisticRegressionClassifier != null ? "Hot start" : "Cold start");
        RegressionHandlerAdapter regressionHandlerAdapter = objectHandler == null ? null : new RegressionHandlerAdapter(objectHandler, featureExtractor, z, mapSymbolTable, strArr);
        reporter.info(regressionHandlerAdapter != null ? "Regssion callback handler class=" + regressionHandlerAdapter.getClass() : "Regression callback handler=" + ((Object) null));
        if (i2 == -1) {
            i2 = Math.max(1, categories.length / 50);
        }
        return new LogisticRegressionClassifier<>(LogisticRegression.estimate(inputs, categories, regressionPrior, i2, logisticRegression, annealingSchedule, d, i3, i4, i5, regressionHandlerAdapter, reporter), featureExtractor, z, mapSymbolTable, strArr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.aliasi.classify.ScoredClassifier, com.aliasi.classify.RankedClassifier, com.aliasi.classify.BaseClassifier
    public /* bridge */ /* synthetic */ ScoredClassification classify(Object obj) {
        return classify((LogisticRegressionClassifier<E>) obj);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.aliasi.classify.RankedClassifier, com.aliasi.classify.BaseClassifier
    public /* bridge */ /* synthetic */ RankedClassification classify(Object obj) {
        return classify((LogisticRegressionClassifier<E>) obj);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.aliasi.classify.BaseClassifier
    public /* bridge */ /* synthetic */ Classification classify(Object obj) {
        return classify((LogisticRegressionClassifier<E>) obj);
    }
}
