/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.tensorflow;

import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.tensorflow.TypeConverter;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import org.tensorflow.Tensor;
import org.tensorflow.framework.TensorProto;

public class TensorConverter {
    public static com.yahoo.tensor.Tensor toVespaTensor(Tensor<?> tfTensor) {
        return TensorConverter.toVespaTensor(tfTensor, "d");
    }

    private static com.yahoo.tensor.Tensor toVespaTensor(Tensor<?> tfTensor, String dimensionPrefix) {
        TensorType type = TypeConverter.typeFrom(tfTensor, dimensionPrefix);
        Values values = TensorConverter.readValuesOf(tfTensor);
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of((TensorType)type);
        for (int i = 0; i < values.size(); ++i) {
            builder.cellByDirectIndex((long)i, values.get(i));
        }
        return builder.build();
    }

    static com.yahoo.tensor.Tensor toVespaTensor(Tensor<?> tfTensor, OrderedTensorType type) {
        Values values = TensorConverter.readValuesOf(tfTensor);
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of((TensorType)type.type());
        for (int i = 0; i < values.size(); ++i) {
            builder.cellByDirectIndex((long)type.toDirectIndex(i), values.get(i));
        }
        return builder.build();
    }

    static com.yahoo.tensor.Tensor toVespaTensor(TensorProto tensorProto, TensorType type) {
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of((TensorType)type);
        Values values = TensorConverter.readValuesOf(tensorProto);
        for (int i = 0; i < values.size(); ++i) {
            builder.cellByDirectIndex((long)i, values.get(i));
        }
        return builder.build();
    }

    public static Long tensorSize(TensorType type) {
        Long size = 1L;
        for (TensorType.Dimension dimension : type.dimensions()) {
            size = size * TensorConverter.dimensionSize(dimension);
        }
        return size;
    }

    private static Long dimensionSize(TensorType.Dimension dim) {
        return (Long)dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
    }

    private static Values readValuesOf(Tensor<?> tfTensor) {
        switch (tfTensor.dataType()) {
            case DOUBLE: {
                return new DoubleValues(tfTensor);
            }
            case FLOAT: {
                return new FloatValues(tfTensor);
            }
            case BOOL: {
                return new BoolValues(tfTensor);
            }
            case UINT8: {
                return new IntValues(tfTensor);
            }
            case INT32: {
                return new IntValues(tfTensor);
            }
            case INT64: {
                return new LongValues(tfTensor);
            }
        }
        throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + tfTensor.dataType() + " to a Vespa tensor");
    }

    private static Values readValuesOf(TensorProto tensorProto) {
        switch (tensorProto.getDtype()) {
            case DT_BOOL: {
                return new ProtoBoolValues(tensorProto);
            }
            case DT_HALF: {
                return new ProtoHalfValues(tensorProto);
            }
            case DT_INT16: 
            case DT_INT32: {
                return new ProtoIntValues(tensorProto);
            }
            case DT_INT64: {
                return new ProtoInt64Values(tensorProto);
            }
            case DT_FLOAT: {
                return new ProtoFloatValues(tensorProto);
            }
            case DT_DOUBLE: {
                return new ProtoDoubleValues(tensorProto);
            }
        }
        throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
    }

    private static class ProtoDoubleValues
    extends ProtoValues {
        ProtoDoubleValues(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        double get(int i) {
            return this.tensorProto.getDoubleVal(i);
        }

        @Override
        int size() {
            return this.tensorProto.getDoubleValCount();
        }
    }

    private static class ProtoFloatValues
    extends ProtoValues {
        ProtoFloatValues(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        double get(int i) {
            return this.tensorProto.getFloatVal(i);
        }

        @Override
        int size() {
            return this.tensorProto.getFloatValCount();
        }
    }

    private static class ProtoInt64Values
    extends ProtoValues {
        ProtoInt64Values(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        double get(int i) {
            return this.tensorProto.getInt64Val(i);
        }

        @Override
        int size() {
            return this.tensorProto.getInt64ValCount();
        }
    }

    private static class ProtoIntValues
    extends ProtoValues {
        ProtoIntValues(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        double get(int i) {
            return this.tensorProto.getIntVal(i);
        }

        @Override
        int size() {
            return this.tensorProto.getIntValCount();
        }
    }

    private static class ProtoHalfValues
    extends ProtoValues {
        ProtoHalfValues(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        double get(int i) {
            return this.tensorProto.getHalfVal(i);
        }

        @Override
        int size() {
            return this.tensorProto.getHalfValCount();
        }
    }

    private static class ProtoBoolValues
    extends ProtoValues {
        ProtoBoolValues(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        double get(int i) {
            return this.tensorProto.getBoolVal(i) ? 1.0 : 0.0;
        }

        @Override
        int size() {
            return this.tensorProto.getBoolValCount();
        }
    }

    private static abstract class ProtoValues
    extends Values {
        final TensorProto tensorProto;

        ProtoValues(TensorProto tensorProto) {
            this.tensorProto = tensorProto;
        }
    }

    private static class LongValues
    extends TensorFlowValues {
        private final LongBuffer values;

        LongValues(Tensor<?> tfTensor) {
            super(tfTensor.numElements());
            this.values = LongBuffer.allocate(tfTensor.numElements());
            tfTensor.writeTo(this.values);
        }

        @Override
        double get(int i) {
            return this.values.get(i);
        }
    }

    private static class IntValues
    extends TensorFlowValues {
        private final IntBuffer values;

        IntValues(Tensor<?> tfTensor) {
            super(tfTensor.numElements());
            this.values = IntBuffer.allocate(tfTensor.numElements());
            tfTensor.writeTo(this.values);
        }

        @Override
        double get(int i) {
            return this.values.get(i);
        }
    }

    private static class BoolValues
    extends TensorFlowValues {
        private final ByteBuffer values;

        BoolValues(Tensor<?> tfTensor) {
            super(tfTensor.numElements());
            this.values = ByteBuffer.allocate(tfTensor.numElements());
            tfTensor.writeTo(this.values);
        }

        @Override
        double get(int i) {
            return this.values.get(i);
        }
    }

    private static class FloatValues
    extends TensorFlowValues {
        private final FloatBuffer values;

        FloatValues(Tensor<?> tfTensor) {
            super(tfTensor.numElements());
            this.values = FloatBuffer.allocate(tfTensor.numElements());
            tfTensor.writeTo(this.values);
        }

        @Override
        double get(int i) {
            return this.values.get(i);
        }
    }

    private static class DoubleValues
    extends TensorFlowValues {
        private final DoubleBuffer values;

        DoubleValues(Tensor<?> tfTensor) {
            super(tfTensor.numElements());
            this.values = DoubleBuffer.allocate(tfTensor.numElements());
            tfTensor.writeTo(this.values);
        }

        @Override
        double get(int i) {
            return this.values.get(i);
        }
    }

    private static abstract class TensorFlowValues
    extends Values {
        private final int size;

        TensorFlowValues(int size) {
            this.size = size;
        }

        @Override
        int size() {
            return this.size;
        }
    }

    private static abstract class Values {
        private Values() {
        }

        abstract double get(int var1);

        abstract int size();
    }
}

