/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.ml;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.BasicSliceInput;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.plugin.ml.Classifier;
import io.trino.plugin.ml.Dataset;
import io.trino.plugin.ml.FeatureVector;
import io.trino.plugin.ml.Model;
import io.trino.plugin.ml.ModelUtils;
import io.trino.plugin.ml.type.ClassifierType;
import io.trino.plugin.ml.type.ModelType;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

public class StringClassifierAdapter
implements Classifier<String> {
    private final Classifier<Integer> classifier;
    private final Map<Integer, String> labelEnumeration;

    public StringClassifierAdapter(Classifier<Integer> classifier) {
        this(classifier, new HashMap<Integer, String>());
    }

    public StringClassifierAdapter(Classifier<Integer> classifier, Map<Integer, String> labelEnumeration) {
        this.classifier = Objects.requireNonNull(classifier, "classifier is null");
        this.labelEnumeration = Objects.requireNonNull(labelEnumeration, "labelEnumeration is null");
    }

    @Override
    public ModelType getType() {
        return ClassifierType.VARCHAR_CLASSIFIER;
    }

    @Override
    public byte[] getSerializedData() {
        byte[] classifierBytes = ModelUtils.serialize(this.classifier).getBytes();
        DynamicSliceOutput output = new DynamicSliceOutput(classifierBytes.length + 64 * this.labelEnumeration.size());
        output.appendInt(classifierBytes.length);
        output.appendBytes(classifierBytes);
        output.appendInt(this.labelEnumeration.size());
        for (Map.Entry<Integer, String> entry : this.labelEnumeration.entrySet()) {
            output.appendInt(entry.getKey().intValue());
            byte[] bytes = entry.getValue().getBytes(StandardCharsets.UTF_8);
            output.appendInt(bytes.length);
            output.appendBytes(bytes);
        }
        return output.slice().getBytes();
    }

    public static StringClassifierAdapter deserialize(byte[] data) {
        Slice slice = Slices.wrappedBuffer((byte[])data);
        BasicSliceInput input = slice.getInput();
        int classifierLength = input.readInt();
        Model classifier = ModelUtils.deserialize(input.readSlice(classifierLength));
        int numEnumerations = input.readInt();
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (int i = 0; i < numEnumerations; ++i) {
            int key = input.readInt();
            int valueLength = input.readInt();
            String value = input.readSlice(valueLength).toStringUtf8();
            builder.put((Object)key, (Object)value);
        }
        return new StringClassifierAdapter((Classifier)classifier, (Map<Integer, String>)builder.buildOrThrow());
    }

    @Override
    public String classify(FeatureVector features) {
        int prediction = this.classifier.classify(features);
        Preconditions.checkState((boolean)this.labelEnumeration.containsKey(prediction), (String)"classifier predicted an unknown class %s", (int)prediction);
        return this.labelEnumeration.get(prediction);
    }

    @Override
    public void train(Dataset dataset) {
        this.labelEnumeration.putAll(dataset.getLabelEnumeration());
        this.classifier.train(dataset);
    }
}

