/*
 * Decompiled with CFR 0.152.
 */
package edu.usc.irds.agepredictor.spark.authorage;

import edu.usc.irds.agepredictor.spark.authorage.AgeClassifyContextGeneratorWrapper;
import edu.usc.irds.agepredictor.spark.authorage.AgePredictModel;
import edu.usc.irds.agepredictor.spark.authorage.CreateEvents;
import edu.usc.irds.agepredictor.spark.authorage.EventWrapper;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import opennlp.tools.util.TrainingParameters;
import org.apache.commons.io.FileUtils;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.Normalizer;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LassoModel;
import org.apache.spark.mllib.regression.LassoWithSGD;
import org.apache.spark.mllib.stat.Statistics;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;

public class AgePredictSGDTrainer {
    public static final String CUTOFF_PARAM = "Cutoff";
    public static final int CUTOFF_DEFAULT = 5;
    public static final String ITERATIONS_PARAM = "Iterations";
    public static final int ITERATIONS_DEFAULT = 100;
    public static final String STEPSIZE_PARAM = "StepSize";
    public static final double STEPSIZE_DEFAULT = 1.0;
    public static final String REG_PARAM = "Regularization";
    public static final double REG_DEFAULT = 0.1;

    public static void generateEvents(SparkSession spark, String dataIn, String tokenizer, String featureGenerators, String outDir) throws IOException {
        AgeClassifyContextGeneratorWrapper wrapper = new AgeClassifyContextGeneratorWrapper(tokenizer, featureGenerators);
        JavaRDD data = spark.sparkContext().textFile(dataIn, 48).toJavaRDD().cache();
        JavaRDD samples = data.map((Function)new CreateEvents(wrapper)).cache();
        JavaRDD validSamples = samples.filter((Function)new Function<EventWrapper, Boolean>(){

            public Boolean call(EventWrapper s) {
                if (s != null) {
                    return s.getValue() != null;
                }
                return false;
            }
        }).repartition(8);
        File dir = new File(outDir);
        if (dir.exists()) {
            FileUtils.cleanDirectory((File)dir);
            FileUtils.forceDelete((File)dir);
        }
        validSamples.saveAsTextFile(outDir);
    }

    private static int getCutoff(Map<String, String> params) {
        String cutoffString = params.get(CUTOFF_PARAM);
        if (cutoffString != null) {
            return Integer.parseInt(cutoffString);
        }
        return 5;
    }

    private static int getIterations(Map<String, String> params) {
        String iterationString = params.get(ITERATIONS_PARAM);
        if (iterationString != null) {
            return Integer.parseInt(iterationString);
        }
        return 100;
    }

    private static double getStepSize(Map<String, String> params) {
        String stepString = params.get(STEPSIZE_PARAM);
        if (stepString != null) {
            return Double.parseDouble(stepString);
        }
        return 1.0;
    }

    private static double getReg(Map<String, String> params) {
        String regString = params.get(REG_PARAM);
        if (regString != null) {
            return Double.parseDouble(regString);
        }
        return 0.1;
    }

    public static AgePredictModel createModel(String languageCode, SparkSession spark, String eventDir, AgeClassifyContextGeneratorWrapper wrapper, TrainingParameters trainParams) throws IOException {
        Map params = trainParams.getSettings();
        int cutoff = AgePredictSGDTrainer.getCutoff(params);
        int iterations = AgePredictSGDTrainer.getIterations(params);
        JavaRDD data = spark.sparkContext().textFile(eventDir, 24).toJavaRDD().cache();
        JavaRDD samples = data.map((Function)new Function<String, Row>(){

            public Row call(String s) {
                if (s == null) {
                    return null;
                }
                String[] parts = s.split(",");
                if (parts.length != 3) {
                    return null;
                }
                try {
                    if (parts[0] != "-1") {
                        Integer value = Integer.parseInt(parts[0]);
                        String[] text = parts[2].split(" ");
                        ArrayList<String> tokens = new ArrayList<String>(Arrays.asList(text));
                        for (int i = 0; i < text.length / 18; ++i) {
                            tokens.add("cat=" + parts[1]);
                        }
                        return RowFactory.create((Object[])new Object[]{value, tokens.toArray()});
                    }
                    return null;
                }
                catch (Exception e) {
                    return null;
                }
            }
        }).cache();
        JavaRDD validSamples = samples.filter((Function)new Function<Row, Boolean>(){

            public Boolean call(Row s) {
                return s != null;
            }
        }).cache();
        samples.unpersist();
        StructType schema = new StructType(new StructField[]{new StructField("value", DataTypes.IntegerType, false, Metadata.empty()), new StructField("context", (DataType)new ArrayType(DataTypes.StringType, true), false, Metadata.empty())});
        Dataset df = spark.createDataFrame(validSamples, schema).cache();
        CountVectorizerModel cvm = new CountVectorizer().setInputCol("context").setOutputCol("feature").setMinDF((double)cutoff).fit(df);
        Normalizer normalizer = ((Normalizer)((Normalizer)new Normalizer().setInputCol("feature")).setOutputCol("normFeature")).setP(1.0);
        Dataset eventDF = cvm.transform(df).select("value", new String[]{"feature"});
        Dataset normDF = normalizer.transform(eventDF).select("value", new String[]{"normFeature"});
        JavaRDD events = normDF.javaRDD().cache();
        eventDF.unpersist();
        normDF.unpersist();
        JavaRDD parsedData = events.map((Function)new Function<Row, LabeledPoint>(){

            public LabeledPoint call(Row r) {
                Integer val = r.getInt(0);
                SparseVector vec = (SparseVector)r.get(1);
                Vector features = Vectors.sparse((int)vec.size(), (int[])vec.indices(), (double[])vec.values());
                return new LabeledPoint((double)val.intValue(), features);
            }
        }).cache();
        double stepSize = AgePredictSGDTrainer.getStepSize(params);
        double regParam = AgePredictSGDTrainer.getReg(params);
        LassoWithSGD algorithm = (LassoWithSGD)new LassoWithSGD().setIntercept(true);
        algorithm.optimizer().setNumIterations(iterations).setStepSize(stepSize).setRegParam(regParam);
        final LassoModel model = (LassoModel)algorithm.run(JavaRDD.toRDD((JavaRDD)parsedData));
        System.out.println("Coefficients: " + Arrays.toString(model.weights().toArray()));
        System.out.println("Intercept: " + model.intercept());
        JavaRDD valuesAndPreds = parsedData.map((Function)new Function<LabeledPoint, Tuple2<Double, Double>>(){

            public Tuple2<Double, Double> call(LabeledPoint point) {
                double prediction = model.predict(point.features());
                System.out.println(prediction + "," + point.label());
                return new Tuple2((Object)prediction, (Object)point.label());
            }
        }).cache();
        double MAE = new JavaDoubleRDD(valuesAndPreds.map((Function)new Function<Tuple2<Double, Double>, Object>(){

            public Object call(Tuple2<Double, Double> pair) {
                return Math.abs((Double)pair._1() - (Double)pair._2());
            }
        }).rdd()).mean();
        JavaRDD vectors = valuesAndPreds.map((Function)new Function<Tuple2<Double, Double>, Vector>(){

            public Vector call(Tuple2<Double, Double> pair) {
                return Vectors.dense((double)((Double)pair._1()), (double[])new double[]{(Double)pair._2()});
            }
        });
        Matrix correlMatrix = Statistics.corr((RDD)vectors.rdd(), (String)"pearson");
        System.out.println("Training Mean Absolute Error: " + MAE);
        System.out.println("Correlation:\n" + correlMatrix.toString());
        HashMap manifestInfoEntries = new HashMap();
        return new AgePredictModel(languageCode, model, cvm.vocabulary(), wrapper);
    }
}

