package org.apache.spark.ml;

import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.Predictor;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.ml.param.shared.HasWeightCol;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StructType;
import scala.MatchError;
import scala.Predef$;
import scala.Some;
import scala.collection.SeqLike;
import scala.collection.immutable.StringOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: Predictor.scala */
@DeveloperApi
@ScalaSignature(bytes = "\u0006\u0001\u0005Mb!B\u0001\u0003\u0003\u0003Y!!\u0003)sK\u0012L7\r^8s\u0015\t\u0019A!\u0001\u0002nY*\u0011QAB\u0001\u0006gB\f'o\u001b\u0006\u0003\u000f!\ta!\u00199bG\",'\"A\u0005\u0002\u0007=\u0014xm\u0001\u0001\u0016\t1\u0001cfE\n\u0004\u000151\u0003c\u0001\b\u0010#5\t!!\u0003\u0002\u0011\u0005\tIQi\u001d;j[\u0006$xN\u001d\t\u0003%Ma\u0001\u0001B\u0003\u0015\u0001\t\u0007QCA\u0001N#\t1B\u0004\u0005\u0002\u001855\t\u0001DC\u0001\u001a\u0003\u0015\u00198-\u00197b\u0013\tY\u0002DA\u0004O_RD\u0017N\\4\u0011\t9ir$E\u0005\u0003=\t\u0011q\u0002\u0015:fI&\u001cG/[8o\u001b>$W\r\u001c\t\u0003%\u0001\"Q!\t\u0001C\u0002\t\u0012ABR3biV\u0014Xm\u001d+za\u0016\f\"AF\u0012\u0011\u0005]!\u0013BA\u0013\u0019\u0005\r\te.\u001f\t\u0003\u001d\u001dJ!\u0001\u000b\u0002\u0003\u001fA\u0013X\rZ5di>\u0014\b+\u0019:b[NDQA\u000b\u0001\u0005\u0002-\na\u0001P5oSRtD#\u0001\u0017\u0011\u000b9\u0001q$L\t\u0011\u0005IqC!B\u0018\u0001\u0005\u0004\u0001$a\u0002'fCJtWM]\t\u0003-1BQA\r\u0001\u0005\u0002M\n1b]3u\u0019\u0006\u0014W\r\\\"pYR\u0011Q\u0006\u000e\u0005\u0006kE\u0002\rAN\u0001\u0006m\u0006dW/\u001a\t\u0003oyr!\u0001\u000f\u001f\u0011\u0005eBR\"\u0001\u001e\u000b\u0005mR\u0011A\u0002\u001fs_>$h(\u0003\u0002>1\u00051\u0001K]3eK\u001aL!a\u0010!\u0003\rM#(/\u001b8h\u0015\ti\u0004\u0004C\u0003C\u0001\u0011\u00051)\u0001\btKR4U-\u0019;ve\u0016\u001c8i\u001c7\u0015\u00055\"\u0005\"B\u001bB\u0001\u00041\u0004\"\u0002$\u0001\t\u00039\u0015\u0001E:fiB\u0013X\rZ5di&|gnQ8m)\ti\u0003\nC\u00036\u000b\u0002\u0007a\u0007C\u0003K\u0001\u0011\u00053*A\u0002gSR$\"!\u0005'\t\u000b5K\u0005\u0019\u0001(\u0002\u000f\u0011\fG/Y:fiB\u0012qJ\u0016\t\u0004!N+V\"A)\u000b\u0005I#\u0011aA:rY&\u0011A+\u0015\u0002\b\t\u0006$\u0018m]3u!\t\u0011b\u000bB\u0005X\u0019\u0006\u0005\t\u0011!B\u0001E\t\u0019q\fJ\u0019\t\u000be\u0003a\u0011\t.\u0002\t\r|\u0007/\u001f\u000b\u0003[mCQ\u0001\u0018-A\u0002u\u000bQ!\u001a=ue\u0006\u0004\"AX1\u000e\u0003}S!\u0001\u0019\u0002\u0002\u000bA\f'/Y7\n\u0005\t|&\u0001\u0003)be\u0006lW*\u00199\t\u000b\u0011\u0004a\u0011C3\u0002\u000bQ\u0014\u0018-\u001b8\u0015\u0005E1\u0007\"B'd\u0001\u00049\u0007G\u00015k!\r\u00016+\u001b\t\u0003%)$\u0011b\u001b4\u0002\u0002\u0003\u0005)\u0011\u0001\u0012\u0003\u0007}##\u0007\u0003\u0004n\u0001\u0011\u0005!A\\\u0001\u0011M\u0016\fG/\u001e:fg\u0012\u000bG/\u0019+za\u0016,\u0012a\u001c\t\u0003aNl\u0011!\u001d\u0006\u0003eF\u000bQ\u0001^=qKNL!\u0001^9\u0003\u0011\u0011\u000bG/\u0019+za\u0016DQA\u001e\u0001\u0005B]\fq\u0002\u001e:b]N4wN]7TG\",W.\u0019\u000b\u0003qn\u0004\"\u0001]=\n\u0005i\f(AC*ueV\u001cG\u000fV=qK\")A0\u001ea\u0001q\u000611o\u00195f[\u0006DQA \u0001\u0005\u0012}\fA#\u001a=ue\u0006\u001cG\u000fT1cK2,G\rU8j]R\u001cH\u0003BA\u0001\u00033\u0001b!a\u0001\u0002\n\u00055QBAA\u0003\u0015\r\t9\u0001B\u0001\u0004e\u0012$\u0017\u0002BA\u0006\u0003\u000b\u00111A\u0015#E!\u0011\ty!!\u0006\u000e\u0005\u0005E!bAA\n\u0005\u00059a-Z1ukJ,\u0017\u0002BA\f\u0003#\u0011A\u0002T1cK2,G\rU8j]RDa!T?A\u0002\u0005m\u0001\u0007BA\u000f\u0003C\u0001B\u0001U*\u0002 A\u0019!#!\t\u0005\u0017\u0005\r\u0012\u0011DA\u0001\u0002\u0003\u0015\tA\t\u0002\u0004?\u0012\u001a\u0004f\u0001\u0001\u0002(A!\u0011\u0011FA\u0018\u001b\t\tYCC\u0002\u0002.\u0011\t!\"\u00198o_R\fG/[8o\u0013\u0011\t\t$a\u000b\u0003\u0019\u0011+g/\u001a7pa\u0016\u0014\u0018\t]5")
/* loaded from: input_file:org/apache/spark/ml/Predictor.class */
public abstract class Predictor<FeaturesType, Learner extends Predictor<FeaturesType, Learner, M>, M extends PredictionModel<FeaturesType, M>> extends Estimator<M> implements PredictorParams {
    private final Param<String> predictionCol;
    private final Param<String> featuresCol;
    private final Param<String> labelCol;

    @Override // org.apache.spark.ml.PredictorParams
    public StructType validateAndTransformSchema(StructType structType, boolean z, DataType dataType) {
        StructType validateAndTransformSchema;
        validateAndTransformSchema = validateAndTransformSchema(structType, z, dataType);
        return validateAndTransformSchema;
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final String getPredictionCol() {
        String predictionCol;
        predictionCol = getPredictionCol();
        return predictionCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasFeaturesCol
    public final String getFeaturesCol() {
        String featuresCol;
        featuresCol = getFeaturesCol();
        return featuresCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final String getLabelCol() {
        String labelCol;
        labelCol = getLabelCol();
        return labelCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final Param<String> predictionCol() {
        return this.predictionCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final void org$apache$spark$ml$param$shared$HasPredictionCol$_setter_$predictionCol_$eq(Param<String> param) {
        this.predictionCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasFeaturesCol
    public final Param<String> featuresCol() {
        return this.featuresCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasFeaturesCol
    public final void org$apache$spark$ml$param$shared$HasFeaturesCol$_setter_$featuresCol_$eq(Param<String> param) {
        this.featuresCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final Param<String> labelCol() {
        return this.labelCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final void org$apache$spark$ml$param$shared$HasLabelCol$_setter_$labelCol_$eq(Param<String> param) {
        this.labelCol = param;
    }

    public Learner setLabelCol(String str) {
        return (Learner) set((Param<Param>) labelCol(), (Param) str);
    }

    public Learner setFeaturesCol(String str) {
        return (Learner) set((Param<Param>) featuresCol(), (Param) str);
    }

    public Learner setPredictionCol(String str) {
        return (Learner) set((Param<Param>) predictionCol(), (Param) str);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.spark.ml.Estimator
    public M fit(Dataset<?> dataset) {
        Dataset dataset2;
        Dataset dataset3;
        transformSchema(dataset.schema(), true);
        Dataset withColumn = dataset.withColumn((String) $(labelCol()), functions$.MODULE$.col((String) $(labelCol())).cast(DoubleType$.MODULE$), dataset.schema().apply((String) $(labelCol())).metadata());
        if (this instanceof HasWeightCol) {
            if (isDefined(((HasWeightCol) this).weightCol()) && new StringOps(Predef$.MODULE$.augmentString((String) $(((HasWeightCol) this).weightCol()))).nonEmpty()) {
                dataset3 = withColumn.withColumn((String) $(((HasWeightCol) this).weightCol()), functions$.MODULE$.col((String) $(((HasWeightCol) this).weightCol())).cast(DoubleType$.MODULE$), dataset.schema().apply((String) $(((HasWeightCol) this).weightCol())).metadata());
            } else {
                dataset3 = withColumn;
            }
            dataset2 = dataset3;
        } else {
            dataset2 = withColumn;
        }
        return (M) copyValues(train(dataset2).setParent(this), copyValues$default$2());
    }

    @Override // org.apache.spark.ml.Estimator, org.apache.spark.ml.PipelineStage, org.apache.spark.ml.param.Params
    public abstract Learner copy(ParamMap paramMap);

    public abstract M train(Dataset<?> dataset);

    public DataType featuresDataType() {
        return new VectorUDT();
    }

    @Override // org.apache.spark.ml.PipelineStage
    public StructType transformSchema(StructType structType) {
        return validateAndTransformSchema(structType, true, featuresDataType());
    }

    public RDD<LabeledPoint> extractLabeledPoints(Dataset<?> dataset) {
        return dataset.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) $(labelCol())), functions$.MODULE$.col((String) $(featuresCol()))})).rdd().map(row -> {
            Some unapplySeq = Row$.MODULE$.unapplySeq(row);
            if (!unapplySeq.isEmpty() && unapplySeq.get() != null && ((SeqLike) unapplySeq.get()).lengthCompare(2) == 0) {
                Object apply = ((SeqLike) unapplySeq.get()).apply(0);
                Object apply2 = ((SeqLike) unapplySeq.get()).apply(1);
                if (apply instanceof Double) {
                    double unboxToDouble = BoxesRunTime.unboxToDouble(apply);
                    if (apply2 instanceof Vector) {
                        return new LabeledPoint(unboxToDouble, (Vector) apply2);
                    }
                }
            }
            throw new MatchError(row);
        }, ClassTag$.MODULE$.apply(LabeledPoint.class));
    }

    @Override // org.apache.spark.ml.Estimator
    public /* bridge */ /* synthetic */ Model fit(Dataset dataset) {
        return fit((Dataset<?>) dataset);
    }

    public Predictor() {
        HasLabelCol.$init$((HasLabelCol) this);
        HasFeaturesCol.$init$((HasFeaturesCol) this);
        HasPredictionCol.$init$((HasPredictionCol) this);
        PredictorParams.$init$((PredictorParams) this);
    }
}
