package org.apache.spark.ml.tree.impl;

import org.apache.spark.SparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer;
import org.apache.spark.mllib.tree.configuration.Algo$;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
import org.apache.spark.mllib.tree.configuration.Strategy;
import org.apache.spark.mllib.tree.impurity.Variance$;
import org.apache.spark.mllib.tree.loss.Loss;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDD$;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import org.slf4j.Logger;
import scala.Array$;
import scala.Enumeration;
import scala.Function0;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.math.Ordering$Double$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;

/* compiled from: GradientBoostedTrees.scala */
/* loaded from: input_file:org/apache/spark/ml/tree/impl/GradientBoostedTrees$.class */
public final class GradientBoostedTrees$ implements Logging {
    public static final GradientBoostedTrees$ MODULE$ = null;
    private transient Logger org$apache$spark$internal$Logging$$log_;

    static {
        new GradientBoostedTrees$();
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger logger) {
        this.org$apache$spark$internal$Logging$$log_ = logger;
    }

    public String logName() {
        return Logging.class.logName(this);
    }

    public Logger log() {
        return Logging.class.log(this);
    }

    public void logInfo(Function0<String> function0) {
        Logging.class.logInfo(this, function0);
    }

    public void logDebug(Function0<String> function0) {
        Logging.class.logDebug(this, function0);
    }

    public void logTrace(Function0<String> function0) {
        Logging.class.logTrace(this, function0);
    }

    public void logWarning(Function0<String> function0) {
        Logging.class.logWarning(this, function0);
    }

    public void logError(Function0<String> function0) {
        Logging.class.logError(this, function0);
    }

    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.class.logInfo(this, function0, th);
    }

    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.class.logDebug(this, function0, th);
    }

    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.class.logTrace(this, function0, th);
    }

    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.class.logWarning(this, function0, th);
    }

    public void logError(Function0<String> function0, Throwable th) {
        Logging.class.logError(this, function0, th);
    }

    public boolean isTraceEnabled() {
        return Logging.class.isTraceEnabled(this);
    }

    public void initializeLogIfNecessary(boolean z) {
        Logging.class.initializeLogIfNecessary(this, z);
    }

    public Tuple2<DecisionTreeRegressionModel[], double[]> run(RDD<LabeledPoint> rdd, BoostingStrategy boostingStrategy, long j) {
        Tuple2<DecisionTreeRegressionModel[], double[]> boost;
        Enumeration.Value algo = boostingStrategy.treeStrategy().algo();
        Enumeration.Value Regression = Algo$.MODULE$.Regression();
        if (Regression != null ? !Regression.equals(algo) : algo != null) {
            Enumeration.Value Classification = Algo$.MODULE$.Classification();
            if (Classification != null ? !Classification.equals(algo) : algo != null) {
                throw new IllegalArgumentException(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", " is not supported by gradient boosting."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{algo})));
            }
            RDD<LabeledPoint> map = rdd.map(new GradientBoostedTrees$$anonfun$2(), ClassTag$.MODULE$.apply(LabeledPoint.class));
            boost = boost(map, map, boostingStrategy, false, j);
        } else {
            boost = boost(rdd, rdd, boostingStrategy, false, j);
        }
        return boost;
    }

    public Tuple2<DecisionTreeRegressionModel[], double[]> runWithValidation(RDD<LabeledPoint> rdd, RDD<LabeledPoint> rdd2, BoostingStrategy boostingStrategy, long j) {
        Tuple2<DecisionTreeRegressionModel[], double[]> boost;
        Enumeration.Value algo = boostingStrategy.treeStrategy().algo();
        Enumeration.Value Regression = Algo$.MODULE$.Regression();
        if (Regression != null ? !Regression.equals(algo) : algo != null) {
            Enumeration.Value Classification = Algo$.MODULE$.Classification();
            if (Classification != null ? !Classification.equals(algo) : algo != null) {
                throw new IllegalArgumentException(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", " is not supported by the gradient boosting."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{algo})));
            }
            boost = boost(rdd.map(new GradientBoostedTrees$$anonfun$3(), ClassTag$.MODULE$.apply(LabeledPoint.class)), rdd2.map(new GradientBoostedTrees$$anonfun$4(), ClassTag$.MODULE$.apply(LabeledPoint.class)), boostingStrategy, true, j);
        } else {
            boost = boost(rdd, rdd2, boostingStrategy, true, j);
        }
        return boost;
    }

    public RDD<Tuple2<Object, Object>> computeInitialPredictionAndError(RDD<LabeledPoint> rdd, double d, DecisionTreeRegressionModel decisionTreeRegressionModel, Loss loss) {
        return rdd.map(new GradientBoostedTrees$$anonfun$computeInitialPredictionAndError$1(d, decisionTreeRegressionModel, loss), ClassTag$.MODULE$.apply(Tuple2.class));
    }

    public RDD<Tuple2<Object, Object>> updatePredictionError(RDD<LabeledPoint> rdd, RDD<Tuple2<Object, Object>> rdd2, double d, DecisionTreeRegressionModel decisionTreeRegressionModel, Loss loss) {
        RDD zip = rdd.zip(rdd2, ClassTag$.MODULE$.apply(Tuple2.class));
        return zip.mapPartitions(new GradientBoostedTrees$$anonfun$5(d, decisionTreeRegressionModel, loss), zip.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class));
    }

    public double updatePrediction(Vector vector, double d, DecisionTreeRegressionModel decisionTreeRegressionModel, double d2) {
        return d + (decisionTreeRegressionModel.rootNode().predictImpl(vector).prediction() * d2);
    }

    public double computeError(RDD<LabeledPoint> rdd, DecisionTreeRegressionModel[] decisionTreeRegressionModelArr, double[] dArr, Loss loss) {
        return RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(rdd.map(new GradientBoostedTrees$$anonfun$computeError$1(decisionTreeRegressionModelArr, dArr, loss), ClassTag$.MODULE$.Double())).mean();
    }

    public double[] evaluateEachIteration(RDD<LabeledPoint> rdd, DecisionTreeRegressionModel[] decisionTreeRegressionModelArr, double[] dArr, Loss loss, Enumeration.Value value) {
        SparkContext sparkContext = rdd.sparkContext();
        Enumeration.Value Classification = Algo$.MODULE$.Classification();
        RDD<LabeledPoint> map = (Classification != null ? !Classification.equals(value) : value != null) ? rdd : rdd.map(new GradientBoostedTrees$$anonfun$7(), ClassTag$.MODULE$.apply(LabeledPoint.class));
        int length = decisionTreeRegressionModelArr.length;
        double[] dArr2 = (double[]) Array$.MODULE$.fill(length, new GradientBoostedTrees$$anonfun$1(), ClassTag$.MODULE$.Double());
        ObjectRef create = ObjectRef.create(computeInitialPredictionAndError(map, dArr[0], decisionTreeRegressionModelArr[0], loss));
        dArr2[0] = RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions((RDD) create.elem, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), Ordering$Double$.MODULE$).values()).mean();
        Broadcast broadcast = sparkContext.broadcast(decisionTreeRegressionModelArr, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(DecisionTreeRegressionModel.class)));
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(1), length).foreach$mVc$sp(new GradientBoostedTrees$$anonfun$evaluateEachIteration$1(loss, map, dArr2, dArr, create, broadcast));
        broadcast.unpersist();
        return dArr2;
    }

    public Tuple2<DecisionTreeRegressionModel[], double[]> boost(RDD<LabeledPoint> rdd, RDD<LabeledPoint> rdd2, BoostingStrategy boostingStrategy, boolean z, long j) {
        boolean z2;
        TimeTracker timeTracker = new TimeTracker();
        timeTracker.start("total");
        timeTracker.start("init");
        boostingStrategy.assertValid();
        int numIterations = boostingStrategy.numIterations();
        DecisionTreeRegressionModel[] decisionTreeRegressionModelArr = new DecisionTreeRegressionModel[numIterations];
        double[] dArr = new double[numIterations];
        Loss loss = boostingStrategy.loss();
        double learningRate = boostingStrategy.learningRate();
        Strategy copy = boostingStrategy.treeStrategy().copy();
        double validationTol = boostingStrategy.validationTol();
        copy.algo_$eq(Algo$.MODULE$.Regression());
        copy.impurity_$eq(Variance$.MODULE$);
        copy.assertValid();
        StorageLevel storageLevel = rdd.getStorageLevel();
        StorageLevel NONE = StorageLevel$.MODULE$.NONE();
        if (storageLevel != null ? !storageLevel.equals(NONE) : NONE != null) {
            z2 = false;
        } else {
            rdd.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK());
            z2 = true;
        }
        boolean z3 = z2;
        PeriodicRDDCheckpointer periodicRDDCheckpointer = new PeriodicRDDCheckpointer(copy.getCheckpointInterval(), rdd.sparkContext());
        PeriodicRDDCheckpointer periodicRDDCheckpointer2 = new PeriodicRDDCheckpointer(copy.getCheckpointInterval(), rdd.sparkContext());
        timeTracker.stop("init");
        logDebug(new GradientBoostedTrees$$anonfun$boost$1());
        logDebug(new GradientBoostedTrees$$anonfun$boost$2());
        logDebug(new GradientBoostedTrees$$anonfun$boost$3());
        timeTracker.start("building tree 0");
        DecisionTreeRegressionModel train = new DecisionTreeRegressor().setSeed(j).train(rdd, copy);
        decisionTreeRegressionModelArr[0] = train;
        dArr[0] = 1.0d;
        ObjectRef create = ObjectRef.create(computeInitialPredictionAndError(rdd, 1.0d, train, loss));
        periodicRDDCheckpointer.update((RDD) create.elem);
        logDebug(new GradientBoostedTrees$$anonfun$boost$4(create));
        timeTracker.stop("building tree 0");
        RDD<Tuple2<Object, Object>> computeInitialPredictionAndError = computeInitialPredictionAndError(rdd2, 1.0d, train, loss);
        if (z) {
            periodicRDDCheckpointer2.update(computeInitialPredictionAndError);
        }
        double mean = z ? RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions(computeInitialPredictionAndError, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), Ordering$Double$.MODULE$).values()).mean() : 0.0d;
        int i = 1;
        IntRef create2 = IntRef.create(1);
        boolean z4 = false;
        while (create2.elem < numIterations && !z4) {
            RDD<LabeledPoint> map = ((RDD) create.elem).zip(rdd, ClassTag$.MODULE$.apply(LabeledPoint.class)).map(new GradientBoostedTrees$$anonfun$9(loss), ClassTag$.MODULE$.apply(LabeledPoint.class));
            timeTracker.start(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"building tree ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(create2.elem)})));
            logDebug(new GradientBoostedTrees$$anonfun$boost$5());
            logDebug(new GradientBoostedTrees$$anonfun$boost$6(create2));
            logDebug(new GradientBoostedTrees$$anonfun$boost$7());
            DecisionTreeRegressionModel train2 = new DecisionTreeRegressor().setSeed(j + create2.elem).train(map, copy);
            timeTracker.stop(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"building tree ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(create2.elem)})));
            decisionTreeRegressionModelArr[create2.elem] = train2;
            dArr[create2.elem] = learningRate;
            create.elem = updatePredictionError(rdd, (RDD) create.elem, dArr[create2.elem], decisionTreeRegressionModelArr[create2.elem], loss);
            periodicRDDCheckpointer.update((RDD) create.elem);
            logDebug(new GradientBoostedTrees$$anonfun$boost$8(create));
            if (z) {
                computeInitialPredictionAndError = updatePredictionError(rdd2, computeInitialPredictionAndError, dArr[create2.elem], decisionTreeRegressionModelArr[create2.elem], loss);
                periodicRDDCheckpointer2.update(computeInitialPredictionAndError);
                double mean2 = RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions(computeInitialPredictionAndError, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), Ordering$Double$.MODULE$).values()).mean();
                if (mean - mean2 < validationTol * Math.max(mean2, 0.01d)) {
                    z4 = true;
                } else if (mean2 < mean) {
                    mean = mean2;
                    i = create2.elem + 1;
                }
            }
            create2.elem++;
        }
        timeTracker.stop("total");
        logInfo(new GradientBoostedTrees$$anonfun$boost$9());
        logInfo(new GradientBoostedTrees$$anonfun$boost$10(timeTracker));
        periodicRDDCheckpointer.deleteAllCheckpoints();
        periodicRDDCheckpointer2.deleteAllCheckpoints();
        if (z3) {
            rdd.unpersist(rdd.unpersist$default$1());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        return z ? new Tuple2<>(Predef$.MODULE$.refArrayOps(decisionTreeRegressionModelArr).slice(0, i), Predef$.MODULE$.doubleArrayOps(dArr).slice(0, i)) : new Tuple2<>(decisionTreeRegressionModelArr, dArr);
    }

    private GradientBoostedTrees$() {
        MODULE$ = this;
        Logging.class.$init$(this);
    }
}
