/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;

import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import onnx.Onnx;

public class OrderedTensorType {
    private final TensorType type;
    private final List<TensorType.Dimension> dimensions;
    private final long[] innerSizesOnnx;
    private final long[] innerSizesVespa;
    private final int[] dimensionMap;

    private OrderedTensorType(List<TensorType.Dimension> dimensions) {
        this.dimensions = Collections.unmodifiableList(dimensions);
        this.type = new TensorType.Builder(dimensions).build();
        this.innerSizesOnnx = new long[dimensions.size()];
        this.innerSizesVespa = new long[dimensions.size()];
        this.dimensionMap = this.createDimensionMap();
    }

    public TensorType type() {
        return this.type;
    }

    public int rank() {
        return this.dimensions.size();
    }

    public List<TensorType.Dimension> dimensions() {
        return this.dimensions;
    }

    public List<String> dimensionNames() {
        return this.dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList());
    }

    private int[] createDimensionMap() {
        int numDimensions = this.dimensions.size();
        if (numDimensions == 0) {
            return null;
        }
        this.innerSizesOnnx[numDimensions - 1] = 1L;
        this.innerSizesVespa[numDimensions - 1] = 1L;
        int i = numDimensions - 1;
        while (--i >= 0) {
            this.innerSizesOnnx[i] = this.dimensions().get(i + 1).size().orElse(-1L) * this.innerSizesOnnx[i + 1];
            this.innerSizesVespa[i] = ((TensorType.Dimension)this.type.dimensions().get(i + 1)).size().orElse(-1L) * this.innerSizesVespa[i + 1];
        }
        int[] mapping = new int[numDimensions];
        block1: for (int i2 = 0; i2 < numDimensions; ++i2) {
            TensorType.Dimension dim1 = this.dimensions().get(i2);
            for (int j = 0; j < numDimensions; ++j) {
                TensorType.Dimension dim2 = (TensorType.Dimension)this.type.dimensions().get(j);
                if (!dim1.equals((Object)dim2)) continue;
                mapping[i2] = j;
                continue block1;
            }
        }
        return mapping;
    }

    public int toDirectIndex(int index) {
        if (this.dimensions.size() == 0) {
            return 0;
        }
        if (this.dimensionMap == null) {
            throw new IllegalArgumentException("Dimension map is not available");
        }
        int directIndex = 0;
        long rest = index;
        for (int i = 0; i < this.dimensions.size(); ++i) {
            long address = rest / this.innerSizesOnnx[i];
            directIndex = (int)((long)directIndex + this.innerSizesVespa[this.dimensionMap[i]] * address);
            rest %= this.innerSizesOnnx[i];
        }
        return directIndex;
    }

    public boolean equals(Object obj) {
        if (obj == null || !(obj instanceof OrderedTensorType)) {
            return false;
        }
        OrderedTensorType other = (OrderedTensorType)obj;
        if (this.dimensions.size() != this.dimensions.size()) {
            return false;
        }
        List<TensorType.Dimension> thisDimensions = this.dimensions();
        List<TensorType.Dimension> otherDimensions = other.dimensions();
        for (int i = 0; i < thisDimensions.size(); ++i) {
            if (thisDimensions.get(i).equals((Object)otherDimensions.get(i))) continue;
            return false;
        }
        return true;
    }

    public void verifyType(Onnx.TypeProto typeProto) {
        Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
        if (shape != null) {
            if (shape.getDimCount() != this.type.rank()) {
                throw new IllegalArgumentException("Onnx shape of does not match Vespa shape");
            }
            for (int onnxIndex = 0; onnxIndex < this.dimensions.size(); ++onnxIndex) {
                int vespaIndex = this.dimensionMap[onnxIndex];
                Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex);
                TensorType.Dimension vespaDimension = (TensorType.Dimension)this.type().dimensions().get(vespaIndex);
                if (onnxDimension.getDimValue() == vespaDimension.size().orElse(-1L).longValue()) continue;
                throw new IllegalArgumentException("TensorFlow dimensions of does not match Vespa dimensions");
            }
        }
    }

    public OrderedTensorType rename(DimensionRenamer renamer) {
        ArrayList<TensorType.Dimension> renamedDimensions = new ArrayList<TensorType.Dimension>(this.dimensions.size());
        for (TensorType.Dimension dimension : this.dimensions) {
            String oldName = dimension.name();
            Optional<String> newName = renamer.dimensionNameOf(oldName);
            if (!newName.isPresent()) {
                return this;
            }
            TensorType.Dimension.Type dimensionType = dimension.type();
            if (dimensionType == TensorType.Dimension.Type.indexedBound) {
                renamedDimensions.add(TensorType.Dimension.indexed((String)newName.get(), (long)((Long)dimension.size().get())));
                continue;
            }
            if (dimensionType == TensorType.Dimension.Type.indexedUnbound) {
                renamedDimensions.add(TensorType.Dimension.indexed((String)newName.get()));
                continue;
            }
            if (dimensionType != TensorType.Dimension.Type.mapped) continue;
            renamedDimensions.add(TensorType.Dimension.mapped((String)newName.get()));
        }
        return new OrderedTensorType(renamedDimensions);
    }

    public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
        return OrderedTensorType.fromOnnxType(type, "d");
    }

    public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
        Onnx.TensorShapeProto shape = type.getTensorType().getShape();
        Builder builder = new Builder(shape);
        for (int i = 0; i < shape.getDimCount(); ++i) {
            String dimensionName = dimensionPrefix + i;
            Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
            if (onnxDimension.getDimValue() >= 0L) {
                builder.add(TensorType.Dimension.indexed((String)dimensionName, (long)onnxDimension.getDimValue()));
                continue;
            }
            builder.add(TensorType.Dimension.indexed((String)dimensionName));
        }
        return builder.build();
    }

    public static OrderedTensorType fromOnnxType(List<Long> dims, String dimensionPrefix) {
        Builder builder = new Builder();
        for (int i = 0; i < dims.size(); ++i) {
            String dimensionName = dimensionPrefix + i;
            Long dimSize = dims.get(i);
            if (dimSize >= 0L) {
                builder.add(TensorType.Dimension.indexed((String)dimensionName, (long)dimSize));
                continue;
            }
            builder.add(TensorType.Dimension.indexed((String)dimensionName));
        }
        return builder.build();
    }

    public static OrderedTensorType standardType(OrderedTensorType type) {
        Builder builder = new Builder();
        for (int i = 0; i < type.dimensions().size(); ++i) {
            TensorType.Dimension dim = type.dimensions().get(i);
            String dimensionName = "d" + i;
            if (dim.size().isPresent() && (Long)dim.size().get() >= 0L) {
                builder.add(TensorType.Dimension.indexed((String)dimensionName, (long)((Long)dim.size().get())));
                continue;
            }
            builder.add(TensorType.Dimension.indexed((String)dimensionName));
        }
        return builder.build();
    }

    public static class Builder {
        private final Onnx.TensorShapeProto shape;
        private final List<TensorType.Dimension> dimensions;

        public Builder(Onnx.TensorShapeProto shape) {
            this.shape = shape;
            this.dimensions = new ArrayList<TensorType.Dimension>(shape.getDimCount());
        }

        public Builder() {
            this.shape = null;
            this.dimensions = new ArrayList<TensorType.Dimension>();
        }

        public Builder add(TensorType.Dimension vespaDimension) {
            if (this.shape != null) {
                int index = this.dimensions.size();
                Onnx.TensorShapeProto.Dimension onnxDimension = this.shape.getDim(index);
                long size = onnxDimension.getDimValue();
                if (size >= 0L) {
                    if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) {
                        throw new IllegalArgumentException("Non-agreement between Onnx and Vespa dimension types");
                    }
                    if (!vespaDimension.size().isPresent()) {
                        throw new IllegalArgumentException("Tensor dimension is indexed bound but does not have a size");
                    }
                    if ((Long)vespaDimension.size().get() != size) {
                        throw new IllegalArgumentException("Non-agreement between Onnx and Vespa dimension sizes. TensorFlow: " + size + " Vespa: " + vespaDimension.size().get());
                    }
                } else if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) {
                    throw new IllegalArgumentException("Non-agreement between Onnx and Vespa dimension types");
                }
            }
            this.dimensions.add(vespaDimension);
            return this;
        }

        public OrderedTensorType build() {
            return new OrderedTensorType(this.dimensions);
        }
    }
}

