/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;

import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import org.tensorflow.Tensor;

public class TensorConverter {
    public com.yahoo.tensor.Tensor toVespaTensor(Tensor<?> tfTensor) {
        TensorType type = this.toVespaTensorType(tfTensor.shape());
        Values values = this.readValuesOf(tfTensor);
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of((TensorType)type);
        for (int i = 0; i < values.size(); ++i) {
            builder.cellByDirectIndex((long)i, values.get(i));
        }
        return builder.build();
    }

    private TensorType toVespaTensorType(long[] shape) {
        TensorType.Builder b = new TensorType.Builder();
        int dimensionIndex = 0;
        for (long dimensionSize : shape) {
            if (dimensionSize == 0L) {
                dimensionSize = 1L;
            }
            b.indexed("d" + dimensionIndex++, dimensionSize);
        }
        return b.build();
    }

    private Values readValuesOf(Tensor<?> tfTensor) {
        switch (tfTensor.dataType()) {
            case DOUBLE: {
                return new DoubleValues(tfTensor);
            }
            case FLOAT: {
                return new FloatValues(tfTensor);
            }
        }
        throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + tfTensor.dataType() + " to a Vespa tensor");
    }

    private static class FloatValues
    extends Values {
        private final FloatBuffer values;

        FloatValues(Tensor<?> tfTensor) {
            super(tfTensor.numElements());
            this.values = FloatBuffer.allocate(tfTensor.numElements());
            tfTensor.writeTo(this.values);
        }

        @Override
        double get(int i) {
            return this.values.get(i);
        }
    }

    private static class DoubleValues
    extends Values {
        private final DoubleBuffer values;

        DoubleValues(Tensor<?> tfTensor) {
            super(tfTensor.numElements());
            this.values = DoubleBuffer.allocate(tfTensor.numElements());
            tfTensor.writeTo(this.values);
        }

        @Override
        double get(int i) {
            return this.values.get(i);
        }
    }

    private static abstract class Values {
        private final int size;

        protected Values(int size) {
            this.size = size;
        }

        abstract double get(int var1);

        int size() {
            return this.size;
        }
    }
}

