/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;

import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;

public class AttrValueConverter {
    public static Tensor toVespaTensor(NodeDef tfNode, String attr) {
        if (!tfNode.getAttrMap().containsKey(attr)) {
            throw new IllegalArgumentException(tfNode.getName() + " has no attribute called " + attr);
        }
        AttrValue attrValue = (AttrValue)tfNode.getAttrMap().get(attr);
        switch (attrValue.getValueCase()) {
            case TENSOR: {
                return AttrValueConverter.buildFromTensor(attrValue);
            }
            case B: {
                return AttrValueConverter.buildFromSingleValue(attrValue.getB() ? 1.0 : 0.0);
            }
            case F: {
                return AttrValueConverter.buildFromSingleValue(attrValue.getF());
            }
            case I: {
                return AttrValueConverter.buildFromSingleValue(attrValue.getI());
            }
        }
        throw new IllegalArgumentException(tfNode.getName() + ": unsupported attribute type: '" + attrValue.getValueCase().toString() + "'");
    }

    private static Tensor buildFromSingleValue(double value) {
        TensorType type = new TensorType.Builder().build();
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of((TensorType)type);
        builder.cellByDirectIndex(0L, value);
        return builder.build();
    }

    private static Tensor buildFromTensor(AttrValue attrValue) {
        TensorProto tensorProto = attrValue.getTensor();
        TensorType type = AttrValueConverter.toVespaTensorType(tensorProto.getTensorShape());
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of((TensorType)type);
        Values values = AttrValueConverter.valuesOf(tensorProto);
        for (int i = 0; i < values.size(); ++i) {
            builder.cellByDirectIndex((long)i, values.get(i));
        }
        IndexedTensor tensor = builder.build();
        return tensor;
    }

    private static Values valuesOf(TensorProto tensorProto) {
        switch (tensorProto.getDtype()) {
            case DT_BOOL: {
                return new BoolValues(tensorProto);
            }
            case DT_HALF: {
                return new HalfValues(tensorProto);
            }
            case DT_INT16: 
            case DT_INT32: {
                return new IntValues(tensorProto);
            }
            case DT_INT64: {
                return new Int64Values(tensorProto);
            }
            case DT_FLOAT: {
                return new FloatValues(tensorProto);
            }
            case DT_DOUBLE: {
                return new DoubleValues(tensorProto);
            }
        }
        throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
    }

    public static TensorType toVespaTensorType(TensorShapeProto shapeProto) {
        TensorType.Builder b = new TensorType.Builder();
        for (TensorShapeProto.Dim dimension : shapeProto.getDimList()) {
            int dimensionSize = (int)dimension.getSize();
            if (dimensionSize >= 0) {
                b.indexed("d" + b.rank(), (long)dimensionSize);
                continue;
            }
            b.indexed("d" + b.rank());
        }
        return b.build();
    }

    private static class DoubleValues
    extends Values {
        DoubleValues(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        double get(int i) {
            return this.tensorProto.getDoubleVal(i);
        }

        @Override
        int size() {
            return this.tensorProto.getDoubleValCount();
        }
    }

    private static class FloatValues
    extends Values {
        FloatValues(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        double get(int i) {
            return this.tensorProto.getFloatVal(i);
        }

        @Override
        int size() {
            return this.tensorProto.getFloatValCount();
        }
    }

    private static class Int64Values
    extends Values {
        Int64Values(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        double get(int i) {
            return this.tensorProto.getInt64Val(i);
        }

        @Override
        int size() {
            return this.tensorProto.getInt64ValCount();
        }
    }

    private static class IntValues
    extends Values {
        IntValues(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        double get(int i) {
            return this.tensorProto.getIntVal(i);
        }

        @Override
        int size() {
            return this.tensorProto.getIntValCount();
        }
    }

    private static class HalfValues
    extends Values {
        HalfValues(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        double get(int i) {
            return this.tensorProto.getHalfVal(i);
        }

        @Override
        int size() {
            return this.tensorProto.getHalfValCount();
        }
    }

    private static class BoolValues
    extends Values {
        BoolValues(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        double get(int i) {
            return this.tensorProto.getBoolVal(i) ? 1.0 : 0.0;
        }

        @Override
        int size() {
            return this.tensorProto.getBoolValCount();
        }
    }

    private static abstract class Values {
        protected final TensorProto tensorProto;

        protected Values(TensorProto tensorProto) {
            this.tensorProto = tensorProto;
        }

        abstract double get(int var1);

        abstract int size();
    }
}

