/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.serialization;

import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.BinaryFormat;
import com.yahoo.tensor.serialization.DenseBinaryFormat;
import com.yahoo.tensor.serialization.MixedBinaryFormat;
import com.yahoo.tensor.serialization.SparseBinaryFormat;
import java.util.Optional;

public class TypedBinaryFormat {
    private static final int SPARSE_BINARY_FORMAT_TYPE = 1;
    private static final int DENSE_BINARY_FORMAT_TYPE = 2;
    private static final int MIXED_BINARY_FORMAT_TYPE = 3;
    private static final int SPARSE_BINARY_FORMAT_WITH_CELLTYPE = 5;
    private static final int DENSE_BINARY_FORMAT_WITH_CELLTYPE = 6;
    private static final int MIXED_BINARY_FORMAT_WITH_CELLTYPE = 7;
    private static final int DOUBLE_VALUE_TYPE = 0;
    private static final int FLOAT_VALUE_TYPE = 1;
    private static final int BFLOAT16_VALUE_TYPE = 2;
    private static final int INT8_VALUE_TYPE = 3;

    public static byte[] encode(Tensor tensor) {
        GrowableByteBuffer buffer = new GrowableByteBuffer();
        return TypedBinaryFormat.asByteArray(TypedBinaryFormat.encode(tensor, buffer));
    }

    public static GrowableByteBuffer encode(Tensor tensor, GrowableByteBuffer buffer) {
        BinaryFormat encoder = TypedBinaryFormat.getFormatEncoder(buffer, tensor);
        encoder.encode(buffer, tensor);
        return buffer;
    }

    public static Tensor decode(Optional<TensorType> type, GrowableByteBuffer buffer) {
        BinaryFormat decoder = TypedBinaryFormat.getFormatDecoder(buffer);
        return decoder.decode(type, buffer);
    }

    private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) {
        boolean isMixed;
        boolean hasMappedDimensions = tensor.type().hasMappedDimensions();
        boolean hasIndexedDimensions = tensor.type().hasIndexedDimensions();
        boolean bl = isMixed = hasMappedDimensions && hasIndexedDimensions;
        if (tensor instanceof MixedTensor && !isMixed && hasIndexedDimensions) {
            isMixed = true;
        }
        if (isMixed && tensor.type().valueType() == TensorType.Value.DOUBLE) {
            TypedBinaryFormat.encodeFormatType(buffer, 3);
            return new MixedBinaryFormat();
        }
        if (isMixed) {
            TypedBinaryFormat.encodeFormatType(buffer, 7);
            TypedBinaryFormat.encodeValueType(buffer, tensor.type().valueType());
            return new MixedBinaryFormat(tensor.type().valueType());
        }
        if (hasIndexedDimensions && tensor.type().valueType() == TensorType.Value.DOUBLE) {
            TypedBinaryFormat.encodeFormatType(buffer, 2);
            return new DenseBinaryFormat();
        }
        if (hasIndexedDimensions) {
            TypedBinaryFormat.encodeFormatType(buffer, 6);
            TypedBinaryFormat.encodeValueType(buffer, tensor.type().valueType());
            return new DenseBinaryFormat(tensor.type().valueType());
        }
        if (tensor.type().valueType() == TensorType.Value.DOUBLE) {
            TypedBinaryFormat.encodeFormatType(buffer, 1);
            return new SparseBinaryFormat();
        }
        TypedBinaryFormat.encodeFormatType(buffer, 5);
        TypedBinaryFormat.encodeValueType(buffer, tensor.type().valueType());
        return new SparseBinaryFormat(tensor.type().valueType());
    }

    private static BinaryFormat getFormatDecoder(GrowableByteBuffer buffer) {
        int formatType = TypedBinaryFormat.decodeFormatType(buffer);
        switch (formatType) {
            case 1: {
                return new SparseBinaryFormat();
            }
            case 2: {
                return new DenseBinaryFormat();
            }
            case 3: {
                return new MixedBinaryFormat();
            }
            case 5: {
                return new SparseBinaryFormat(TypedBinaryFormat.decodeValueType(buffer));
            }
            case 6: {
                return new DenseBinaryFormat(TypedBinaryFormat.decodeValueType(buffer));
            }
            case 7: {
                return new MixedBinaryFormat(TypedBinaryFormat.decodeValueType(buffer));
            }
        }
        throw new IllegalArgumentException("Binary format type " + formatType + " is unknown");
    }

    private static void encodeFormatType(GrowableByteBuffer buffer, int formatType) {
        buffer.putInt1_4Bytes(formatType);
    }

    private static int decodeFormatType(GrowableByteBuffer buffer) {
        return buffer.getInt1_4Bytes();
    }

    private static void encodeValueType(GrowableByteBuffer buffer, TensorType.Value valueType) {
        switch (valueType) {
            case DOUBLE: {
                buffer.putInt1_4Bytes(0);
                break;
            }
            case FLOAT: {
                buffer.putInt1_4Bytes(1);
                break;
            }
            case BFLOAT16: {
                buffer.putInt1_4Bytes(2);
                break;
            }
            case INT8: {
                buffer.putInt1_4Bytes(3);
                break;
            }
            default: {
                throw new IllegalArgumentException("Attempt to encode unknown tensor value type: " + valueType);
            }
        }
    }

    private static TensorType.Value decodeValueType(GrowableByteBuffer buffer) {
        int valueType = buffer.getInt1_4Bytes();
        switch (valueType) {
            case 0: {
                return TensorType.Value.DOUBLE;
            }
            case 1: {
                return TensorType.Value.FLOAT;
            }
            case 2: {
                return TensorType.Value.BFLOAT16;
            }
            case 3: {
                return TensorType.Value.INT8;
            }
        }
        throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. Only 0(double), 1(float), 2(bfloat16), or 3(int8) is legal.");
    }

    private static byte[] asByteArray(GrowableByteBuffer buffer) {
        buffer.flip();
        byte[] result = new byte[buffer.remaining()];
        buffer.get(result);
        return result;
    }

    static short bFloat16BitsFromFloat(float val) {
        return (short)(Float.floatToRawIntBits(val) >>> 16);
    }

    static float floatFromBFloat16Bits(short bits) {
        return Float.intBitsToFloat(bits << 16);
    }
}

