/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;

import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TensorTypeParser;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.TensorShapeProto;

public class OrderedTensorType {
    private final TensorType type;
    private final List<TensorType.Dimension> dimensions;
    private final long[] innerSizesTensorFlow;
    private final long[] innerSizesVespa;
    private final int[] dimensionMap;

    private OrderedTensorType(List<TensorType.Dimension> dimensions) {
        this.dimensions = Collections.unmodifiableList(dimensions);
        this.type = new TensorType.Builder(dimensions).build();
        this.innerSizesTensorFlow = new long[dimensions.size()];
        this.innerSizesVespa = new long[dimensions.size()];
        this.dimensionMap = this.createDimensionMap();
    }

    public TensorType type() {
        return this.type;
    }

    public int rank() {
        return this.dimensions.size();
    }

    public List<TensorType.Dimension> dimensions() {
        return this.dimensions;
    }

    public List<String> dimensionNames() {
        return this.dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList());
    }

    private int[] createDimensionMap() {
        int numDimensions = this.dimensions.size();
        if (numDimensions == 0) {
            return null;
        }
        this.innerSizesTensorFlow[numDimensions - 1] = 1L;
        this.innerSizesVespa[numDimensions - 1] = 1L;
        int i = numDimensions - 1;
        while (--i >= 0) {
            this.innerSizesTensorFlow[i] = this.dimensions().get(i + 1).size().orElse(-1L) * this.innerSizesTensorFlow[i + 1];
            this.innerSizesVespa[i] = ((TensorType.Dimension)this.type.dimensions().get(i + 1)).size().orElse(-1L) * this.innerSizesVespa[i + 1];
        }
        int[] mapping = new int[numDimensions];
        block1: for (int i2 = 0; i2 < numDimensions; ++i2) {
            TensorType.Dimension dim1 = this.dimensions().get(i2);
            for (int j = 0; j < numDimensions; ++j) {
                TensorType.Dimension dim2 = (TensorType.Dimension)this.type.dimensions().get(j);
                if (!dim1.equals((Object)dim2)) continue;
                mapping[i2] = j;
                continue block1;
            }
        }
        return mapping;
    }

    public int toDirectIndex(int index) {
        if (this.dimensions.size() == 0) {
            return 0;
        }
        if (this.dimensionMap == null) {
            throw new IllegalArgumentException("Dimension map is not available");
        }
        int directIndex = 0;
        long rest = index;
        for (int i = 0; i < this.dimensions.size(); ++i) {
            long address = rest / this.innerSizesTensorFlow[i];
            directIndex = (int)((long)directIndex + this.innerSizesVespa[this.dimensionMap[i]] * address);
            rest %= this.innerSizesTensorFlow[i];
        }
        return directIndex;
    }

    public boolean equals(Object obj) {
        if (obj == null || !(obj instanceof OrderedTensorType)) {
            return false;
        }
        OrderedTensorType other = (OrderedTensorType)obj;
        if (this.dimensions.size() != this.dimensions.size()) {
            return false;
        }
        List<TensorType.Dimension> thisDimensions = this.dimensions();
        List<TensorType.Dimension> otherDimensions = other.dimensions();
        for (int i = 0; i < thisDimensions.size(); ++i) {
            if (thisDimensions.get(i).equals((Object)otherDimensions.get(i))) continue;
            return false;
        }
        return true;
    }

    public void verifyType(NodeDef node) {
        TensorShapeProto shape = OrderedTensorType.tensorFlowShape(node);
        if (shape != null) {
            if (shape.getDimCount() != this.type.rank()) {
                throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' does not match Vespa shape");
            }
            for (int tensorFlowIndex = 0; tensorFlowIndex < this.dimensions.size(); ++tensorFlowIndex) {
                int vespaIndex = this.dimensionMap[tensorFlowIndex];
                TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
                TensorType.Dimension vespaDimension = (TensorType.Dimension)this.type().dimensions().get(vespaIndex);
                if (tensorFlowDimension.getSize() == vespaDimension.size().orElse(-1L).longValue()) continue;
                throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' does not match Vespa dimensions");
            }
        }
    }

    private static TensorShapeProto tensorFlowShape(NodeDef node) {
        AttrValue attrValueList = (AttrValue)node.getAttrMap().get("_output_shapes");
        if (attrValueList == null) {
            throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' does not exist");
        }
        if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
            throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' is not of expected type");
        }
        List shapeList = attrValueList.getList().getShapeList();
        return (TensorShapeProto)shapeList.get(0);
    }

    public OrderedTensorType rename(DimensionRenamer renamer) {
        ArrayList<TensorType.Dimension> renamedDimensions = new ArrayList<TensorType.Dimension>(this.dimensions.size());
        for (TensorType.Dimension dimension : this.dimensions) {
            String oldName = dimension.name();
            Optional<String> newName = renamer.dimensionNameOf(oldName);
            if (!newName.isPresent()) {
                return this;
            }
            TensorType.Dimension.Type dimensionType = dimension.type();
            if (dimensionType == TensorType.Dimension.Type.indexedBound) {
                renamedDimensions.add(TensorType.Dimension.indexed((String)newName.get(), (long)((Long)dimension.size().get())));
                continue;
            }
            if (dimensionType == TensorType.Dimension.Type.indexedUnbound) {
                renamedDimensions.add(TensorType.Dimension.indexed((String)newName.get()));
                continue;
            }
            if (dimensionType != TensorType.Dimension.Type.mapped) continue;
            renamedDimensions.add(TensorType.Dimension.mapped((String)newName.get()));
        }
        return new OrderedTensorType(renamedDimensions);
    }

    public String toString() {
        return "tensor(" + this.dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")";
    }

    public static OrderedTensorType fromSpec(String typeSpec) {
        return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec((String)typeSpec));
    }

    public static OrderedTensorType fromTensorFlowType(NodeDef node) {
        return OrderedTensorType.fromTensorFlowType(node, "d");
    }

    public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
        Builder builder = new Builder(node);
        TensorShapeProto shape = OrderedTensorType.tensorFlowShape(node);
        for (int i = 0; i < shape.getDimCount(); ++i) {
            String dimensionName = dimensionPrefix + i;
            TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
            if (tensorFlowDimension.getSize() >= 0L) {
                builder.add(TensorType.Dimension.indexed((String)dimensionName, (long)tensorFlowDimension.getSize()));
                continue;
            }
            builder.add(TensorType.Dimension.indexed((String)dimensionName));
        }
        return builder.build();
    }

    public static class Builder {
        private final TensorShapeProto shape;
        private final List<TensorType.Dimension> dimensions;

        public Builder(NodeDef node) {
            this.shape = OrderedTensorType.tensorFlowShape(node);
            this.dimensions = new ArrayList<TensorType.Dimension>(this.shape.getDimCount());
        }

        public Builder add(TensorType.Dimension vespaDimension) {
            int index = this.dimensions.size();
            TensorShapeProto.Dim tensorFlowDimension = this.shape.getDim(index);
            long size = tensorFlowDimension.getSize();
            if (size >= 0L) {
                if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) {
                    throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa dimension types");
                }
                if (!vespaDimension.size().isPresent()) {
                    throw new IllegalArgumentException("Tensor dimension is indexed bound but does not have a size");
                }
                if ((Long)vespaDimension.size().get() != size) {
                    throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa dimension sizes. TensorFlow: " + size + " Vespa: " + vespaDimension.size().get());
                }
            } else if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) {
                throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa dimension types");
            }
            this.dimensions.add(vespaDimension);
            return this;
        }

        public OrderedTensorType build() {
            return new OrderedTensorType(this.dimensions);
        }
    }
}

