package org.apache.ignite.ml.naivebayes.compound;

import java.lang.invoke.SerializedLambda;
import java.util.Collection;
import java.util.Collections;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer;
import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

/* loaded from: input_file:org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTrainer.class */
public class CompoundNaiveBayesTrainer extends SingleLabelDatasetTrainer<CompoundNaiveBayesModel> {
    private double[] priorProbabilities;
    private GaussianNaiveBayesTrainer gaussianNaiveBayesTrainer;
    private DiscreteNaiveBayesTrainer discreteNaiveBayesTrainer;
    private Collection<Integer> gaussianFeatureIdsToSkip = Collections.emptyList();
    private Collection<Integer> discreteFeatureIdsToSkip = Collections.emptyList();

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> CompoundNaiveBayesModel fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        return updateModel((CompoundNaiveBayesModel) null, (DatasetBuilder) datasetBuilder, (Preprocessor) preprocessor);
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public boolean isUpdateable(CompoundNaiveBayesModel compoundNaiveBayesModel) {
        return this.gaussianNaiveBayesTrainer.isUpdateable(compoundNaiveBayesModel.getGaussianModel()) && this.discreteNaiveBayesTrainer.isUpdateable(compoundNaiveBayesModel.getDiscreteModel());
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    /* renamed from: withEnvironmentBuilder */
    public CompoundNaiveBayesTrainer withEnvironmentBuilder2(LearningEnvironmentBuilder learningEnvironmentBuilder) {
        return (CompoundNaiveBayesTrainer) super.withEnvironmentBuilder2(learningEnvironmentBuilder);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> CompoundNaiveBayesModel updateModel(CompoundNaiveBayesModel compoundNaiveBayesModel, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        CompoundNaiveBayesModel withPriorProbabilities = new CompoundNaiveBayesModel().withPriorProbabilities(this.priorProbabilities);
        if (this.gaussianNaiveBayesTrainer != null) {
            if (this.priorProbabilities != null) {
                this.gaussianNaiveBayesTrainer.setPriorProbabilities(this.priorProbabilities);
            }
            GaussianNaiveBayesModel gaussianNaiveBayesModel = compoundNaiveBayesModel == null ? (GaussianNaiveBayesModel) this.gaussianNaiveBayesTrainer.fit(datasetBuilder, preprocessor.map(skipFeatures(this.gaussianFeatureIdsToSkip))) : (GaussianNaiveBayesModel) this.gaussianNaiveBayesTrainer.update(compoundNaiveBayesModel.getGaussianModel(), datasetBuilder, preprocessor.map(skipFeatures(this.gaussianFeatureIdsToSkip)));
            withPriorProbabilities.withGaussianModel(gaussianNaiveBayesModel).withGaussianFeatureIdsToSkip(this.gaussianFeatureIdsToSkip).withLabels(gaussianNaiveBayesModel.getLabels()).withPriorProbabilities(this.priorProbabilities);
        }
        if (this.discreteNaiveBayesTrainer != null) {
            if (this.priorProbabilities != null) {
                this.discreteNaiveBayesTrainer.setPriorProbabilities(this.priorProbabilities);
            }
            DiscreteNaiveBayesModel discreteNaiveBayesModel = compoundNaiveBayesModel == null ? (DiscreteNaiveBayesModel) this.discreteNaiveBayesTrainer.fit(datasetBuilder, preprocessor.map(skipFeatures(this.discreteFeatureIdsToSkip))) : (DiscreteNaiveBayesModel) this.discreteNaiveBayesTrainer.update(compoundNaiveBayesModel.getDiscreteModel(), datasetBuilder, preprocessor.map(skipFeatures(this.discreteFeatureIdsToSkip)));
            withPriorProbabilities.withDiscreteModel(discreteNaiveBayesModel).withDiscreteFeatureIdsToSkip(this.discreteFeatureIdsToSkip).withLabels(discreteNaiveBayesModel.getLabels()).withPriorProbabilities(this.priorProbabilities);
        }
        return withPriorProbabilities;
    }

    public CompoundNaiveBayesTrainer withPriorProbabilities(double[] dArr) {
        this.priorProbabilities = (double[]) dArr.clone();
        return this;
    }

    public CompoundNaiveBayesTrainer withGaussianNaiveBayesTrainer(GaussianNaiveBayesTrainer gaussianNaiveBayesTrainer) {
        this.gaussianNaiveBayesTrainer = gaussianNaiveBayesTrainer;
        return this;
    }

    public CompoundNaiveBayesTrainer withDiscreteNaiveBayesTrainer(DiscreteNaiveBayesTrainer discreteNaiveBayesTrainer) {
        this.discreteNaiveBayesTrainer = discreteNaiveBayesTrainer;
        return this;
    }

    public CompoundNaiveBayesTrainer withGaussianFeatureIdsToSkip(Collection<Integer> collection) {
        this.gaussianFeatureIdsToSkip = collection;
        return this;
    }

    public CompoundNaiveBayesTrainer withDiscreteFeatureIdsToSkip(Collection<Integer> collection) {
        this.discreteFeatureIdsToSkip = collection;
        return this;
    }

    private static IgniteFunction<LabeledVector<Object>, LabeledVector<Object>> skipFeatures(Collection<Integer> collection) {
        return labeledVector -> {
            int size = labeledVector.features().size();
            double[] dArr = new double[size - collection.size()];
            int i = 0;
            for (int i2 = 0; i2 < size; i2++) {
                if (!collection.contains(Integer.valueOf(i2))) {
                    dArr[i] = labeledVector.get(i2);
                    i++;
                }
            }
            return new LabeledVector(VectorUtils.of(dArr), labeledVector.label());
        };
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -86716503:
                if (implMethodName.equals("lambda$skipFeatures$139cb6e6$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Collection;Lorg/apache/ignite/ml/structures/LabeledVector;)Lorg/apache/ignite/ml/structures/LabeledVector;")) {
                    Collection collection = (Collection) serializedLambda.getCapturedArg(0);
                    return labeledVector -> {
                        int size = labeledVector.features().size();
                        double[] dArr = new double[size - collection.size()];
                        int i = 0;
                        for (int i2 = 0; i2 < size; i2++) {
                            if (!collection.contains(Integer.valueOf(i2))) {
                                dArr[i] = labeledVector.get(i2);
                                i++;
                            }
                        }
                        return new LabeledVector(VectorUtils.of(dArr), labeledVector.label());
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
