/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor;

import com.yahoo.tensor.TensorTypeParser;
import com.yahoo.text.Ascii7BitMatcher;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

public class TensorType {
    static Ascii7BitMatcher labelMatcher = new Ascii7BitMatcher("-_@" + Ascii7BitMatcher.charsAndNumbers(), "_@$" + Ascii7BitMatcher.charsAndNumbers());
    public static final TensorType empty = new TensorType(Value.DOUBLE, Collections.emptyList());
    private final Value valueType;
    private final List<Dimension> dimensions;
    private final TensorType mappedSubtype;
    private final TensorType indexedSubtype;

    public TensorType(Value valueType, Collection<Dimension> dimensions) {
        this.valueType = valueType;
        ArrayList<Dimension> dimensionList = new ArrayList<Dimension>(dimensions);
        Collections.sort(dimensionList);
        this.dimensions = List.copyOf(dimensionList);
        if (dimensionList.stream().allMatch(d -> d.isIndexed())) {
            this.mappedSubtype = empty;
            this.indexedSubtype = this;
        } else if (dimensionList.stream().noneMatch(d -> d.isIndexed())) {
            this.mappedSubtype = this;
            this.indexedSubtype = empty;
        } else {
            this.mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> !d.isIndexed()).toList());
            this.indexedSubtype = new TensorType(valueType, dimensions.stream().filter(Dimension::isIndexed).toList());
        }
    }

    public static Value combinedValueType(TensorType ... types) {
        ArrayList<Value> valueTypes = new ArrayList<Value>();
        for (TensorType type : types) {
            if (type.rank() <= 0) continue;
            valueTypes.add(type.valueType());
        }
        return Value.largestOf(valueTypes);
    }

    public static TensorType fromSpec(String specString) {
        return TensorTypeParser.fromSpec(specString);
    }

    public Value valueType() {
        return this.valueType;
    }

    public TensorType mappedSubtype() {
        return this.mappedSubtype;
    }

    public TensorType indexedSubtype() {
        return this.indexedSubtype;
    }

    public int rank() {
        return this.dimensions.size();
    }

    public List<Dimension> dimensions() {
        return this.dimensions;
    }

    public Set<String> dimensionNames() {
        return this.dimensions.stream().map(Dimension::name).collect(Collectors.toSet());
    }

    public Optional<Dimension> dimension(String name) {
        return this.indexOfDimension(name).map(i -> this.dimensions.get((int)i));
    }

    public Optional<Integer> indexOfDimension(String dimension) {
        for (int i = 0; i < this.dimensions.size(); ++i) {
            if (!this.dimensions.get(i).name().equals(dimension)) continue;
            return Optional.of(i);
        }
        return Optional.empty();
    }

    public Optional<Long> sizeOfDimension(String dimension) {
        Optional<Dimension> d = this.dimension(dimension);
        if (!d.isPresent()) {
            return Optional.empty();
        }
        return d.get().size();
    }

    public boolean isAssignableTo(TensorType generalization) {
        return this.isConvertibleOrAssignableTo(generalization, false, true);
    }

    public boolean isConvertibleTo(TensorType generalization) {
        return this.isConvertibleOrAssignableTo(generalization, true, true);
    }

    public boolean isRenamableTo(TensorType other) {
        return this.isConvertibleOrAssignableTo(other, false, false);
    }

    private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible, boolean considerName) {
        if (!generalization.valueType().isEqualOrLargerThan(this.valueType)) {
            return false;
        }
        if (generalization.dimensions().size() != this.dimensions().size()) {
            return false;
        }
        for (int i = 0; i < generalization.dimensions().size(); ++i) {
            Dimension thisDimension = this.dimensions().get(i);
            Dimension generalizationDimension = generalization.dimensions().get(i);
            if (thisDimension.isIndexed() != generalizationDimension.isIndexed()) {
                return false;
            }
            if (considerName && !thisDimension.name().equals(generalizationDimension.name())) {
                return false;
            }
            if (!generalizationDimension.size().isPresent()) continue;
            if (!thisDimension.size().isPresent()) {
                return false;
            }
            if (!(convertible ? thisDimension.size().get() > generalizationDimension.size().get() : !thisDimension.size().get().equals(generalizationDimension.size().get()))) continue;
            return false;
        }
        return true;
    }

    public String toString() {
        return "tensor" + (String)(this.valueType == Value.DOUBLE ? "" : "<" + this.valueType.id() + ">") + "(" + this.dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")";
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        TensorType other = (TensorType)o;
        if (this.rank() == 0 && other.rank() == 0) {
            return true;
        }
        if (this.valueType != other.valueType) {
            return false;
        }
        return this.dimensions.equals(other.dimensions);
    }

    public boolean mathematicallyEquals(TensorType other) {
        if (this.dimensions().size() != other.dimensions().size()) {
            return false;
        }
        for (int i = 0; i < this.dimensions().size(); ++i) {
            if (this.dimensions().get(i).name().equals(other.dimensions().get(i).name())) continue;
            return false;
        }
        return true;
    }

    public Optional<TensorType> dimensionwiseGeneralizationWith(TensorType other) {
        if (this.equals(other)) {
            return Optional.of(this);
        }
        if (this.dimensions.size() != other.dimensions.size()) {
            return Optional.empty();
        }
        Builder b = new Builder(Value.largestOf(this.valueType, other.valueType));
        for (int i = 0; i < this.dimensions.size(); ++i) {
            Dimension thisDim = this.dimensions().get(i);
            Dimension otherDim = other.dimensions().get(i);
            if (!thisDim.name().equals(otherDim.name())) {
                return Optional.empty();
            }
            if (thisDim.isIndexed() && otherDim.isIndexed()) {
                if (thisDim.size().isPresent() && otherDim.size().isPresent()) {
                    if (!thisDim.size().get().equals(otherDim.size().get())) {
                        return Optional.empty();
                    }
                    b.dimension(thisDim);
                    continue;
                }
                if (thisDim.size().isPresent()) {
                    b.dimension(otherDim);
                    continue;
                }
                if (otherDim.size().isPresent()) {
                    b.dimension(thisDim);
                    continue;
                }
                b.dimension(thisDim);
                continue;
            }
            if (!thisDim.isIndexed() && !otherDim.isIndexed()) {
                b.dimension(thisDim);
                continue;
            }
            return Optional.empty();
        }
        return Optional.of(b.build());
    }

    public int hashCode() {
        return Objects.hash(new Object[]{this.dimensions, this.valueType});
    }

    public static enum Value {
        DOUBLE("double"),
        FLOAT("float"),
        BFLOAT16("bfloat16"),
        INT8("int8");

        private final String id;

        private Value(String id) {
            this.id = id;
        }

        public String id() {
            return this.id;
        }

        public boolean isEqualOrLargerThan(Value other) {
            return this == other || Value.largestOf(this, other) == this;
        }

        public static Value largestOf(List<Value> values) {
            if (values.isEmpty()) {
                return DOUBLE;
            }
            Value largest = null;
            for (Value value : values) {
                if (largest == null) {
                    largest = value;
                    continue;
                }
                largest = Value.largestOf(largest, value);
            }
            return largest;
        }

        public static Value largestOf(Value value1, Value value2) {
            if (value1 == DOUBLE || value2 == DOUBLE) {
                return DOUBLE;
            }
            if (value1 == FLOAT || value2 == FLOAT) {
                return FLOAT;
            }
            if (value1 == BFLOAT16 || value2 == BFLOAT16) {
                return BFLOAT16;
            }
            if (value1 == INT8 && value2 == INT8) {
                return INT8;
            }
            throw new IllegalArgumentException("Cannot find largest of " + value1 + " and " + value2);
        }

        public String toString() {
            return this.name().toLowerCase();
        }

        public static Value fromId(String valueTypeString) {
            for (Value value : Value.values()) {
                if (!value.id.equals(valueTypeString)) continue;
                return value;
            }
            throw new IllegalArgumentException("Value type must be either 'double', 'float', 'bfloat16', or 'int8' but was '" + valueTypeString + "'");
        }
    }

    public static abstract class Dimension
    implements Comparable<Dimension> {
        private final String name;

        private Dimension(String name) {
            this.name = Dimension.requireIdentifier(name);
        }

        public final String name() {
            return this.name;
        }

        public abstract Optional<Long> size();

        public abstract Type type();

        public abstract Dimension withName(String var1);

        public boolean isIndexed() {
            return this.type() == Type.indexedBound || this.type() == Type.indexedUnbound;
        }

        public boolean isMapped() {
            return this.type() == Type.mapped;
        }

        Dimension combineWith(Optional<Dimension> other, boolean allowDifferentSizes) {
            if (!other.isPresent()) {
                return this;
            }
            if (this instanceof MappedDimension) {
                return this;
            }
            if (other.get() instanceof MappedDimension) {
                return other.get();
            }
            if (this instanceof IndexedUnboundDimension) {
                return this;
            }
            if (other.get() instanceof IndexedUnboundDimension) {
                return other.get();
            }
            IndexedBoundDimension thisIb = (IndexedBoundDimension)this;
            IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get();
            if (allowDifferentSizes) {
                return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb;
            }
            if (!thisIb.size().equals(otherIb.size())) {
                throw new IllegalArgumentException("Unequal dimension sizes in " + thisIb + " and " + otherIb);
            }
            return thisIb;
        }

        public abstract String toString();

        public boolean equals(Object other) {
            if (this == other) {
                return true;
            }
            if (other == null || this.getClass() != other.getClass()) {
                return false;
            }
            return this.name.equals(((Dimension)other).name);
        }

        public int hashCode() {
            return this.name.hashCode();
        }

        @Override
        public int compareTo(Dimension other) {
            return this.name.compareTo(other.name);
        }

        public static Dimension indexed(String name, long size) {
            return new IndexedBoundDimension(name, size);
        }

        public static Dimension indexed(String name) {
            return new IndexedUnboundDimension(name);
        }

        public static Dimension mapped(String name) {
            return new MappedDimension(name);
        }

        private static String requireIdentifier(String name) {
            if (name == null) {
                throw new IllegalArgumentException("A dimension name cannot be null");
            }
            if (!labelMatcher.matches(name)) {
                throw new IllegalArgumentException("A dimension name must be an identifier or integer, not '" + name + "'");
            }
            return name;
        }

        public static enum Type {
            indexedBound,
            indexedUnbound,
            mapped;

        }
    }

    public static class Builder {
        private final Map<String, Dimension> dimensions = new LinkedHashMap<String, Dimension>();
        private final Value valueType;

        public Builder() {
            this(Value.DOUBLE);
        }

        public Builder(Value valueType) {
            this.valueType = valueType;
        }

        public Builder(TensorType ... types) {
            this(true, types);
        }

        public Builder(boolean allowDifferentSizes, TensorType ... types) {
            this.valueType = TensorType.combinedValueType(types);
            for (TensorType type : types) {
                this.addDimensionsOf(type, allowDifferentSizes);
            }
        }

        public Builder(Iterable<Dimension> dimensions) {
            this(Value.DOUBLE, dimensions);
        }

        public Builder(Value valueType, Iterable<Dimension> dimensions) {
            this.valueType = valueType;
            for (Dimension dimension : dimensions) {
                this.dimension(dimension);
            }
        }

        private void addDimensionsOf(TensorType type, boolean allowDifferentSizes) {
            for (Dimension dimension : type.dimensions) {
                this.set(dimension.combineWith(Optional.ofNullable(this.dimensions.get(dimension.name())), allowDifferentSizes));
            }
        }

        public int rank() {
            return this.dimensions.size();
        }

        private Builder add(Dimension dimension) {
            Objects.requireNonNull(dimension, "A dimension cannot be null");
            if (this.dimensions.containsKey(dimension.name())) {
                throw new IllegalArgumentException("Could not add dimension " + dimension + " as this dimension is already present");
            }
            this.dimensions.put(dimension.name(), dimension);
            return this;
        }

        public Builder set(Dimension dimension) {
            Objects.requireNonNull(dimension, "A dimension cannot be null");
            this.dimensions.put(dimension.name(), dimension);
            return this;
        }

        public Builder indexed(String name, long size) {
            return this.add(new IndexedBoundDimension(name, size));
        }

        public Builder indexed(String name) {
            return this.add(new IndexedUnboundDimension(name));
        }

        public Builder mapped(String name) {
            return this.add(new MappedDimension(name));
        }

        public Builder dimension(Dimension dimension) {
            return this.add(dimension);
        }

        public Optional<Dimension> getDimension(String dimension) {
            return Optional.ofNullable(this.dimensions.get(dimension));
        }

        public Builder dimension(String name, Dimension.Type type) {
            switch (type) {
                case mapped: {
                    this.mapped(name);
                    break;
                }
                case indexedUnbound: {
                    this.indexed(name);
                    break;
                }
                default: {
                    throw new IllegalArgumentException("This can not create a dimension of type " + type);
                }
            }
            return this;
        }

        public TensorType build() {
            return new TensorType(this.valueType, this.dimensions.values());
        }
    }

    public static class MappedDimension
    extends Dimension {
        private MappedDimension(String name) {
            super(name);
        }

        @Override
        public Optional<Long> size() {
            return Optional.empty();
        }

        @Override
        public Dimension.Type type() {
            return Dimension.Type.mapped;
        }

        @Override
        public MappedDimension withName(String name) {
            return new MappedDimension(name);
        }

        @Override
        public String toString() {
            return this.name() + "{}";
        }
    }

    public static class IndexedUnboundDimension
    extends Dimension {
        private IndexedUnboundDimension(String name) {
            super(name);
        }

        @Override
        public Optional<Long> size() {
            return Optional.empty();
        }

        @Override
        public Dimension.Type type() {
            return Dimension.Type.indexedUnbound;
        }

        @Override
        public IndexedUnboundDimension withName(String name) {
            return new IndexedUnboundDimension(name);
        }

        @Override
        public String toString() {
            return this.name() + "[]";
        }
    }

    public static class IndexedBoundDimension
    extends Dimension {
        private final Long size;

        private IndexedBoundDimension(String name, long size) {
            super(name);
            if (size < 1L) {
                throw new IllegalArgumentException("Size of bound dimension '" + name + "' must be at least 1");
            }
            if (size > Integer.MAX_VALUE) {
                throw new IllegalArgumentException("Size of bound dimension '" + name + "' cannot be larger than 2147483647");
            }
            this.size = size;
        }

        @Override
        public Optional<Long> size() {
            return Optional.of(this.size);
        }

        @Override
        public Dimension.Type type() {
            return Dimension.Type.indexedBound;
        }

        @Override
        public IndexedBoundDimension withName(String name) {
            return new IndexedBoundDimension(name, this.size);
        }

        @Override
        public String toString() {
            return this.name() + "[" + this.size + "]";
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            IndexedBoundDimension that = (IndexedBoundDimension)o;
            return this.size.equals(that.size);
        }

        @Override
        public int hashCode() {
            int result = super.hashCode();
            result = 31 * result + this.size.hashCode();
            return result;
        }
    }
}

