/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.ml;

import com.facebook.presto.ml.Classifier;
import com.facebook.presto.ml.Dataset;
import com.facebook.presto.ml.FeatureTransformation;
import com.facebook.presto.ml.FeatureVector;
import com.facebook.presto.ml.Model;
import com.facebook.presto.ml.ModelUtils;
import com.facebook.presto.ml.Types;
import com.facebook.presto.ml.type.ClassifierType;
import com.facebook.presto.ml.type.ModelType;
import java.util.List;
import java.util.Objects;

public class ClassifierFeatureTransformer
implements Classifier<Integer> {
    private final Classifier<Integer> classifier;
    private final FeatureTransformation transformation;

    public ClassifierFeatureTransformer(Classifier<Integer> classifier, FeatureTransformation transformation) {
        this.classifier = Objects.requireNonNull(classifier, "classifier is is null");
        this.transformation = Objects.requireNonNull(transformation, "transformation is null");
    }

    @Override
    public ModelType getType() {
        return ClassifierType.BIGINT_CLASSIFIER;
    }

    @Override
    public byte[] getSerializedData() {
        return ModelUtils.serializeModels(this.classifier, this.transformation);
    }

    public static ClassifierFeatureTransformer deserialize(byte[] data) {
        List<Model> models = ModelUtils.deserializeModels(data);
        return new ClassifierFeatureTransformer(Types.checkType(models.get(0), Classifier.class, "model 0"), Types.checkType(models.get(1), FeatureTransformation.class, "model 1"));
    }

    @Override
    public Integer classify(FeatureVector features) {
        return this.classifier.classify(this.transformation.transform(features));
    }

    @Override
    public void train(Dataset dataset) {
        this.transformation.train(dataset);
        this.classifier.train(this.transformation.transform(dataset));
    }
}

