/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor;

import com.yahoo.tensor.TensorType;
import java.util.HashMap;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;

public class TypeResolver {
    private static final Logger logger = Logger.getLogger(TypeResolver.class.getName());

    private static TensorType scalar() {
        return TensorType.empty;
    }

    public static TensorType map(TensorType inputType) {
        TensorType.Value orig = inputType.valueType();
        TensorType.Value cellType = TensorType.Value.largestOf(orig, TensorType.Value.FLOAT);
        if (cellType == orig) {
            return inputType;
        }
        return new TensorType(cellType, inputType.dimensions());
    }

    public static TensorType reduce(TensorType inputType, List<String> reduceDimensions) {
        if (reduceDimensions.isEmpty()) {
            return TypeResolver.scalar();
        }
        HashMap<String, TensorType.Dimension> map = new HashMap<String, TensorType.Dimension>();
        for (TensorType.Dimension dim : inputType.dimensions()) {
            map.put(dim.name(), dim);
        }
        for (String name : reduceDimensions) {
            if (map.containsKey(name)) {
                map.remove(name);
                continue;
            }
            logger.log(Level.WARNING, "reducing non-existing dimension " + name + " in type " + inputType);
        }
        if (map.isEmpty()) {
            return TypeResolver.scalar();
        }
        TensorType.Value cellType = TensorType.Value.largestOf(inputType.valueType(), TensorType.Value.FLOAT);
        return new TensorType(cellType, map.values());
    }

    public static TensorType peek(TensorType inputType, List<String> peekDimensions) {
        if (peekDimensions.isEmpty()) {
            throw new IllegalArgumentException("peeking no dimensions makes no sense");
        }
        HashMap<String, TensorType.Dimension> map = new HashMap<String, TensorType.Dimension>();
        for (TensorType.Dimension dim : inputType.dimensions()) {
            map.put(dim.name(), dim);
        }
        for (String name : peekDimensions) {
            if (map.containsKey(name)) {
                map.remove(name);
                continue;
            }
            throw new IllegalArgumentException("peeking non-existing dimension " + name + " in type " + inputType);
        }
        if (map.isEmpty()) {
            return TypeResolver.scalar();
        }
        TensorType.Value cellType = inputType.valueType();
        return new TensorType(cellType, map.values());
    }

    public static TensorType rename(TensorType inputType, List<String> from, List<String> to) {
        if (from.isEmpty()) {
            throw new IllegalArgumentException("renaming no dimensions");
        }
        if (from.size() != to.size()) {
            throw new IllegalArgumentException("bad rename, from size " + from.size() + " != to.size " + to.size());
        }
        HashMap<String, TensorType.Dimension> oldDims = new HashMap<String, TensorType.Dimension>();
        for (TensorType.Dimension dim : inputType.dimensions()) {
            oldDims.put(dim.name(), dim);
        }
        HashMap<String, TensorType.Dimension> newDims = new HashMap<String, TensorType.Dimension>();
        for (int i = 0; i < from.size(); ++i) {
            String oldName = from.get(i);
            String newName = to.get(i);
            if (oldDims.containsKey(oldName)) {
                TensorType.Dimension dim = (TensorType.Dimension)oldDims.remove(oldName);
                newDims.put(newName, dim.withName(newName));
                continue;
            }
            logger.log(Level.WARNING, "renaming non-existing dimension " + oldName + " in type " + inputType);
        }
        for (TensorType.Dimension keep : oldDims.values()) {
            newDims.put(keep.name(), keep);
        }
        if (inputType.dimensions().size() == newDims.size()) {
            return new TensorType(inputType.valueType(), newDims.values());
        }
        throw new IllegalArgumentException("bad rename, lost some dimenions");
    }

    public static TensorType cell_cast(TensorType inputType, TensorType.Value toCellType) {
        if (toCellType != TensorType.Value.DOUBLE && inputType.dimensions().isEmpty()) {
            throw new IllegalArgumentException("cannot cast " + inputType + " to valueType" + toCellType);
        }
        return new TensorType(toCellType, inputType.dimensions());
    }

    private static boolean firstIsBoundSecond(TensorType.Dimension first, TensorType.Dimension second) {
        return first.type() == TensorType.Dimension.Type.indexedBound && second.type() == TensorType.Dimension.Type.indexedUnbound && first.name().equals(second.name());
    }

    private static boolean firstIsSmaller(TensorType.Dimension first, TensorType.Dimension second) {
        return first.type() == TensorType.Dimension.Type.indexedBound && second.type() == TensorType.Dimension.Type.indexedBound && first.name().equals(second.name()) && first.size().isPresent() && second.size().isPresent() && first.size().get() < second.size().get();
    }

    public static TensorType join(TensorType lhs, TensorType rhs) {
        TensorType.Value cellType = TensorType.Value.DOUBLE;
        if (lhs.rank() > 0 && rhs.rank() > 0) {
            cellType = TensorType.Value.largestOf(lhs.valueType(), rhs.valueType());
        } else if (lhs.rank() > 0) {
            cellType = lhs.valueType();
        } else if (rhs.rank() > 0) {
            cellType = rhs.valueType();
        }
        cellType = TensorType.Value.largestOf(cellType, TensorType.Value.FLOAT);
        HashMap<String, TensorType.Dimension> map = new HashMap<String, TensorType.Dimension>();
        for (TensorType.Dimension dim : lhs.dimensions()) {
            map.put(dim.name(), dim);
        }
        for (TensorType.Dimension dim : rhs.dimensions()) {
            if (map.containsKey(dim.name())) {
                TensorType.Dimension other = (TensorType.Dimension)map.get(dim.name());
                if (other.equals(dim)) continue;
                if (TypeResolver.firstIsBoundSecond(dim, other)) {
                    map.put(dim.name(), dim);
                    continue;
                }
                if (TypeResolver.firstIsBoundSecond(other, dim)) {
                    map.put(dim.name(), other);
                    continue;
                }
                if (dim.isMapped() && other.isIndexed()) {
                    map.put(dim.name(), dim);
                    continue;
                }
                if (dim.isIndexed() && other.isMapped()) {
                    map.put(dim.name(), other);
                    continue;
                }
                throw new IllegalArgumentException("Unequal dimension " + dim.name() + " in " + lhs + " and " + rhs);
            }
            map.put(dim.name(), dim);
        }
        return new TensorType(cellType, map.values());
    }

    public static TensorType merge(TensorType lhs, TensorType rhs) {
        boolean allOk;
        int sz = lhs.dimensions().size();
        boolean bl = allOk = rhs.dimensions().size() == sz;
        if (allOk) {
            for (int i = 0; i < sz; ++i) {
                String rName;
                String lName = lhs.dimensions().get(i).name();
                if (lName.equals(rName = rhs.dimensions().get(i).name())) continue;
                allOk = false;
            }
        }
        if (allOk) {
            return TypeResolver.join(lhs, rhs);
        }
        throw new IllegalArgumentException("types in merge() dimensions mismatch: " + lhs + " != " + rhs);
    }

    public static TensorType concat(TensorType lhs, TensorType rhs, String concatDimension) {
        TensorType.Value cellType = TensorType.Value.DOUBLE;
        if (lhs.rank() > 0 && rhs.rank() > 0) {
            if (lhs.valueType() == rhs.valueType()) {
                cellType = lhs.valueType();
            } else {
                cellType = TensorType.Value.largestOf(lhs.valueType(), rhs.valueType());
                cellType = TensorType.Value.largestOf(cellType, TensorType.Value.FLOAT);
            }
        } else if (lhs.rank() > 0) {
            cellType = lhs.valueType();
        } else if (rhs.rank() > 0) {
            cellType = rhs.valueType();
        }
        TensorType.Dimension first = TensorType.Dimension.indexed(concatDimension, 1L);
        TensorType.Dimension second = TensorType.Dimension.indexed(concatDimension, 1L);
        HashMap<String, TensorType.Dimension> map = new HashMap<String, TensorType.Dimension>();
        for (TensorType.Dimension dim : lhs.dimensions()) {
            if (dim.name().equals(concatDimension)) {
                first = dim;
                continue;
            }
            map.put(dim.name(), dim);
        }
        for (TensorType.Dimension dim : rhs.dimensions()) {
            if (dim.name().equals(concatDimension)) {
                second = dim;
                continue;
            }
            if (map.containsKey(dim.name())) {
                TensorType.Dimension other = (TensorType.Dimension)map.get(dim.name());
                if (other.equals(dim)) continue;
                if (TypeResolver.firstIsBoundSecond(dim, other)) {
                    map.put(dim.name(), other);
                    continue;
                }
                if (TypeResolver.firstIsBoundSecond(other, dim)) {
                    map.put(dim.name(), dim);
                    continue;
                }
                if (TypeResolver.firstIsSmaller(dim, other)) {
                    map.put(dim.name(), dim);
                    continue;
                }
                if (TypeResolver.firstIsSmaller(other, dim)) {
                    map.put(dim.name(), other);
                    continue;
                }
                throw new IllegalArgumentException("Unequal dimension " + dim.name() + " in " + lhs + " and " + rhs);
            }
            map.put(dim.name(), dim);
        }
        if (first.type() == TensorType.Dimension.Type.mapped) {
            throw new IllegalArgumentException("Bad concat dimension " + concatDimension + " in lhs: " + lhs);
        }
        if (second.type() == TensorType.Dimension.Type.mapped) {
            throw new IllegalArgumentException("Bad concat dimension " + concatDimension + " in rhs: " + rhs);
        }
        if (first.type() == TensorType.Dimension.Type.indexedUnbound) {
            map.put(concatDimension, first);
        } else if (second.type() == TensorType.Dimension.Type.indexedUnbound) {
            map.put(concatDimension, second);
        } else {
            long concatSize = first.size().get() + second.size().get();
            map.put(concatDimension, TensorType.Dimension.indexed(concatDimension, concatSize));
        }
        return new TensorType(cellType, map.values());
    }
}

