package org.apache.spark.ml.classification;

import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.shared.HasRawPredictionCol;
import org.apache.spark.ml.util.SchemaUtils$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;
import scala.Predef$;
import scala.collection.immutable.StringOps;
import scala.reflect.ScalaSignature;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxesRunTime;

/* compiled from: Classifier.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005-b!\u0002\u0006\f\u0003\u00031\u0002\"B\u001a\u0001\t\u0003!\u0004\"B\u001b\u0001\t\u00031\u0004\"\u0002#\u0001\r\u0003)\u0005\"B%\u0001\t\u0003R\u0005\"B+\u0001\t\u00032\u0006\"\u00029\u0001\t\u000b\n\b\"\u0002=\u0001\t\u0003J\bBB@\u0001\r\u0003\t\t\u0001C\u0004\u0002$\u0001!\t\"!\n\u0003'\rc\u0017m]:jM&\u001c\u0017\r^5p]6{G-\u001a7\u000b\u00051i\u0011AD2mCN\u001c\u0018NZ5dCRLwN\u001c\u0006\u0003\u001d=\t!!\u001c7\u000b\u0005A\t\u0012!B:qCJ\\'B\u0001\n\u0014\u0003\u0019\t\u0007/Y2iK*\tA#A\u0002pe\u001e\u001c\u0001!F\u0002\u0018=-\u001a2\u0001\u0001\r1!\u0011I\"\u0004\b\u0016\u000e\u00035I!aG\u0007\u0003\u001fA\u0013X\rZ5di&|g.T8eK2\u0004\"!\b\u0010\r\u0001\u0011)q\u0004\u0001b\u0001A\taa)Z1ukJ,7\u000fV=qKF\u0011\u0011e\n\t\u0003E\u0015j\u0011a\t\u0006\u0002I\u0005)1oY1mC&\u0011ae\t\u0002\b\u001d>$\b.\u001b8h!\t\u0011\u0003&\u0003\u0002*G\t\u0019\u0011I\\=\u0011\u0005uYC!\u0002\u0017\u0001\u0005\u0004i#!A'\u0012\u0005\u0005r\u0003\u0003B\u0018\u00019)j\u0011a\u0003\t\u0003_EJ!AM\u0006\u0003!\rc\u0017m]:jM&,'\u000fU1sC6\u001c\u0018A\u0002\u001fj]&$h\bF\u0001/\u0003M\u0019X\r\u001e*boB\u0013X\rZ5di&|gnQ8m)\tQs\u0007C\u00039\u0005\u0001\u0007\u0011(A\u0003wC2,X\r\u0005\u0002;\u0003:\u00111h\u0010\t\u0003y\rj\u0011!\u0010\u0006\u0003}U\ta\u0001\u0010:p_Rt\u0014B\u0001!$\u0003\u0019\u0001&/\u001a3fM&\u0011!i\u0011\u0002\u0007'R\u0014\u0018N\\4\u000b\u0005\u0001\u001b\u0013A\u00038v[\u000ec\u0017m]:fgV\ta\t\u0005\u0002#\u000f&\u0011\u0001j\t\u0002\u0004\u0013:$\u0018a\u0004;sC:\u001chm\u001c:n'\u000eDW-\\1\u0015\u0005-\u001b\u0006C\u0001'R\u001b\u0005i%B\u0001(P\u0003\u0015!\u0018\u0010]3t\u0015\t\u0001v\"A\u0002tc2L!AU'\u0003\u0015M#(/^2u)f\u0004X\rC\u0003U\t\u0001\u00071*\u0001\u0004tG\",W.Y\u0001\niJ\fgn\u001d4pe6$\"a\u00164\u0011\u0005a\u001bgBA-b\u001d\tQ\u0006M\u0004\u0002\\?:\u0011AL\u0018\b\u0003yuK\u0011\u0001F\u0005\u0003%MI!\u0001E\t\n\u0005A{\u0011B\u00012P\u0003\u001d\u0001\u0018mY6bO\u0016L!\u0001Z3\u0003\u0013\u0011\u000bG/\u0019$sC6,'B\u00012P\u0011\u00159W\u00011\u0001i\u0003\u001d!\u0017\r^1tKR\u0004$!\u001b8\u0011\u0007)\\W.D\u0001P\u0013\tawJA\u0004ECR\f7/\u001a;\u0011\u0005uqG!C8g\u0003\u0003\u0005\tQ!\u0001!\u0005\ryF\u0005N\u0001\u000eiJ\fgn\u001d4pe6LU\u000e\u001d7\u0015\u0005]\u0013\b\"B4\u0007\u0001\u0004\u0019\bG\u0001;w!\rQ7.\u001e\t\u0003;Y$\u0011b\u001e:\u0002\u0002\u0003\u0005)\u0011\u0001\u0011\u0003\u0007}#S'A\u0004qe\u0016$\u0017n\u0019;\u0015\u0005il\bC\u0001\u0012|\u0013\ta8E\u0001\u0004E_V\u0014G.\u001a\u0005\u0006}\u001e\u0001\r\u0001H\u0001\tM\u0016\fG/\u001e:fg\u0006Q\u0001O]3eS\u000e$(+Y<\u0015\t\u0005\r\u0011q\u0002\t\u0005\u0003\u000b\tY!\u0004\u0002\u0002\b)\u0019\u0011\u0011B\u0007\u0002\r1Lg.\u00197h\u0013\u0011\ti!a\u0002\u0003\rY+7\r^8s\u0011\u0015q\b\u00021\u0001\u001dQ\u0015A\u00111CA\u0010!\u0011\t)\"a\u0007\u000e\u0005\u0005]!bAA\r\u001f\u0005Q\u0011M\u001c8pi\u0006$\u0018n\u001c8\n\t\u0005u\u0011q\u0003\u0002\u0006'&t7-Z\u0011\u0003\u0003C\tQa\r\u00181]A\naB]1xeA\u0014X\rZ5di&|g\u000eF\u0002{\u0003OAq!!\u000b\n\u0001\u0004\t\u0019!A\u0007sC^\u0004&/\u001a3jGRLwN\u001c")
/* loaded from: input_file:org/apache/spark/ml/classification/ClassificationModel.class */
public abstract class ClassificationModel<FeaturesType, M extends ClassificationModel<FeaturesType, M>> extends PredictionModel<FeaturesType, M> implements ClassifierParams {
    private final Param<String> rawPredictionCol;

    @Override // org.apache.spark.ml.classification.ClassifierParams
    public /* synthetic */ StructType org$apache$spark$ml$classification$ClassifierParams$$super$validateAndTransformSchema(StructType structType, boolean z, DataType dataType) {
        StructType validateAndTransformSchema;
        validateAndTransformSchema = validateAndTransformSchema(structType, z, dataType);
        return validateAndTransformSchema;
    }

    @Override // org.apache.spark.ml.PredictionModel, org.apache.spark.ml.PredictorParams
    public StructType validateAndTransformSchema(StructType structType, boolean z, DataType dataType) {
        return ClassifierParams.validateAndTransformSchema$((ClassifierParams) this, structType, z, dataType);
    }

    @Override // org.apache.spark.ml.classification.ClassifierParams
    public RDD<Instance> extractInstances(Dataset<?> dataset, int i) {
        return ClassifierParams.extractInstances$(this, dataset, i);
    }

    @Override // org.apache.spark.ml.param.shared.HasRawPredictionCol
    public final String getRawPredictionCol() {
        String rawPredictionCol;
        rawPredictionCol = getRawPredictionCol();
        return rawPredictionCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasRawPredictionCol
    public final Param<String> rawPredictionCol() {
        return this.rawPredictionCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasRawPredictionCol
    public final void org$apache$spark$ml$param$shared$HasRawPredictionCol$_setter_$rawPredictionCol_$eq(Param<String> param) {
        this.rawPredictionCol = param;
    }

    public M setRawPredictionCol(String str) {
        return (M) set((Param<Param>) rawPredictionCol(), (Param) str);
    }

    public abstract int numClasses();

    @Override // org.apache.spark.ml.PredictionModel, org.apache.spark.ml.PipelineStage
    public StructType transformSchema(StructType structType) {
        StructType transformSchema = super.transformSchema(structType);
        if (new StringOps(Predef$.MODULE$.augmentString((String) $(predictionCol()))).nonEmpty()) {
            transformSchema = SchemaUtils$.MODULE$.updateNumValues(structType, (String) $(predictionCol()), numClasses());
        }
        if (new StringOps(Predef$.MODULE$.augmentString((String) $(rawPredictionCol()))).nonEmpty()) {
            transformSchema = SchemaUtils$.MODULE$.updateAttributeGroupSize(transformSchema, (String) $(rawPredictionCol()), numClasses());
        }
        return transformSchema;
    }

    @Override // org.apache.spark.ml.PredictionModel, org.apache.spark.ml.Transformer
    public Dataset<Row> transform(Dataset<?> dataset) {
        Column apply;
        StructType transformSchema = transformSchema(dataset.schema(), true);
        Dataset<?> dataset2 = dataset;
        int i = 0;
        String rawPredictionCol = getRawPredictionCol();
        if (rawPredictionCol != null ? !rawPredictionCol.equals("") : "" != 0) {
            final ClassificationModel classificationModel = null;
            dataset2 = dataset2.withColumn(getRawPredictionCol(), functions$.MODULE$.udf(obj -> {
                return this.predictRaw(obj);
            }, package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(ClassificationModel.class.getClassLoader()), new TypeCreator(classificationModel) { // from class: org.apache.spark.ml.classification.ClassificationModel$$typecreator1$1
                public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                    mirror.universe();
                    return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                }
            }), package$.MODULE$.universe().TypeTag().Any()).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(getFeaturesCol())})), transformSchema.apply((String) $(rawPredictionCol())).metadata());
            i = 0 + 1;
        }
        String predictionCol = getPredictionCol();
        if (predictionCol != null ? !predictionCol.equals("") : "" != 0) {
            String rawPredictionCol2 = getRawPredictionCol();
            if (rawPredictionCol2 != null ? rawPredictionCol2.equals("") : "" == 0) {
                apply = functions$.MODULE$.udf(obj2 -> {
                    return BoxesRunTime.boxToDouble(this.predict(obj2));
                }, package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().Any()).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(getFeaturesCol())}));
            } else {
                final ClassificationModel classificationModel2 = null;
                apply = functions$.MODULE$.udf(vector -> {
                    return BoxesRunTime.boxToDouble(this.raw2prediction(vector));
                }, package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(ClassificationModel.class.getClassLoader()), new TypeCreator(classificationModel2) { // from class: org.apache.spark.ml.classification.ClassificationModel$$typecreator2$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                })).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(getRawPredictionCol())}));
            }
            dataset2 = dataset2.withColumn(getPredictionCol(), apply, transformSchema.apply((String) $(predictionCol())).metadata());
            i++;
        }
        if (i == 0) {
            logWarning(() -> {
                return new StringBuilder(82).append(this.uid()).append(": ClassificationModel.transform() does nothing").append(" because no output columns were set.").toString();
            });
        }
        return dataset2.toDF();
    }

    @Override // org.apache.spark.ml.PredictionModel
    public final Dataset<Row> transformImpl(Dataset<?> dataset) {
        throw new UnsupportedOperationException(new StringBuilder(34).append("transformImpl is not supported in ").append(getClass()).toString());
    }

    @Override // org.apache.spark.ml.PredictionModel
    public double predict(FeaturesType featurestype) {
        return raw2prediction(predictRaw(featurestype));
    }

    public abstract Vector predictRaw(FeaturesType featurestype);

    public double raw2prediction(Vector vector) {
        return vector.argmax();
    }

    public ClassificationModel() {
        HasRawPredictionCol.$init$((HasRawPredictionCol) this);
        ClassifierParams.$init$((ClassifierParams) this);
    }
}
