/*
 * Decompiled with CFR 0.152.
 */
package edu.usc.irds.agepredictor.spark.authorage;

import edu.usc.irds.agepredictor.spark.authorage.AgeClassifyContextGeneratorWrapper;
import edu.usc.irds.agepredictor.spark.authorage.AgeClassifyModelWrapper;
import edu.usc.irds.agepredictor.spark.authorage.AgePredictModel;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import opennlp.tools.authorage.AgeClassifyME;
import opennlp.tools.util.featuregen.FeatureGenerator;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.Normalizer;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LassoModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;

public class AgePredictEvaluator {
    public static void evaluate(SparkSession spark, File classifyModel, File linModel, File report, String dataIn) throws IOException {
        AgePredictModel model = AgePredictModel.readModel(linModel);
        final AgeClassifyModelWrapper wrapper = classifyModel == null ? null : new AgeClassifyModelWrapper(classifyModel);
        JavaRDD data = spark.sparkContext().textFile(dataIn, 8).toJavaRDD().cache();
        final AgeClassifyContextGeneratorWrapper contextGen = model.getContext();
        JavaRDD samples = data.map((Function)new Function<String, Row>(){

            public Row call(String s) throws IOException {
                FeatureGenerator[] featureGenerators;
                String label = s.split("\t", 2)[0];
                String text = s.split("\t", 2)[1];
                String[] tokens = contextGen.getTokenizer().tokenize(text);
                String category = null;
                if (wrapper != null) {
                    AgeClassifyME classify = wrapper.getClassifier();
                    double[] prob = classify.getProbabilities(tokens);
                    category = classify.getBestCategory(prob);
                }
                ArrayList<String> context = new ArrayList<String>();
                for (FeatureGenerator featureGenerator : featureGenerators = contextGen.getFeatureGenerators()) {
                    Collection extractedFeatures = featureGenerator.extractFeatures(tokens);
                    context.addAll(extractedFeatures);
                }
                if (category != null) {
                    for (int i = 0; i < tokens.length / 18; ++i) {
                        context.add("cat=" + category);
                    }
                }
                if (context.size() > 0) {
                    try {
                        int age = Integer.valueOf(label);
                        return RowFactory.create((Object[])new Object[]{age, context.toArray()});
                    }
                    catch (Exception e) {
                        return null;
                    }
                }
                return null;
            }
        });
        JavaRDD validSamples = samples.filter((Function)new Function<Row, Boolean>(){

            public Boolean call(Row s) {
                return s != null;
            }
        }).cache();
        samples.unpersist();
        StructType schema = new StructType(new StructField[]{new StructField("value", DataTypes.IntegerType, false, Metadata.empty()), new StructField("context", (DataType)new ArrayType(DataTypes.StringType, true), false, Metadata.empty())});
        Dataset df = spark.createDataFrame(validSamples, schema).cache();
        System.out.println("Vocab: " + model.getVocabulary());
        CountVectorizerModel cvm = new CountVectorizerModel(model.getVocabulary()).setInputCol("context").setOutputCol("feature");
        Normalizer normalizer = ((Normalizer)((Normalizer)new Normalizer().setInputCol("feature")).setOutputCol("norm")).setP(1.0);
        Dataset eventDF = cvm.transform(df).select("value", new String[]{"feature"});
        JavaRDD events = normalizer.transform(eventDF).select("value", new String[]{"norm"}).javaRDD().cache();
        eventDF.unpersist();
        JavaRDD parsedData = events.map((Function)new Function<Row, LabeledPoint>(){

            public LabeledPoint call(Row r) {
                Integer val = r.getInt(0);
                SparseVector vec = (SparseVector)r.get(1);
                Vector features = Vectors.sparse((int)vec.size(), (int[])vec.indices(), (double[])vec.values());
                return new LabeledPoint((double)val.intValue(), features);
            }
        });
        parsedData.cache();
        final LassoModel reg = model.getModel();
        JavaRDD valuesAndPreds = parsedData.map((Function)new Function<LabeledPoint, Tuple2<Double, Double>>(){

            public Tuple2<Double, Double> call(LabeledPoint point) {
                double prediction = reg.predict(point.features());
                return new Tuple2((Object)prediction, (Object)point.label());
            }
        }).cache();
        double MAE = new JavaDoubleRDD(valuesAndPreds.map((Function)new Function<Tuple2<Double, Double>, Object>(){

            public Object call(Tuple2<Double, Double> pair) {
                return Math.abs((Double)pair._1() - (Double)pair._2());
            }
        }).rdd()).mean();
        if (report != null) {
            Iterator iterator = valuesAndPreds.toLocalIterator();
            report.createNewFile();
            FileWriter writer = new FileWriter(report);
            while (iterator.hasNext()) {
                Tuple2 pair = (Tuple2)iterator.next();
                writer.write(pair._1() + "," + pair._2() + "\n");
            }
            writer.close();
        }
        System.out.println("Mean Absolute Error: " + MAE);
    }
}

