package org.apache.spark.ml.feature;

import java.io.IOException;
import org.apache.spark.SparkException;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.AttributeGroup$;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamValidators$;
import org.apache.spark.ml.param.shared.HasHandleInvalid;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.reflect.ScalaSignature;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.runtime.BoxesRunTime;

/* compiled from: VectorSizeHint.scala */
@ScalaSignature(bytes = "\u0006\u0001\t\ra\u0001B\u0001\u0003\u00015\u0011aBV3di>\u00148+\u001b>f\u0011&tGO\u0003\u0002\u0004\t\u00059a-Z1ukJ,'BA\u0003\u0007\u0003\tiGN\u0003\u0002\b\u0011\u0005)1\u000f]1sW*\u0011\u0011BC\u0001\u0007CB\f7\r[3\u000b\u0003-\t1a\u001c:h\u0007\u0001\u0019R\u0001\u0001\b\u00135u\u0001\"a\u0004\t\u000e\u0003\u0011I!!\u0005\u0003\u0003\u0017Q\u0013\u0018M\\:g_JlWM\u001d\t\u0003'ai\u0011\u0001\u0006\u0006\u0003+Y\taa\u001d5be\u0016$'BA\f\u0005\u0003\u0015\u0001\u0018M]1n\u0013\tIBCA\u0006ICNLe\u000e];u\u0007>d\u0007CA\n\u001c\u0013\taBC\u0001\tICND\u0015M\u001c3mK&sg/\u00197jIB\u0011a$I\u0007\u0002?)\u0011\u0001\u0005B\u0001\u0005kRLG.\u0003\u0002#?\t)B)\u001a4bk2$\b+\u0019:b[N<&/\u001b;bE2,\u0007\u0002\u0003\u0013\u0001\u0005\u000b\u0007I\u0011I\u0013\u0002\u0007ULG-F\u0001'!\t9\u0003G\u0004\u0002)]A\u0011\u0011\u0006L\u0007\u0002U)\u00111\u0006D\u0001\u0007yI|w\u000e\u001e \u000b\u00035\nQa]2bY\u0006L!a\f\u0017\u0002\rA\u0013X\rZ3g\u0013\t\t$G\u0001\u0004TiJLgn\u001a\u0006\u0003_1B3a\t\u001b;!\t)\u0004(D\u00017\u0015\t9d!\u0001\u0006b]:|G/\u0019;j_:L!!\u000f\u001c\u0003\u000bMKgnY3\"\u0003m\nQA\r\u00184]AB\u0001\"\u0010\u0001\u0003\u0002\u0003\u0006IAJ\u0001\u0005k&$\u0007\u0005K\u0002=iiBQ\u0001\u0011\u0001\u0005\u0002\u0005\u000ba\u0001P5oSRtDC\u0001\"E!\t\u0019\u0005!D\u0001\u0003\u0011\u0015!s\b1\u0001'Q\r!EG\u000f\u0015\u0004\u007fQR\u0004\"\u0002!\u0001\t\u0003AE#\u0001\")\u0007\u001d#$\bC\u0004L\u0001\t\u0007I\u0011\u0001'\u0002\tML'0Z\u000b\u0002\u001bB\u0011ajT\u0007\u0002-%\u0011\u0001K\u0006\u0002\t\u0013:$\b+\u0019:b[\"\u001a!\n\u000e\u001e\t\rM\u0003\u0001\u0015!\u0003N\u0003\u0015\u0019\u0018N_3!Q\r\u0011FG\u000f\u0005\u0006-\u0002!\taV\u0001\bO\u0016$8+\u001b>f+\u0005A\u0006CA-[\u001b\u0005a\u0013BA.-\u0005\rIe\u000e\u001e\u0015\u0004+RR\u0004\"\u00020\u0001\t\u0003y\u0016aB:fiNK'0\u001a\u000b\u0003A\u0006l\u0011\u0001\u0001\u0005\u0006Ev\u0003\r\u0001W\u0001\u0006m\u0006dW/\u001a\u0015\u0004;RR\u0004\"B3\u0001\t\u00031\u0017aC:fi&s\u0007/\u001e;D_2$\"\u0001Y4\t\u000b\t$\u0007\u0019\u0001\u0014)\u0007\u0011$$\bC\u0004k\u0001\t\u0007I\u0011I6\u0002\u001b!\fg\u000e\u001a7f\u0013:4\u0018\r\\5e+\u0005a\u0007c\u0001(nM%\u0011aN\u0006\u0002\u0006!\u0006\u0014\u0018-\u001c\u0015\u0004SRR\u0004BB9\u0001A\u0003%A.\u0001\biC:$G.Z%om\u0006d\u0017\u000e\u001a\u0011)\u0007A$$\bC\u0003u\u0001\u0011\u0005Q/\u0001\ttKRD\u0015M\u001c3mK&sg/\u00197jIR\u0011\u0001M\u001e\u0005\u0006EN\u0004\rA\n\u0015\u0004gRR\u0004\"B=\u0001\t\u0003R\u0018!\u0003;sC:\u001chm\u001c:n)\rY\u0018\u0011\u0004\t\u0004y\u0006MabA?\u0002\u000e9\u0019a0!\u0003\u000f\u0007}\f9A\u0004\u0003\u0002\u0002\u0005\u0015abA\u0015\u0002\u0004%\t1\"\u0003\u0002\n\u0015%\u0011q\u0001C\u0005\u0004\u0003\u00171\u0011aA:rY&!\u0011qBA\t\u0003\u001d\u0001\u0018mY6bO\u0016T1!a\u0003\u0007\u0013\u0011\t)\"a\u0006\u0003\u0013\u0011\u000bG/\u0019$sC6,'\u0002BA\b\u0003#Aq!a\u0007y\u0001\u0004\ti\"A\u0004eCR\f7/\u001a;1\t\u0005}\u00111\u0006\t\u0007\u0003C\t\u0019#a\n\u000e\u0005\u0005E\u0011\u0002BA\u0013\u0003#\u0011q\u0001R1uCN,G\u000f\u0005\u0003\u0002*\u0005-B\u0002\u0001\u0003\r\u0003[\tI\"!A\u0001\u0002\u000b\u0005\u0011q\u0006\u0002\u0004?\u0012\n\u0014\u0003BA\u0019\u0003o\u00012!WA\u001a\u0013\r\t)\u0004\f\u0002\b\u001d>$\b.\u001b8h!\rI\u0016\u0011H\u0005\u0004\u0003wa#aA!os\"\u001a\u0001\u0010\u000e\u001e\t\u000f\u0005\u0005\u0003\u0001\"\u0003\u0002D\u0005)b/\u00197jI\u0006$XmU2iK6\f\u0017I\u001c3TSj,GCBA#\u0003#\n\t\u0007\u0005\u0003\u0002H\u00055SBAA%\u0015\r\tY\u0005B\u0001\nCR$(/\u001b2vi\u0016LA!a\u0014\u0002J\tq\u0011\t\u001e;sS\n,H/Z$s_V\u0004\b\u0002CA*\u0003\u007f\u0001\r!!\u0016\u0002\rM\u001c\u0007.Z7b!\u0011\t9&!\u0018\u000e\u0005\u0005e#\u0002BA.\u0003#\tQ\u0001^=qKNLA!a\u0018\u0002Z\tQ1\u000b\u001e:vGR$\u0016\u0010]3\t\u0011\u0005\r\u0014q\ba\u0001\u0003\u000b\nQa\u001a:pkBDq!a\u001a\u0001\t\u0003\nI'A\bue\u0006t7OZ8s[N\u001b\u0007.Z7b)\u0011\t)&a\u001b\t\u0011\u0005M\u0013Q\ra\u0001\u0003+BC!!\u001a5u!9\u0011\u0011\u000f\u0001\u0005B\u0005M\u0014\u0001B2paf$2\u0001YA;\u0011!\t9(a\u001cA\u0002\u0005e\u0014!B3yiJ\f\u0007c\u0001(\u0002|%\u0019\u0011Q\u0010\f\u0003\u0011A\u000b'/Y7NCBDC!a\u001c5u!\u001a\u0001\u0001\u000e\u001e)\u0007\u0001\t)\tE\u00026\u0003\u000fK1!!#7\u00051)\u0005\u0010]3sS6,g\u000e^1m\u000f\u001d\tiI\u0001E\u0001\u0003\u001f\u000baBV3di>\u00148+\u001b>f\u0011&tG\u000fE\u0002D\u0003#3a!\u0001\u0002\t\u0002\u0005M5\u0003CAI\u0003+\u000bY*!)\u0011\u0007e\u000b9*C\u0002\u0002\u001a2\u0012a!\u00118z%\u00164\u0007\u0003\u0002\u0010\u0002\u001e\nK1!a( \u0005U!UMZ1vYR\u0004\u0016M]1ngJ+\u0017\rZ1cY\u0016\u00042!WAR\u0013\r\t)\u000b\f\u0002\r'\u0016\u0014\u0018.\u00197ju\u0006\u0014G.\u001a\u0005\b\u0001\u0006EE\u0011AAU)\t\ty\tC\u0006\u0002.\u0006E%\u0019!C\u0001\u0005\u0005=\u0016AE(Q)&k\u0015j\u0015+J\u0007~KeJV!M\u0013\u0012+\"!!-\u0011\t\u0005M\u0016QX\u0007\u0003\u0003kSA!a.\u0002:\u0006!A.\u00198h\u0015\t\tY,\u0001\u0003kCZ\f\u0017bA\u0019\u00026\"I\u0011\u0011YAIA\u0003%\u0011\u0011W\u0001\u0014\u001fB#\u0016*T%T)&\u001bu,\u0013(W\u00032KE\t\t\u0005\f\u0003\u000b\f\tJ1A\u0005\u0002\t\ty+A\u0007F%J{%kX%O-\u0006c\u0015\n\u0012\u0005\n\u0003\u0013\f\t\n)A\u0005\u0003c\u000ba\"\u0012*S\u001fJ{\u0016J\u0014,B\u0019&#\u0005\u0005C\u0006\u0002N\u0006E%\u0019!C\u0001\u0005\u0005=\u0016\u0001D*L\u0013B{\u0016J\u0014,B\u0019&#\u0005\"CAi\u0003#\u0003\u000b\u0011BAY\u00035\u00196*\u0013)`\u0013:3\u0016\tT%EA!Y\u0011Q[AI\u0005\u0004%\tAAAl\u0003]\u0019X\u000f\u001d9peR,G\rS1oI2,\u0017J\u001c<bY&$7/\u0006\u0002\u0002ZB!\u0011,a7'\u0013\r\ti\u000e\f\u0002\u0006\u0003J\u0014\u0018-\u001f\u0005\n\u0003C\f\t\n)A\u0005\u00033\f\u0001d];qa>\u0014H/\u001a3IC:$G.Z%om\u0006d\u0017\u000eZ:!\u0011!\t)/!%\u0005B\u0005\u001d\u0018\u0001\u00027pC\u0012$2AQAu\u0011\u001d\tY/a9A\u0002\u0019\nA\u0001]1uQ\"\"\u00111\u001d\u001b;\u0011)\t\t0!%\u0002\u0002\u0013%\u00111_\u0001\fe\u0016\fGMU3t_24X\r\u0006\u0002\u0002vB!\u00111WA|\u0013\u0011\tI0!.\u0003\r=\u0013'.Z2uQ\u0011\t\t\n\u000e\u001e)\t\u0005E\u0015Q\u0011\u0015\u0005\u0003\u0017#$\b\u000b\u0003\u0002\f\u0006\u0015\u0005")
@Experimental
/* loaded from: input_file:org/apache/spark/ml/feature/VectorSizeHint.class */
public class VectorSizeHint extends Transformer implements HasInputCol, HasHandleInvalid, DefaultParamsWritable {
    private final String uid;
    private final IntParam size;
    private final Param<String> handleInvalid;
    private final Param<String> inputCol;

    public static MLReader<VectorSizeHint> read() {
        return VectorSizeHint$.MODULE$.read();
    }

    public static /* bridge */ Object load(String str) {
        return VectorSizeHint$.MODULE$.load(str);
    }

    /* renamed from: load, reason: collision with other method in class */
    public static VectorSizeHint m195load(String str) {
        return VectorSizeHint$.MODULE$.load(str);
    }

    @Override // org.apache.spark.ml.util.DefaultParamsWritable, org.apache.spark.ml.util.MLWritable
    public MLWriter write() {
        MLWriter write;
        write = write();
        return write;
    }

    @Override // org.apache.spark.ml.util.MLWritable
    public void save(String str) throws IOException {
        save(str);
    }

    @Override // org.apache.spark.ml.param.shared.HasHandleInvalid
    public final String getHandleInvalid() {
        String handleInvalid;
        handleInvalid = getHandleInvalid();
        return handleInvalid;
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final String getInputCol() {
        String inputCol;
        inputCol = getInputCol();
        return inputCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasHandleInvalid
    public void org$apache$spark$ml$param$shared$HasHandleInvalid$_setter_$handleInvalid_$eq(Param<String> param) {
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final Param<String> inputCol() {
        return this.inputCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final void org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(Param<String> param) {
        this.inputCol = param;
    }

    @Override // org.apache.spark.ml.util.Identifiable
    public String uid() {
        return this.uid;
    }

    public IntParam size() {
        return this.size;
    }

    public int getSize() {
        return BoxesRunTime.unboxToInt(getOrDefault(size()));
    }

    public VectorSizeHint setSize(int i) {
        return (VectorSizeHint) set((Param<IntParam>) size(), (IntParam) BoxesRunTime.boxToInteger(i));
    }

    public VectorSizeHint setInputCol(String str) {
        return (VectorSizeHint) set((Param<Param<String>>) inputCol(), (Param<String>) str);
    }

    @Override // org.apache.spark.ml.param.shared.HasHandleInvalid
    public Param<String> handleInvalid() {
        return this.handleInvalid;
    }

    public VectorSizeHint setHandleInvalid(String str) {
        return (VectorSizeHint) set((Param<Param<String>>) handleInvalid(), (Param<String>) str);
    }

    @Override // org.apache.spark.ml.Transformer
    public Dataset<Row> transform(Dataset<?> dataset) {
        Column apply;
        String inputCol = getInputCol();
        int size = getSize();
        String handleInvalid = getHandleInvalid();
        AttributeGroup fromStructField = AttributeGroup$.MODULE$.fromStructField(dataset.schema().apply(inputCol));
        AttributeGroup validateSchemaAndSize = validateSchemaAndSize(dataset.schema(), fromStructField);
        String OPTIMISTIC_INVALID = VectorSizeHint$.MODULE$.OPTIMISTIC_INVALID();
        if (handleInvalid != null ? handleInvalid.equals(OPTIMISTIC_INVALID) : OPTIMISTIC_INVALID == null) {
            if (fromStructField.size() == size) {
                return dataset.toDF();
            }
        }
        String OPTIMISTIC_INVALID2 = VectorSizeHint$.MODULE$.OPTIMISTIC_INVALID();
        if (OPTIMISTIC_INVALID2 != null ? !OPTIMISTIC_INVALID2.equals(handleInvalid) : handleInvalid != null) {
            String ERROR_INVALID = VectorSizeHint$.MODULE$.ERROR_INVALID();
            if (ERROR_INVALID != null ? !ERROR_INVALID.equals(handleInvalid) : handleInvalid != null) {
                String SKIP_INVALID = VectorSizeHint$.MODULE$.SKIP_INVALID();
                if (SKIP_INVALID != null ? !SKIP_INVALID.equals(handleInvalid) : handleInvalid != null) {
                    throw new MatchError(handleInvalid);
                }
                final VectorSizeHint vectorSizeHint = null;
                final VectorSizeHint vectorSizeHint2 = null;
                apply = functions$.MODULE$.udf(vector -> {
                    if (vector == null || vector.size() != size) {
                        return null;
                    }
                    return vector;
                }, scala.reflect.runtime.package$.MODULE$.universe().TypeTag().apply(scala.reflect.runtime.package$.MODULE$.universe().runtimeMirror(VectorSizeHint.class.getClassLoader()), new TypeCreator(vectorSizeHint) { // from class: org.apache.spark.ml.feature.VectorSizeHint$$typecreator3$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                }), scala.reflect.runtime.package$.MODULE$.universe().TypeTag().apply(scala.reflect.runtime.package$.MODULE$.universe().runtimeMirror(VectorSizeHint.class.getClassLoader()), new TypeCreator(vectorSizeHint2) { // from class: org.apache.spark.ml.feature.VectorSizeHint$$typecreator4$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                })).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(inputCol)}));
            } else {
                final VectorSizeHint vectorSizeHint3 = null;
                final VectorSizeHint vectorSizeHint4 = null;
                apply = functions$.MODULE$.udf(vector2 -> {
                    if (vector2 == null) {
                        throw new SparkException(new StringBuilder(88).append("Got null vector in VectorSizeHint, set `handleInvalid` ").append("to 'skip' to filter invalid rows.").toString());
                    }
                    if (vector2.size() != size) {
                        throw new SparkException(new StringBuilder(46).append("VectorSizeHint Expecting a vector of size ").append(size).append(" but").append(new StringBuilder(5).append(" got ").append(vector2.size()).toString()).toString());
                    }
                    return vector2;
                }, scala.reflect.runtime.package$.MODULE$.universe().TypeTag().apply(scala.reflect.runtime.package$.MODULE$.universe().runtimeMirror(VectorSizeHint.class.getClassLoader()), new TypeCreator(vectorSizeHint3) { // from class: org.apache.spark.ml.feature.VectorSizeHint$$typecreator1$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                }), scala.reflect.runtime.package$.MODULE$.universe().TypeTag().apply(scala.reflect.runtime.package$.MODULE$.universe().runtimeMirror(VectorSizeHint.class.getClassLoader()), new TypeCreator(vectorSizeHint4) { // from class: org.apache.spark.ml.feature.VectorSizeHint$$typecreator2$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                })).asNondeterministic().apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(inputCol)}));
            }
        } else {
            apply = functions$.MODULE$.col(inputCol);
        }
        Dataset<Row> withColumn = dataset.withColumn(inputCol, apply.as(inputCol, validateSchemaAndSize.toMetadata()));
        String SKIP_INVALID2 = VectorSizeHint$.MODULE$.SKIP_INVALID();
        return (handleInvalid != null ? !handleInvalid.equals(SKIP_INVALID2) : SKIP_INVALID2 != null) ? withColumn : withColumn.na().drop(new String[]{inputCol});
    }

    private AttributeGroup validateSchemaAndSize(StructType structType, AttributeGroup attributeGroup) {
        AttributeGroup attributeGroup2;
        int size = getSize();
        String inputCol = getInputCol();
        DataType dataType = structType.apply(getInputCol()).dataType();
        Predef$.MODULE$.require(dataType instanceof VectorUDT, () -> {
            return new StringBuilder(43).append("Input column, ").append(this.getInputCol()).append(" must be of Vector type, got ").append(dataType).toString();
        });
        int size2 = attributeGroup.size();
        if (size == size2) {
            attributeGroup2 = attributeGroup;
        } else {
            if (-1 != size2) {
                throw new IllegalArgumentException(new StringBuilder(49).append("Trying to set size of vectors in `").append(inputCol).append("` to ").append(size).append(" but size ").append(new StringBuilder(16).append("already set to ").append(attributeGroup.size()).append(".").toString()).toString());
            }
            attributeGroup2 = new AttributeGroup(inputCol, size);
        }
        return attributeGroup2;
    }

    @Override // org.apache.spark.ml.PipelineStage
    public StructType transformSchema(StructType structType) {
        int fieldIndex = structType.fieldIndex(getInputCol());
        StructField[] structFieldArr = (StructField[]) structType.fields().clone();
        StructField structField = structFieldArr[fieldIndex];
        structFieldArr[fieldIndex] = structField.copy(structField.copy$default$1(), structField.copy$default$2(), structField.copy$default$3(), validateSchemaAndSize(structType, AttributeGroup$.MODULE$.fromStructField(structField)).toMetadata());
        return new StructType(structFieldArr);
    }

    @Override // org.apache.spark.ml.Transformer, org.apache.spark.ml.PipelineStage, org.apache.spark.ml.param.Params
    public VectorSizeHint copy(ParamMap paramMap) {
        return (VectorSizeHint) defaultCopy(paramMap);
    }

    public VectorSizeHint(String str) {
        this.uid = str;
        org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(new Param<>(this, "inputCol", "input column name"));
        org$apache$spark$ml$param$shared$HasHandleInvalid$_setter_$handleInvalid_$eq(new Param<>(this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later", ParamValidators$.MODULE$.inArray(new String[]{"skip", "error"})));
        MLWritable.$init$(this);
        DefaultParamsWritable.$init$((DefaultParamsWritable) this);
        this.size = new IntParam(this, "size", "Size of vectors in column.", (Function1<Object, Object>) i -> {
            return i >= 0;
        });
        this.handleInvalid = new Param<>(this, "handleInvalid", "How to handle invalid vectors in inputCol. Invalid vectors include nulls and vectors with the wrong size. The options are `skip` (filter out rows with invalid vectors), `error` (throw an error) and `optimistic` (do not check the vector size, and keep all rows). `error` by default.", ParamValidators$.MODULE$.inArray(VectorSizeHint$.MODULE$.supportedHandleInvalids()));
        setDefault(handleInvalid(), VectorSizeHint$.MODULE$.ERROR_INVALID());
    }

    public VectorSizeHint() {
        this(Identifiable$.MODULE$.randomUID("vectSizeHint"));
    }
}
