package edu.usc.irds.agepredictor.authorage;

import edu.usc.irds.agepredictor.spark.authorage.AgePredictModel;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import opennlp.tools.authorage.AgeClassifyME;
import opennlp.tools.authorage.AgeClassifyModel;
import opennlp.tools.util.InvalidFormatException;
import opennlp.tools.util.featuregen.FeatureGenerator;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.Normalizer;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

/* loaded from: input_file:edu/usc/irds/agepredictor/authorage/AgePredicterLocal.class */
public class AgePredicterLocal {
    private SparkSession spark;
    private AgeClassifyModel classifyModel;
    private AgeClassifyME classify;
    private AgePredictModel model;

    public AgePredicterLocal() throws InvalidFormatException, IOException {
        this("./model/classify-bigram.bin", "./model/regression-global.bin");
    }

    public AgePredicterLocal(String str, String str2) throws InvalidFormatException, IOException {
        this.spark = SparkSession.builder().master("local").appName("AgePredict").getOrCreate();
        this.classifyModel = new AgeClassifyModel(new File(str));
        this.classify = new AgeClassifyME(this.classifyModel);
        this.model = AgePredictModel.readModel(new File(str2));
    }

    public double predictAge(String str) throws InvalidFormatException, IOException {
        FeatureGenerator[] featureGenerators = this.model.getContext().getFeatureGenerators();
        ArrayList arrayList = new ArrayList();
        String[] strArr = this.model.getContext().getTokenizer().tokenize(str);
        String bestCategory = this.classify.getBestCategory(this.classify.getProbabilities(strArr));
        ArrayList arrayList2 = new ArrayList();
        for (FeatureGenerator featureGenerator : featureGenerators) {
            arrayList2.addAll(featureGenerator.extractFeatures(strArr));
        }
        if (bestCategory != null) {
            for (int i = 0; i < strArr.length / 18; i++) {
                arrayList2.add("cat=" + bestCategory);
            }
        }
        if (arrayList2.size() > 0) {
            arrayList.add(RowFactory.create(new Object[]{str, arrayList2.toArray()}));
        }
        SparseVector sparseVector = (SparseVector) ((Row) new Normalizer().setInputCol("feature").setOutputCol("normFeature").setP(1.0d).transform(new CountVectorizerModel(this.model.getVocabulary()).setInputCol("text").setOutputCol("feature").transform(this.spark.createDataFrame(arrayList, new StructType(new StructField[]{new StructField("document", DataTypes.StringType, false, Metadata.empty()), new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())})))).javaRDD().first()).getAs("normFeature");
        return this.model.getModel().predict(Vectors.sparse(sparseVector.size(), sparseVector.indices(), sparseVector.values()).compressed());
    }

    public static void main(String[] strArr) throws Exception {
        String str = "I am very very old person";
        if (strArr.length > 0) {
            StringBuilder sb = new StringBuilder();
            for (String str2 : strArr) {
                sb.append(str2);
                sb.append(" ");
            }
            str = sb.toString();
        }
        double predictAge = new AgePredicterLocal().predictAge(str);
        System.out.println("\n===================\n");
        System.out.println(String.format("Text received- '%s' \n Predicted Age - %f%n", str, Double.valueOf(predictAge)));
        System.out.println("\n===================\n");
    }
}
