/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.ml;

import com.facebook.presto.ml.ClassifierFeatureTransformer;
import com.facebook.presto.ml.FeatureUnitNormalizer;
import com.facebook.presto.ml.FeatureVector;
import com.facebook.presto.ml.FeatureVectorUnitNormalizer;
import com.facebook.presto.ml.Model;
import com.facebook.presto.ml.RegressorFeatureTransformer;
import com.facebook.presto.ml.StringClassifierAdapter;
import com.facebook.presto.ml.SvmClassifier;
import com.facebook.presto.ml.SvmRegressor;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.DoubleType;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.hash.HashCode;
import com.google.common.hash.Hashing;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;

public final class ModelUtils {
    private static final int VERSION_OFFSET = 0;
    private static final int HASH_OFFSET = 4;
    private static final int ALGORITHM_OFFSET = 36;
    private static final int HYPERPARAMETER_LENGTH_OFFSET = 40;
    private static final int HYPERPARAMETERS_OFFSET = 44;
    private static final int CURRENT_FORMAT_VERSION = 1;
    @VisibleForTesting
    static final BiMap<Class<? extends Model>, Integer> MODEL_SERIALIZATION_IDS;

    private ModelUtils() {
    }

    public static Slice serialize(Model model) {
        Objects.requireNonNull(model, "model is null");
        Integer id = (Integer)MODEL_SERIALIZATION_IDS.get(model.getClass());
        Objects.requireNonNull(id, "id is null");
        int size = 44;
        byte[] hyperparameters = new byte[]{};
        int dataLengthOffset = size += hyperparameters.length;
        int dataOffset = size += 8;
        byte[] data = model.getSerializedData();
        Slice slice = Slices.allocate((int)(size += data.length));
        slice.setInt(0, 1);
        slice.setInt(36, id.intValue());
        slice.setInt(40, hyperparameters.length);
        slice.setBytes(44, hyperparameters);
        slice.setLong(dataLengthOffset, (long)data.length);
        slice.setBytes(dataOffset, data);
        byte[] modelHash = Hashing.sha256().hashBytes(slice.getBytes(36, slice.length() - 36)).asBytes();
        Preconditions.checkState((modelHash.length == 32 ? 1 : 0) != 0, (Object)"sha256 hash code expected to be 32 bytes");
        slice.setBytes(4, modelHash);
        return slice;
    }

    public static HashCode modelHash(Slice slice) {
        return HashCode.fromBytes((byte[])slice.getBytes(4, 32));
    }

    public static Model deserialize(byte[] data) {
        return ModelUtils.deserialize(Slices.wrappedBuffer((byte[])data));
    }

    public static Model deserialize(Slice slice) {
        int version = slice.getInt(0);
        Preconditions.checkArgument((version == 1 ? 1 : 0) != 0, (Object)String.format("Unsupported version: %d", version));
        byte[] modelHashBytes = slice.getBytes(4, 32);
        HashCode expectedHash = HashCode.fromBytes((byte[])modelHashBytes);
        HashCode actualHash = Hashing.sha256().hashBytes(slice.getBytes(36, slice.length() - 36));
        Preconditions.checkArgument((boolean)actualHash.equals((Object)expectedHash), (Object)"model hash does not match data");
        int id = slice.getInt(36);
        Class algorithm = (Class)MODEL_SERIALIZATION_IDS.inverse().get((Object)id);
        Objects.requireNonNull(algorithm, String.format("Unsupported algorith %d", id));
        int hyperparameterLength = slice.getInt(40);
        byte[] hyperparameterBytes = slice.getBytes(44, hyperparameterLength);
        int dataLengthOffset = 44 + hyperparameterLength;
        long dataLength = slice.getLong(dataLengthOffset);
        int dataOffset = dataLengthOffset + 8;
        byte[] data = slice.getBytes(dataOffset, (int)dataLength);
        try {
            Method deserialize = algorithm.getMethod("deserialize", byte[].class);
            return (Model)deserialize.invoke(null, new Object[]{data});
        }
        catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
            throw Throwables.propagate((Throwable)e);
        }
    }

    public static byte[] serializeModels(Model ... models) {
        ArrayList<byte[]> serializedModels = new ArrayList<byte[]>();
        int size = 4 + 4 * models.length;
        for (Model model : models) {
            byte[] bytes = ModelUtils.serialize(model).getBytes();
            size += bytes.length;
            serializedModels.add(bytes);
        }
        Slice slice = Slices.allocate((int)size);
        slice.setInt(0, models.length);
        for (int i = 0; i < models.length; ++i) {
            slice.setInt(4 * (i + 1), ((byte[])serializedModels.get(i)).length);
        }
        int offset = 4 + 4 * models.length;
        for (byte[] bytes : serializedModels) {
            slice.setBytes(offset, bytes);
            offset += bytes.length;
        }
        return slice.getBytes();
    }

    public static List<Model> deserializeModels(byte[] bytes) {
        Slice slice = Slices.wrappedBuffer((byte[])bytes);
        int numModels = slice.getInt(0);
        int offset = 4 + 4 * numModels;
        ImmutableList.Builder models = ImmutableList.builder();
        for (int i = 0; i < numModels; ++i) {
            int length = slice.getInt(4 * (i + 1));
            models.add((Object)ModelUtils.deserialize(slice.getBytes(offset, length)));
            offset += length;
        }
        return models.build();
    }

    public static FeatureVector toFeatures(Block map) {
        HashMap<Integer, Double> features = new HashMap<Integer, Double>();
        if (map != null) {
            for (int position = 0; position < map.getPositionCount(); position += 2) {
                features.put((int)BigintType.BIGINT.getLong(map, position), DoubleType.DOUBLE.getDouble(map, position + 1));
            }
        }
        return new FeatureVector(features);
    }

    static {
        ImmutableBiMap.Builder builder = ImmutableBiMap.builder();
        builder.put(SvmClassifier.class, (Object)1);
        builder.put(SvmRegressor.class, (Object)2);
        builder.put(FeatureVectorUnitNormalizer.class, (Object)3);
        builder.put(ClassifierFeatureTransformer.class, (Object)4);
        builder.put(RegressorFeatureTransformer.class, (Object)5);
        builder.put(FeatureUnitNormalizer.class, (Object)6);
        builder.put(StringClassifierAdapter.class, (Object)7);
        MODEL_SERIALIZATION_IDS = builder.build();
    }
}

