package org.apache.spark.ml.feature;

import java.io.IOException;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.feature.VectorIndexerParams;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.ml.util.SchemaUtils$;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.util.collection.OpenHashSet;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.immutable.Map;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: VectorIndexer.scala */
@ScalaSignature(bytes = "\u0006\u0001\u00055g\u0001B\u0001\u0003\u00015\u0011QBV3di>\u0014\u0018J\u001c3fq\u0016\u0014(BA\u0002\u0005\u0003\u001d1W-\u0019;ve\u0016T!!\u0002\u0004\u0002\u00055d'BA\u0004\t\u0003\u0015\u0019\b/\u0019:l\u0015\tI!\"\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0002\u0017\u0005\u0019qN]4\u0004\u0001M!\u0001A\u0004\f\u001a!\ry\u0001CE\u0007\u0002\t%\u0011\u0011\u0003\u0002\u0002\n\u000bN$\u0018.\\1u_J\u0004\"a\u0005\u000b\u000e\u0003\tI!!\u0006\u0002\u0003%Y+7\r^8s\u0013:$W\r_3s\u001b>$W\r\u001c\t\u0003']I!\u0001\u0007\u0002\u0003'Y+7\r^8s\u0013:$W\r_3s!\u0006\u0014\u0018-\\:\u0011\u0005iiR\"A\u000e\u000b\u0005q!\u0011\u0001B;uS2L!AH\u000e\u0003+\u0011+g-Y;miB\u000b'/Y7t/JLG/\u00192mK\"A\u0001\u0005\u0001BC\u0002\u0013\u0005\u0013%A\u0002vS\u0012,\u0012A\t\t\u0003G%r!\u0001J\u0014\u000e\u0003\u0015R\u0011AJ\u0001\u0006g\u000e\fG.Y\u0005\u0003Q\u0015\na\u0001\u0015:fI\u00164\u0017B\u0001\u0016,\u0005\u0019\u0019FO]5oO*\u0011\u0001&\n\u0005\t[\u0001\u0011\t\u0011)A\u0005E\u0005!Q/\u001b3!\u0011\u0015y\u0003\u0001\"\u00011\u0003\u0019a\u0014N\\5u}Q\u0011\u0011G\r\t\u0003'\u0001AQ\u0001\t\u0018A\u0002\tBQa\f\u0001\u0005\u0002Q\"\u0012!\r\u0005\u0006m\u0001!\taN\u0001\u0011g\u0016$X*\u0019=DCR,wm\u001c:jKN$\"\u0001O\u001d\u000e\u0003\u0001AQAO\u001bA\u0002m\nQA^1mk\u0016\u0004\"\u0001\n\u001f\n\u0005u*#aA%oi\")q\b\u0001C\u0001\u0001\u0006Y1/\u001a;J]B,HoQ8m)\tA\u0014\tC\u0003;}\u0001\u0007!\u0005C\u0003D\u0001\u0011\u0005A)\u0001\u0007tKR|U\u000f\u001e9vi\u000e{G\u000e\u0006\u00029\u000b\")!H\u0011a\u0001E!)q\t\u0001C!\u0011\u0006\u0019a-\u001b;\u0015\u0005II\u0005\"\u0002&G\u0001\u0004Y\u0015a\u00023bi\u0006\u001cX\r\u001e\t\u0003\u0019>k\u0011!\u0014\u0006\u0003\u001d\u001a\t1a]9m\u0013\t\u0001VJA\u0005ECR\fgI]1nK\")!\u000b\u0001C!'\u0006yAO]1og\u001a|'/\\*dQ\u0016l\u0017\r\u0006\u0002U5B\u0011Q\u000bW\u0007\u0002-*\u0011q+T\u0001\u0006if\u0004Xm]\u0005\u00033Z\u0013!b\u0015;sk\u000e$H+\u001f9f\u0011\u0015Y\u0016\u000b1\u0001U\u0003\u0019\u00198\r[3nC\")Q\f\u0001C!=\u0006!1m\u001c9z)\t\tt\fC\u0003a9\u0002\u0007\u0011-A\u0003fqR\u0014\u0018\r\u0005\u0002cK6\t1M\u0003\u0002e\t\u0005)\u0001/\u0019:b[&\u0011am\u0019\u0002\t!\u0006\u0014\u0018-\\'ba\"\u0012\u0001\u0001\u001b\t\u0003S2l\u0011A\u001b\u0006\u0003W\u001a\t!\"\u00198o_R\fG/[8o\u0013\ti'N\u0001\u0007FqB,'/[7f]R\fGnB\u0003p\u0005!\u0005\u0001/A\u0007WK\u000e$xN]%oI\u0016DXM\u001d\t\u0003'E4Q!\u0001\u0002\t\u0002I\u001cB!]:wsB\u0011A\u0005^\u0005\u0003k\u0016\u0012a!\u00118z%\u00164\u0007c\u0001\u000exc%\u0011\u0001p\u0007\u0002\u0016\t\u00164\u0017-\u001e7u!\u0006\u0014\u0018-\\:SK\u0006$\u0017M\u00197f!\t!#0\u0003\u0002|K\ta1+\u001a:jC2L'0\u00192mK\")q&\u001dC\u0001{R\t\u0001\u000f\u0003\u0004��c\u0012\u0005\u0013\u0011A\u0001\u0005Y>\fG\rF\u00022\u0003\u0007Aa!!\u0002\u007f\u0001\u0004\u0011\u0013\u0001\u00029bi\"DSA`A\u0005\u0003\u001f\u00012![A\u0006\u0013\r\tiA\u001b\u0002\u0006'&t7-Z\u0011\u0003\u0003#\tQ!\r\u00187]A2a!!\u0006r\t\u0005]!!D\"bi\u0016<wN]=Ti\u0006$8o\u0005\u0003\u0002\u0014ML\bbCA\u000e\u0003'\u0011)\u0019!C\u0005\u0003;\t1B\\;n\r\u0016\fG/\u001e:fgV\t1\b\u0003\u0006\u0002\"\u0005M!\u0011!Q\u0001\nm\nAB\\;n\r\u0016\fG/\u001e:fg\u0002B1\"!\n\u0002\u0014\t\u0015\r\u0011\"\u0003\u0002\u001e\u0005iQ.\u0019=DCR,wm\u001c:jKND!\"!\u000b\u0002\u0014\t\u0005\t\u0015!\u0003<\u00039i\u0017\r_\"bi\u0016<wN]5fg\u0002BqaLA\n\t\u0003\ti\u0003\u0006\u0004\u00020\u0005M\u0012Q\u0007\t\u0005\u0003c\t\u0019\"D\u0001r\u0011\u001d\tY\"a\u000bA\u0002mBq!!\n\u0002,\u0001\u00071\b\u0003\u0006\u0002:\u0005M!\u0019!C\u0005\u0003w\t\u0001CZ3biV\u0014XMV1mk\u0016\u001cV\r^:\u0016\u0005\u0005u\u0002#\u0002\u0013\u0002@\u0005\r\u0013bAA!K\t)\u0011I\u001d:bsB1\u0011QIA'\u0003#j!!a\u0012\u000b\t\u0005%\u00131J\u0001\u000bG>dG.Z2uS>t'B\u0001\u000f\u0007\u0013\u0011\ty%a\u0012\u0003\u0017=\u0003XM\u001c%bg\"\u001cV\r\u001e\t\u0004I\u0005M\u0013bAA+K\t1Ai\\;cY\u0016D\u0011\"!\u0017\u0002\u0014\u0001\u0006I!!\u0010\u0002#\u0019,\u0017\r^;sKZ\u000bG.^3TKR\u001c\b\u0005\u0003\u0005\u0002^\u0005MA\u0011AA0\u0003\u0015iWM]4f)\u0011\ty#!\u0019\t\u0011\u0005\r\u00141\fa\u0001\u0003_\tQa\u001c;iKJD\u0001\"a\u001a\u0002\u0014\u0011\u0005\u0011\u0011N\u0001\nC\u0012$g+Z2u_J$B!a\u001b\u0002rA\u0019A%!\u001c\n\u0007\u0005=TE\u0001\u0003V]&$\b\u0002CA:\u0003K\u0002\r!!\u001e\u0002\u0003Y\u0004B!a\u001e\u0002\u00026\u0011\u0011\u0011\u0010\u0006\u0005\u0003w\ni(\u0001\u0004mS:\fGn\u001a\u0006\u0004\u0003\u007f2\u0011!B7mY&\u0014\u0017\u0002BAB\u0003s\u0012aAV3di>\u0014\b\u0002CAD\u0003'!\t!!#\u0002\u001f\u001d,GoQ1uK\u001e|'/_'baN,\"!a#\u0011\r\r\niiOAI\u0013\r\tyi\u000b\u0002\u0004\u001b\u0006\u0004\bCB\u0012\u0002\u000e\u0006E3\b\u0003\u0005\u0002\u0016\u0006MA\u0011BAL\u00039\tG\r\u001a#f]N,g+Z2u_J$B!a\u001b\u0002\u001a\"A\u00111TAJ\u0001\u0004\ti*\u0001\u0002emB!\u0011qOAP\u0013\u0011\t\t+!\u001f\u0003\u0017\u0011+gn]3WK\u000e$xN\u001d\u0005\t\u0003K\u000b\u0019\u0002\"\u0003\u0002(\u0006y\u0011\r\u001a3Ta\u0006\u00148/\u001a,fGR|'\u000f\u0006\u0003\u0002l\u0005%\u0006\u0002CAV\u0003G\u0003\r!!,\u0002\u0005M4\b\u0003BA<\u0003_KA!!-\u0002z\ta1\u000b]1sg\u00164Vm\u0019;pe\"I\u0011QW9\u0002\u0002\u0013%\u0011qW\u0001\fe\u0016\fGMU3t_24X\r\u0006\u0002\u0002:B!\u00111XAc\u001b\t\tiL\u0003\u0003\u0002@\u0006\u0005\u0017\u0001\u00027b]\u001eT!!a1\u0002\t)\fg/Y\u0005\u0005\u0003\u000f\fiL\u0001\u0004PE*,7\r\u001e\u0015\u0006c\u0006%\u0011q\u0002\u0015\u0006]\u0006%\u0011q\u0002")
@Experimental
/* loaded from: input_file:org/apache/spark/ml/feature/VectorIndexer.class */
public class VectorIndexer extends Estimator<VectorIndexerModel> implements VectorIndexerParams, DefaultParamsWritable {
    private final String uid;
    private final IntParam maxCategories;
    private final Param<String> outputCol;
    private final Param<String> inputCol;

    /* compiled from: VectorIndexer.scala */
    /* loaded from: input_file:org/apache/spark/ml/feature/VectorIndexer$CategoryStats.class */
    public static class CategoryStats implements Serializable {
        private final int org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$numFeatures;
        private final int org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$maxCategories;
        private final OpenHashSet<Object>[] featureValueSets;

        public int org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$numFeatures() {
            return this.org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$numFeatures;
        }

        public int org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$maxCategories() {
            return this.org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$maxCategories;
        }

        private OpenHashSet<Object>[] featureValueSets() {
            return this.featureValueSets;
        }

        public CategoryStats merge(CategoryStats categoryStats) {
            Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(featureValueSets()).zip(Predef$.MODULE$.wrapRefArray(categoryStats.featureValueSets()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).foreach(new VectorIndexer$CategoryStats$$anonfun$merge$1(this));
            return this;
        }

        public void addVector(Vector vector) {
            Predef$.MODULE$.require(vector.size() == org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$numFeatures(), new VectorIndexer$CategoryStats$$anonfun$addVector$1(this, vector));
            if (vector instanceof DenseVector) {
                addDenseVector((DenseVector) vector);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                if (!(vector instanceof SparseVector)) {
                    throw new MatchError(vector);
                }
                addSparseVector((SparseVector) vector);
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            }
        }

        public Map<Object, Map<Object, Object>> getCategoryMaps() {
            return Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(featureValueSets()).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).filter(new VectorIndexer$CategoryStats$$anonfun$getCategoryMaps$1(this))).map(new VectorIndexer$CategoryStats$$anonfun$getCategoryMaps$2(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).toMap(Predef$.MODULE$.$conforms());
        }

        private void addDenseVector(DenseVector denseVector) {
            int size = denseVector.size();
            for (int i = 0; i < size; i++) {
                if (featureValueSets()[i].size() <= org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$maxCategories()) {
                    featureValueSets()[i].add(BoxesRunTime.boxToDouble(denseVector.apply(i)));
                }
            }
        }

        private void addSparseVector(SparseVector sparseVector) {
            double d;
            int i = 0;
            int size = sparseVector.size();
            for (int i2 = 0; i2 < size; i2++) {
                if (i >= sparseVector.indices().length || i2 != sparseVector.indices()[i]) {
                    d = 0.0d;
                } else {
                    i++;
                    d = sparseVector.values()[i - 1];
                }
                double d2 = d;
                if (featureValueSets()[i2].size() <= org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$maxCategories()) {
                    featureValueSets()[i2].add(BoxesRunTime.boxToDouble(d2));
                }
            }
        }

        public CategoryStats(int i, int i2) {
            this.org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$numFeatures = i;
            this.org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$maxCategories = i2;
            this.featureValueSets = (OpenHashSet[]) Array$.MODULE$.fill(i, new VectorIndexer$CategoryStats$$anonfun$5(this), ClassTag$.MODULE$.apply(OpenHashSet.class));
        }
    }

    public static MLReader<VectorIndexer> read() {
        return VectorIndexer$.MODULE$.read();
    }

    public static VectorIndexer load(String str) {
        return VectorIndexer$.MODULE$.load(str);
    }

    @Override // org.apache.spark.ml.util.DefaultParamsWritable, org.apache.spark.ml.util.MLWritable
    public MLWriter write() {
        return DefaultParamsWritable.Cclass.write(this);
    }

    @Override // org.apache.spark.ml.util.MLWritable
    public void save(String str) throws IOException {
        MLWritable.Cclass.save(this, str);
    }

    @Override // org.apache.spark.ml.feature.VectorIndexerParams
    public IntParam maxCategories() {
        return this.maxCategories;
    }

    @Override // org.apache.spark.ml.feature.VectorIndexerParams
    public void org$apache$spark$ml$feature$VectorIndexerParams$_setter_$maxCategories_$eq(IntParam intParam) {
        this.maxCategories = intParam;
    }

    @Override // org.apache.spark.ml.feature.VectorIndexerParams
    public int getMaxCategories() {
        return VectorIndexerParams.Cclass.getMaxCategories(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasOutputCol
    public final Param<String> outputCol() {
        return this.outputCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasOutputCol
    public final void org$apache$spark$ml$param$shared$HasOutputCol$_setter_$outputCol_$eq(Param param) {
        this.outputCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasOutputCol
    public final String getOutputCol() {
        return HasOutputCol.Cclass.getOutputCol(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final Param<String> inputCol() {
        return this.inputCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final void org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(Param param) {
        this.inputCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final String getInputCol() {
        return HasInputCol.Cclass.getInputCol(this);
    }

    @Override // org.apache.spark.ml.util.Identifiable
    public String uid() {
        return this.uid;
    }

    public VectorIndexer setMaxCategories(int i) {
        return (VectorIndexer) set((Param<IntParam>) maxCategories(), (IntParam) BoxesRunTime.boxToInteger(i));
    }

    public VectorIndexer setInputCol(String str) {
        return (VectorIndexer) set((Param<Param<String>>) inputCol(), (Param<String>) str);
    }

    public VectorIndexer setOutputCol(String str) {
        return (VectorIndexer) set((Param<Param<String>>) outputCol(), (Param<String>) str);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.spark.ml.Estimator
    public VectorIndexerModel fit(DataFrame dataFrame) {
        transformSchema(dataFrame.schema(), true);
        Row[] take = dataFrame.select((String) $(inputCol()), Predef$.MODULE$.wrapRefArray(new String[0])).take(1);
        Predef$.MODULE$.require(take.length == 1, new VectorIndexer$$anonfun$fit$1(this));
        int size = ((Vector) take[0].getAs(0)).size();
        RDD map = dataFrame.select((String) $(inputCol()), Predef$.MODULE$.wrapRefArray(new String[0])).map(new VectorIndexer$$anonfun$2(this), ClassTag$.MODULE$.apply(Vector.class));
        return (VectorIndexerModel) copyValues(new VectorIndexerModel(uid(), size, ((CategoryStats) map.mapPartitions(new VectorIndexer$$anonfun$3(this, size, BoxesRunTime.unboxToInt($(maxCategories()))), map.mapPartitions$default$2(), ClassTag$.MODULE$.apply(CategoryStats.class)).reduce(new VectorIndexer$$anonfun$4(this))).getCategoryMaps()).setParent(this), copyValues$default$2());
    }

    @Override // org.apache.spark.ml.PipelineStage
    public StructType transformSchema(StructType structType) {
        DataType vectorUDT = new VectorUDT();
        Predef$.MODULE$.require(isDefined(inputCol()), new VectorIndexer$$anonfun$transformSchema$2(this));
        Predef$.MODULE$.require(isDefined(outputCol()), new VectorIndexer$$anonfun$transformSchema$3(this));
        SchemaUtils$.MODULE$.checkColumnType(structType, (String) $(inputCol()), vectorUDT, SchemaUtils$.MODULE$.checkColumnType$default$4());
        return SchemaUtils$.MODULE$.appendColumn(structType, (String) $(outputCol()), vectorUDT, SchemaUtils$.MODULE$.appendColumn$default$4());
    }

    @Override // org.apache.spark.ml.Estimator, org.apache.spark.ml.PipelineStage, org.apache.spark.ml.param.Params
    public VectorIndexer copy(ParamMap paramMap) {
        return (VectorIndexer) defaultCopy(paramMap);
    }

    public VectorIndexer(String str) {
        this.uid = str;
        HasInputCol.Cclass.$init$(this);
        HasOutputCol.Cclass.$init$(this);
        VectorIndexerParams.Cclass.$init$(this);
        MLWritable.Cclass.$init$(this);
        DefaultParamsWritable.Cclass.$init$(this);
    }

    public VectorIndexer() {
        this(Identifiable$.MODULE$.randomUID("vecIdx"));
    }
}
