package org.apache.spark.ml.util;

import org.apache.spark.SparkException;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.PredictorParams;
import org.apache.spark.ml.classification.ClassifierParams;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.param.shared.HasWeightCol;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.IntegerType$;
import org.apache.spark.sql.types.StringType$;
import org.slf4j.Logger;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.collection.ArrayOps$;
import scala.collection.SeqOps;
import scala.collection.StringOps$;
import scala.collection.immutable.$colon;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag$;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.TypeTags;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichDouble$;
import scala.runtime.ScalaRunTime$;

/* compiled from: DatasetUtils.scala */
/* loaded from: input_file:org/apache/spark/ml/util/DatasetUtils$.class */
public final class DatasetUtils$ implements Logging {
    public static final DatasetUtils$ MODULE$ = new DatasetUtils$();
    private static UserDefinedFunction validateVector;
    private static transient Logger org$apache$spark$internal$Logging$$log_;
    private static volatile boolean bitmap$0;

    static {
        Logging.$init$(MODULE$);
    }

    public String logName() {
        return Logging.logName$(this);
    }

    public Logger log() {
        return Logging.log$(this);
    }

    public void logInfo(Function0<String> function0) {
        Logging.logInfo$(this, function0);
    }

    public void logDebug(Function0<String> function0) {
        Logging.logDebug$(this, function0);
    }

    public void logTrace(Function0<String> function0) {
        Logging.logTrace$(this, function0);
    }

    public void logWarning(Function0<String> function0) {
        Logging.logWarning$(this, function0);
    }

    public void logError(Function0<String> function0) {
        Logging.logError$(this, function0);
    }

    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.logInfo$(this, function0, th);
    }

    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.logDebug$(this, function0, th);
    }

    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.logTrace$(this, function0, th);
    }

    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.logWarning$(this, function0, th);
    }

    public void logError(Function0<String> function0, Throwable th) {
        Logging.logError$(this, function0, th);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$(this);
    }

    public void initializeLogIfNecessary(boolean z) {
        Logging.initializeLogIfNecessary$(this, z);
    }

    public boolean initializeLogIfNecessary(boolean z, boolean z2) {
        return Logging.initializeLogIfNecessary$(this, z, z2);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$(this);
    }

    public void initializeForcefully(boolean z, boolean z2) {
        Logging.initializeForcefully$(this, z, z2);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger logger) {
        org$apache$spark$internal$Logging$$log_ = logger;
    }

    public Column checkNonNanValues(String str, String str2) {
        Column cast = functions$.MODULE$.col(str).cast(DoubleType$.MODULE$);
        return functions$.MODULE$.when(cast.isNull().$bar$bar(cast.isNaN()), functions$.MODULE$.raise_error(functions$.MODULE$.lit(new StringBuilder(24).append(str2).append(" MUST NOT be Null or NaN").toString()))).when(cast.$eq$eq$eq(BoxesRunTime.boxToDouble(Double.NEGATIVE_INFINITY)).$bar$bar(cast.$eq$eq$eq(BoxesRunTime.boxToDouble(Double.POSITIVE_INFINITY))), functions$.MODULE$.raise_error(functions$.MODULE$.concat(ScalaRunTime$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.lit(new StringBuilder(31).append(str2).append(" MUST NOT be Infinity, but got ").toString()), cast})))).otherwise(cast);
    }

    public Column checkRegressionLabels(String str) {
        return checkNonNanValues(str, "Labels");
    }

    public Column checkClassificationLabels(String str, Option<Object> option) {
        Column otherwise;
        Column cast = functions$.MODULE$.col(str).cast(DoubleType$.MODULE$);
        if ((option instanceof Some) && 2 == BoxesRunTime.unboxToInt(((Some) option).value())) {
            otherwise = functions$.MODULE$.when(cast.isNull().$bar$bar(cast.isNaN()), functions$.MODULE$.raise_error(functions$.MODULE$.lit("Labels MUST NOT be Null or NaN"))).when(cast.$eq$bang$eq(BoxesRunTime.boxToInteger(0)).$amp$amp(cast.$eq$bang$eq(BoxesRunTime.boxToInteger(1))), functions$.MODULE$.raise_error(functions$.MODULE$.concat(ScalaRunTime$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.lit("Labels MUST be in {0, 1}, but got "), cast})))).otherwise(cast);
        } else {
            int unboxToInt = BoxesRunTime.unboxToInt(option.getOrElse(() -> {
                return Integer.MAX_VALUE;
            }));
            Predef$.MODULE$.require(0 < unboxToInt && unboxToInt <= Integer.MAX_VALUE);
            otherwise = functions$.MODULE$.when(cast.isNull().$bar$bar(cast.isNaN()), functions$.MODULE$.raise_error(functions$.MODULE$.lit("Labels MUST NOT be Null or NaN"))).when(cast.$less(BoxesRunTime.boxToInteger(0)).$bar$bar(cast.$greater$eq(BoxesRunTime.boxToInteger(unboxToInt))), functions$.MODULE$.raise_error(functions$.MODULE$.concat(ScalaRunTime$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.lit(new StringBuilder(33).append("Labels MUST be in [0, ").append(unboxToInt).append("), but got ").toString()), cast})))).when(cast.$eq$bang$eq(cast.cast(IntegerType$.MODULE$)), functions$.MODULE$.raise_error(functions$.MODULE$.concat(ScalaRunTime$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.lit("Labels MUST be Integers, but got "), cast})))).otherwise(cast);
        }
        return otherwise;
    }

    public Column checkNonNegativeWeights(String str) {
        Column cast = functions$.MODULE$.col(str).cast(DoubleType$.MODULE$);
        return functions$.MODULE$.when(cast.isNull().$bar$bar(cast.isNaN()), functions$.MODULE$.raise_error(functions$.MODULE$.lit("Weights MUST NOT be Null or NaN"))).when(cast.$less(BoxesRunTime.boxToInteger(0)).$bar$bar(cast.$eq$eq$eq(BoxesRunTime.boxToDouble(Double.POSITIVE_INFINITY))), functions$.MODULE$.raise_error(functions$.MODULE$.concat(ScalaRunTime$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.lit("Weights MUST NOT be Negative or Infinity, but got "), cast})))).otherwise(cast);
    }

    public Column checkNonNegativeWeights(Option<String> option) {
        Column lit;
        if (option instanceof Some) {
            String str = (String) ((Some) option).value();
            if (StringOps$.MODULE$.nonEmpty$extension(Predef$.MODULE$.augmentString(str))) {
                lit = checkNonNegativeWeights(str);
                return lit;
            }
        }
        lit = functions$.MODULE$.lit(BoxesRunTime.boxToDouble(1.0d));
        return lit;
    }

    public Column checkNonNanVectors(Column column) {
        return functions$.MODULE$.when(column.isNull(), functions$.MODULE$.raise_error(functions$.MODULE$.lit("Vectors MUST NOT be Null"))).when(validateVector().apply(ScalaRunTime$.MODULE$.wrapRefArray(new Column[]{column})).unary_$bang(), functions$.MODULE$.raise_error(functions$.MODULE$.concat(ScalaRunTime$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.lit("Vector values MUST NOT be NaN or Infinity, but got "), column.cast(StringType$.MODULE$)})))).otherwise(column);
    }

    public Column checkNonNanVectors(String str) {
        return checkNonNanVectors(functions$.MODULE$.col(str));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v7 */
    private UserDefinedFunction validateVector$lzycompute() {
        ?? r0 = this;
        synchronized (r0) {
            if (!bitmap$0) {
                functions$ functions_ = functions$.MODULE$;
                Function1 function1 = vector -> {
                    return BoxesRunTime.boxToBoolean($anonfun$validateVector$1(vector));
                };
                TypeTags.TypeTag Boolean = package$.MODULE$.universe().TypeTag().Boolean();
                TypeTags universe = package$.MODULE$.universe();
                validateVector = functions_.udf(function1, Boolean, universe.TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.util.DatasetUtils$$typecreator1$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                }));
                r0 = 1;
                bitmap$0 = true;
            }
        }
        return validateVector;
    }

    private UserDefinedFunction validateVector() {
        return !bitmap$0 ? validateVector$lzycompute() : validateVector;
    }

    public RDD<Instance> extractInstances(PredictorParams predictorParams, Dataset<?> dataset, Option<Object> option) {
        return dataset.select(ScalaRunTime$.MODULE$.wrapRefArray(new Column[]{predictorParams instanceof ClassifierParams ? checkClassificationLabels(((ClassifierParams) predictorParams).getLabelCol(), option) : checkRegressionLabels(predictorParams.getLabelCol()), predictorParams instanceof HasWeightCol ? checkNonNegativeWeights(predictorParams.get(((HasWeightCol) predictorParams).weightCol())) : functions$.MODULE$.lit(BoxesRunTime.boxToDouble(1.0d)), checkNonNanVectors(predictorParams.getFeaturesCol())})).rdd().map(row -> {
            if (row != null) {
                Some unapplySeq = Row$.MODULE$.unapplySeq(row);
                if (!unapplySeq.isEmpty() && unapplySeq.get() != null && ((SeqOps) unapplySeq.get()).lengthCompare(3) == 0) {
                    Object apply = ((SeqOps) unapplySeq.get()).apply(0);
                    Object apply2 = ((SeqOps) unapplySeq.get()).apply(1);
                    Object apply3 = ((SeqOps) unapplySeq.get()).apply(2);
                    if (apply instanceof Double) {
                        double unboxToDouble = BoxesRunTime.unboxToDouble(apply);
                        if (apply2 instanceof Double) {
                            double unboxToDouble2 = BoxesRunTime.unboxToDouble(apply2);
                            if (apply3 instanceof Vector) {
                                return new Instance(unboxToDouble, unboxToDouble2, (Vector) apply3);
                            }
                        }
                    }
                }
            }
            throw new MatchError(row);
        }, ClassTag$.MODULE$.apply(Instance.class));
    }

    public Option<Object> extractInstances$default$3() {
        return None$.MODULE$;
    }

    public Column columnToVector(Dataset<?> dataset, String str) {
        UserDefinedFunction udf;
        Column apply;
        ArrayType dataType = dataset.schema().apply(str).dataType();
        if (dataType instanceof VectorUDT) {
            apply = functions$.MODULE$.col(str);
        } else {
            if (!(dataType instanceof ArrayType)) {
                throw new IllegalArgumentException(new StringBuilder(32).append(dataType).append(" column cannot be cast to Vector").toString());
            }
            DataType elementType = dataType.elementType();
            if (elementType instanceof FloatType) {
                udf = functions$.MODULE$.udf(seq -> {
                    double[] dArr = (double[]) Array$.MODULE$.ofDim(seq.size(), ClassTag$.MODULE$.Double());
                    seq.indices().foreach$mVc$sp(i -> {
                        dArr[i] = BoxesRunTime.unboxToFloat(seq.apply(i));
                    });
                    return Vectors$.MODULE$.dense(dArr);
                }, package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.util.DatasetUtils$$typecreator1$2
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                }), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.util.DatasetUtils$$typecreator2$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        Universe universe = mirror.universe();
                        return universe.internal().reificationSupport().TypeRef(universe.internal().reificationSupport().SingleType(universe.internal().reificationSupport().SingleType(universe.internal().reificationSupport().thisPrefix(mirror.RootClass()), mirror.staticPackage("scala")), mirror.staticModule("scala.package")), universe.internal().reificationSupport().selectType(mirror.staticModule("scala.package").asModule().moduleClass(), "Seq"), new $colon.colon(mirror.staticClass("scala.Float").asType().toTypeConstructor(), Nil$.MODULE$));
                    }
                }));
            } else {
                if (!(elementType instanceof DoubleType)) {
                    throw new IllegalArgumentException(new StringBuilder(39).append("Array[").append(elementType).append("] column cannot be cast to Vector").toString());
                }
                udf = functions$.MODULE$.udf(seq2 -> {
                    return Vectors$.MODULE$.dense((double[]) seq2.toArray(ClassTag$.MODULE$.Double()));
                }, package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.util.DatasetUtils$$typecreator3$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                }), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.util.DatasetUtils$$typecreator4$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        Universe universe = mirror.universe();
                        return universe.internal().reificationSupport().TypeRef(universe.internal().reificationSupport().SingleType(universe.internal().reificationSupport().SingleType(universe.internal().reificationSupport().thisPrefix(mirror.RootClass()), mirror.staticPackage("scala")), mirror.staticModule("scala.package")), universe.internal().reificationSupport().selectType(mirror.staticModule("scala.package").asModule().moduleClass(), "Seq"), new $colon.colon(mirror.staticClass("scala.Double").asType().toTypeConstructor(), Nil$.MODULE$));
                    }
                }));
            }
            apply = udf.apply(ScalaRunTime$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str)}));
        }
        return apply;
    }

    public RDD<org.apache.spark.mllib.linalg.Vector> columnToOldVector(Dataset<?> dataset, String str) {
        return dataset.select(ScalaRunTime$.MODULE$.wrapRefArray(new Column[]{columnToVector(dataset, str)})).rdd().map(row -> {
            if (row != null) {
                Some unapplySeq = Row$.MODULE$.unapplySeq(row);
                if (!unapplySeq.isEmpty() && unapplySeq.get() != null && ((SeqOps) unapplySeq.get()).lengthCompare(1) == 0) {
                    Object apply = ((SeqOps) unapplySeq.get()).apply(0);
                    if (apply instanceof Vector) {
                        return org.apache.spark.mllib.linalg.Vectors$.MODULE$.fromML((Vector) apply);
                    }
                }
            }
            throw new MatchError(row);
        }, ClassTag$.MODULE$.apply(org.apache.spark.mllib.linalg.Vector.class));
    }

    public int getNumClasses(Dataset<?> dataset, String str, int i) {
        int i2;
        Some numClasses = MetadataUtils$.MODULE$.getNumClasses(dataset.schema().apply(str));
        if (numClasses instanceof Some) {
            int unboxToInt = BoxesRunTime.unboxToInt(numClasses.value());
            if (1 != 0) {
                i2 = unboxToInt;
                return i2;
            }
        }
        if (!None$.MODULE$.equals(numClasses)) {
            throw new MatchError(numClasses);
        }
        Row[] rowArr = (Row[]) dataset.select(ScalaRunTime$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.max(checkClassificationLabels(str, new Some(BoxesRunTime.boxToInteger(i))))})).take(1);
        if (ArrayOps$.MODULE$.isEmpty$extension(Predef$.MODULE$.refArrayOps(rowArr)) || rowArr[0].get(0) == null) {
            throw new SparkException("ML algorithm was given empty dataset.");
        }
        double d = ((Row) ArrayOps$.MODULE$.head$extension(Predef$.MODULE$.refArrayOps(rowArr))).getDouble(0);
        Predef$.MODULE$.require(RichDouble$.MODULE$.isValidInt$extension(Predef$.MODULE$.doubleWrapper(d + 1)), () -> {
            return new StringBuilder(0).append("Classifier found max label value =").append(new StringBuilder(62).append(" ").append(d).append(" but requires integers in range [0, ... ").append(Integer.MAX_VALUE).append(")").toString()).toString();
        });
        int i3 = ((int) d) + 1;
        Predef$.MODULE$.require(i3 <= i, () -> {
            return new StringBuilder(0).append(new StringBuilder(38).append("Classifier inferred ").append(i3).append(" from label values").toString()).append(new StringBuilder(60).append(" in column ").append(str).append(", but this exceeded the max numClasses (").append(i).append(") allowed").toString()).append(new StringBuilder(68).append(" to be inferred from values.  To avoid this error for labels with > ").append(i).toString()).append(" classes, specify numClasses explicitly in the metadata; this can be done by applying").append(" StringIndexer to the label column.").toString();
        });
        logInfo(() -> {
            return new StringBuilder(0).append(MODULE$.getClass().getCanonicalName()).append(new StringBuilder(22).append(" inferred ").append(i3).append(" classes for").toString()).append(new StringBuilder(69).append(" labelCol=").append(str).append(" since numClasses was not specified in the column metadata.").toString()).toString();
        });
        i2 = i3;
        return i2;
    }

    public int getNumClasses$default$3() {
        return 100;
    }

    public int getNumFeatures(Dataset<?> dataset, String str) {
        return BoxesRunTime.unboxToInt(MetadataUtils$.MODULE$.getNumFeatures(dataset.schema().apply(str)).getOrElse(() -> {
            return ((Vector) ((Row) dataset.select(ScalaRunTime$.MODULE$.wrapRefArray(new Column[]{MODULE$.columnToVector(dataset, str)})).head()).getAs(0)).size();
        }));
    }

    public static final /* synthetic */ boolean $anonfun$validateVector$1(Vector vector) {
        boolean forall$extension;
        if (vector instanceof DenseVector) {
            forall$extension = ArrayOps$.MODULE$.forall$extension(Predef$.MODULE$.doubleArrayOps(((DenseVector) vector).values()), d -> {
                return (Double.isNaN(d) || RichDouble$.MODULE$.isInfinity$extension(Predef$.MODULE$.doubleWrapper(d))) ? false : true;
            });
        } else {
            if (!(vector instanceof SparseVector)) {
                throw new MatchError(vector);
            }
            forall$extension = ArrayOps$.MODULE$.forall$extension(Predef$.MODULE$.doubleArrayOps(((SparseVector) vector).values()), d2 -> {
                return (Double.isNaN(d2) || RichDouble$.MODULE$.isInfinity$extension(Predef$.MODULE$.doubleWrapper(d2))) ? false : true;
            });
        }
        return forall$extension;
    }

    private DatasetUtils$() {
    }
}
