/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor;

import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TensorTypeParser;
import com.yahoo.tensor.serialization.JsonFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;

class TensorParser {
    TensorParser() {
    }

    static Tensor tensorFrom(String tensorString, Optional<TensorType> explicitType) {
        try {
            return TensorParser.tensorFromBody(tensorString, explicitType);
        }
        catch (IllegalArgumentException e) {
            throw new IllegalArgumentException("Could not parse '" + tensorString + "' as a tensor" + (String)(explicitType.isPresent() ? " of type " + explicitType.get() : ""), e);
        }
    }

    static Tensor tensorFromBody(String tensorString, Optional<TensorType> explicitType) {
        String valueString;
        Optional<TensorType> type;
        ArrayList<String> dimensionOrder;
        if ((tensorString = tensorString.trim()).startsWith("tensor")) {
            int colonIndex = tensorString.indexOf(58);
            String typeString = tensorString.substring(0, colonIndex);
            dimensionOrder = new ArrayList<String>();
            TensorType typeFromString = TensorTypeParser.fromSpec(typeString, dimensionOrder);
            if (explicitType.isPresent() && !explicitType.get().equals(typeFromString)) {
                throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was passed type " + explicitType.get());
            }
            type = Optional.of(typeFromString);
            valueString = tensorString.substring(colonIndex + 1);
        } else {
            type = explicitType;
            valueString = tensorString;
            dimensionOrder = null;
        }
        valueString = valueString.trim();
        if (valueString.startsWith("{") && (type.isEmpty() || type.get().rank() == 0 || valueString.substring(1).trim().startsWith("{") || valueString.substring(1).trim().equals("}"))) {
            return TensorParser.tensorFromMappedValueString(valueString, type);
        }
        if (valueString.startsWith("{")) {
            return TensorParser.tensorFromMixedValueString(valueString, type, dimensionOrder);
        }
        if (valueString.startsWith("[")) {
            return TensorParser.tensorFromDenseValueString(valueString, type, dimensionOrder);
        }
        Optional<Tensor> t = TensorParser.maybeFromBinaryValueString(valueString, type, dimensionOrder);
        if (t.isPresent()) {
            return t.get();
        }
        if (explicitType.isPresent() && !explicitType.get().equals(TensorType.empty)) {
            throw new IllegalArgumentException("Got a zero-dimensional tensor value ('" + tensorString + "') where type " + explicitType.get() + " is required");
        }
        try {
            return Tensor.Builder.of(TensorType.empty).cell(Double.parseDouble(tensorString), new long[0]).build();
        }
        catch (NumberFormatException e) {
            throw new IllegalArgumentException("Excepted a number or a string starting by {, [ or tensor(...):, got '" + tensorString + "'");
        }
    }

    private static TensorType typeFromMappedValueString(String valueString) {
        TensorType.Builder builder = new TensorType.Builder();
        MappedValueTypeParser parser = new MappedValueTypeParser(valueString, builder);
        parser.parse();
        return builder.build();
    }

    private static Tensor tensorFromMappedValueString(String valueString, Optional<TensorType> type) {
        try {
            valueString = valueString.trim();
            Tensor.Builder builder = Tensor.Builder.of(type.orElse(TensorParser.typeFromMappedValueString(valueString)));
            MappedValueParser parser = new MappedValueParser(valueString, builder);
            parser.parse();
            return builder.build();
        }
        catch (NumberFormatException e) {
            throw new IllegalArgumentException("Excepted a number or a string starting by '{' or 'tensor('");
        }
    }

    private static Tensor tensorFromMixedValueString(String valueString, Optional<TensorType> type, List<String> dimensionOrder) {
        if (type.isEmpty()) {
            throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type on the form 'tensor(dimensions):...");
        }
        long numMappedDims = type.get().dimensions().stream().filter(d -> d.isMapped()).count();
        try {
            valueString = valueString.trim();
            if (!valueString.startsWith("{") && valueString.endsWith("}")) {
                throw new IllegalArgumentException("A mixed tensor must be enclosed in {}");
            }
            Tensor.Builder builder = Tensor.Builder.of(type.get());
            if (numMappedDims == 0L) {
                if (!SingleUnboundParser.canHandle(type.get())) {
                    throw new IllegalArgumentException("No suitable dimension in " + type.get() + " for parsing a tensor on the mixed form: Should have one mapped dimension");
                }
                SingleUnboundParser parser = new SingleUnboundParser(valueString, builder);
                parser.parse();
            } else {
                GenericMixedValueParser parser = new GenericMixedValueParser(valueString, dimensionOrder, builder);
                parser.parse();
            }
            return builder.build();
        }
        catch (NumberFormatException e) {
            throw new IllegalArgumentException("Excepted a number or a string starting by '{' or 'tensor('");
        }
    }

    private static boolean validHexString(TensorType type, String valueString) {
        long sz = 1L;
        for (TensorType.Dimension d : type.dimensions()) {
            sz *= d.size().orElse(0L).longValue();
        }
        int numHexDigits = (int)(sz * 2L * (long)type.valueType().sizeOfCell());
        return sz != 0L && !type.dimensions().isEmpty() && valueString.length() == numHexDigits && !valueString.chars().anyMatch(ch -> Character.digit(ch, 16) == -1);
    }

    private static Optional<Tensor> maybeFromBinaryValueString(String valueString, Optional<TensorType> optType, List<String> dimensionOrder) {
        if (optType.isEmpty()) {
            return Optional.empty();
        }
        TensorType type = optType.get();
        if (TensorParser.validHexString(type, valueString)) {
            Tensor tensor = TensorParser.tensorFromDenseValueString(valueString, optType, dimensionOrder);
            return Optional.of(tensor);
        }
        return Optional.empty();
    }

    private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type, List<String> dimensionOrder) {
        if (type.isEmpty()) {
            throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type on the form 'tensor(dimensions):...");
        }
        IndexedTensor.Builder builder = IndexedTensor.Builder.of(type.get());
        if (type.get().dimensions().stream().anyMatch(d -> d.size().isEmpty())) {
            new UnboundDenseValueParser(valueString, builder).parse();
            return TensorParser.checkBoundDimensionSizes(builder.build());
        }
        new DenseValueParser(valueString, dimensionOrder, (IndexedTensor.BoundBuilder)builder).parse();
        return builder.build();
    }

    private static Tensor checkBoundDimensionSizes(IndexedTensor tensor) {
        TensorType type = tensor.type();
        for (int i = 0; i < type.dimensions().size(); ++i) {
            TensorType.Dimension dimension = type.dimensions().get(i);
            if (!dimension.size().isPresent() || dimension.size().get().longValue() == tensor.dimensionSizes().size(i)) continue;
            throw new IllegalArgumentException("Unexpected size " + tensor.dimensionSizes().size(i) + " for dimension " + dimension.name() + " for type " + type);
        }
        return tensor;
    }

    private static class MappedValueTypeParser
    extends ValueParser {
        private final TensorType.Builder builder;

        public MappedValueTypeParser(String string, TensorType.Builder builder) {
            super(string);
            this.builder = builder;
        }

        public void parse() {
            this.consume('{');
            this.consumeLabels();
        }

        private void consumeLabels() {
            if (!this.consumeOptional('{')) {
                return;
            }
            while (!this.consumeOptional('}')) {
                String dimension = this.consumeIdentifier();
                this.consume(':');
                this.consumeLabel();
                this.builder.mapped(dimension);
                this.consumeOptional(',');
            }
        }
    }

    private static class MappedValueParser
    extends ValueParser {
        private final Tensor.Builder builder;

        public MappedValueParser(String string, Tensor.Builder builder) {
            super(string);
            this.builder = builder;
        }

        private void parse() {
            this.consume('{');
            while (this.position < this.string.length()) {
                this.skipSpace();
                if (this.string.charAt(this.position) == '}') break;
                TensorAddress address = this.consumeLabels();
                if (!address.isEmpty()) {
                    this.consume(':');
                } else {
                    this.consumeOptional(':');
                }
                this.consumeNumber(this.builder.type().valueType(), f -> this.builder.cell(address, f.floatValue()), d -> this.builder.cell(address, (double)d));
                if (this.consumeOptional(',')) continue;
                break;
            }
            if (!this.consumeOptional('}')) {
                throw new IllegalArgumentException("A mapped tensor string must end by '}'");
            }
            this.skipSpace();
            if (this.position < this.string.length()) {
                throw new IllegalArgumentException("Garbage after mapped tensor string: " + this.string.substring(this.position));
            }
        }

        private TensorAddress consumeLabels() {
            TensorAddress.Builder addressBuilder = new TensorAddress.Builder(this.builder.type());
            if (!this.consumeOptional('{')) {
                return addressBuilder.build();
            }
            while (!this.consumeOptional('}')) {
                String dimension = this.consumeIdentifier();
                this.consume(':');
                String label = this.consumeLabel();
                addressBuilder.add(dimension, label);
                this.consumeOptional(',');
            }
            return addressBuilder.build();
        }

        private void parseDenseSubspace(TensorAddress mappedAddress, List<String> denseDimensionOrder) {
            IndexedTensor.DirectIndexBuilder subBuilder = ((MixedTensor.BoundBuilder)this.builder).denseSubspaceBuilder(mappedAddress);
            String rest = this.string.substring(this.position);
            DenseValueParser denseParser = new DenseValueParser(rest, denseDimensionOrder, subBuilder);
            denseParser.parse();
            this.position += denseParser.position();
        }
    }

    private static class SingleUnboundParser
    extends ValueParser {
        private final Tensor.Builder builder;

        public SingleUnboundParser(String string, Tensor.Builder builder) {
            super(string);
            this.builder = builder;
        }

        private void parse() {
            TensorType type = this.builder.type();
            String dimName = type.dimensions().get(0).name();
            this.skipSpace();
            this.consume('{');
            this.skipSpace();
            while (this.position + 1 < this.string.length()) {
                String label = this.consumeLabel();
                this.consume(':');
                TensorAddress mappedAddress = new TensorAddress.Builder(type).add(dimName, label).build();
                this.consumeNumber(mappedAddress);
                if (!this.consumeOptional(',')) {
                    this.consume('}');
                }
                this.skipSpace();
            }
        }

        static boolean canHandle(TensorType type) {
            if (type.rank() != 1) {
                return false;
            }
            TensorType.Dimension dim = type.dimensions().get(0);
            return dim.isIndexed() && dim.size().isEmpty();
        }

        private void consumeNumber(TensorAddress address) {
            this.consumeNumber(this.builder.type().valueType(), f -> this.builder.cell(address, f.floatValue()), d -> this.builder.cell(address, (double)d));
        }
    }

    private static class GenericMixedValueParser
    extends ValueParser {
        private final Tensor.Builder builder;
        private final TensorType type;
        private final List<TensorType.Dimension> mappedDimensions;
        private final TensorType mappedSubtype;
        private final List<String> denseDimensionOrder;

        public GenericMixedValueParser(String string, List<String> dimensionOrder, Tensor.Builder builder) {
            super(string);
            this.builder = builder;
            this.type = builder.type();
            List<String> allDims = GenericMixedValueParser.findOrder(dimensionOrder, this.type);
            this.mappedDimensions = GenericMixedValueParser.findMapped(allDims, this.type);
            this.mappedSubtype = MixedTensor.createPartialType(this.type.valueType(), this.mappedDimensions);
            this.denseDimensionOrder = new ArrayList<String>(allDims);
            for (TensorType.Dimension mapped : this.mappedDimensions) {
                this.denseDimensionOrder.remove(mapped.name());
            }
        }

        private static final List<String> findOrder(List<String> dimensionOrder, TensorType type) {
            if (dimensionOrder == null) {
                return type.dimensions().stream().map(d -> d.name()).toList();
            }
            return dimensionOrder;
        }

        private static final List<TensorType.Dimension> findMapped(List<String> dimensionOrder, TensorType type) {
            ArrayList<TensorType.Dimension> result = new ArrayList<TensorType.Dimension>();
            for (String name : dimensionOrder) {
                TensorType.Dimension dim = type.dimension(name).orElseThrow(() -> new IllegalArgumentException("bad dimension " + name));
                if (!dim.isMapped()) continue;
                result.add(dim);
            }
            return result;
        }

        private void parse() {
            this.consume('{');
            this.skipSpace();
            while (this.position + 1 < this.string.length()) {
                TensorAddress.Builder addrBuilder = new TensorAddress.Builder(this.mappedSubtype);
                this.parseSubspace(addrBuilder, 0);
                if (this.consumeOptional(',')) continue;
                break;
            }
            this.consume('}');
        }

        private void parseSubspace(TensorAddress.Builder addrBuilder, int level) {
            if (level >= this.mappedDimensions.size()) {
                throw new IllegalArgumentException("Too many nested {label:...} levels");
            }
            String label = this.consumeLabel();
            addrBuilder.add(this.mappedDimensions.get(level).name(), label);
            this.consume(':');
            ++level;
            if (this.consumeOptional('{')) {
                do {
                    this.parseSubspace(addrBuilder, level);
                } while (this.consumeOptional(','));
                this.consume('}');
            } else {
                if (level < this.mappedDimensions.size()) {
                    throw new IllegalArgumentException("Not enough nested {label:...} levels");
                }
                TensorAddress mappedAddress = addrBuilder.build();
                if (this.builder.type().rank() > level) {
                    this.parseDenseSubspace(mappedAddress, this.denseDimensionOrder);
                } else {
                    this.consumeNumber(mappedAddress);
                }
            }
        }

        private void parseDenseSubspace(TensorAddress mappedAddress, List<String> denseDimensionOrder) {
            IndexedTensor.DirectIndexBuilder subBuilder = ((MixedTensor.BoundBuilder)this.builder).denseSubspaceBuilder(mappedAddress);
            String rest = this.string.substring(this.position);
            DenseValueParser denseParser = new DenseValueParser(rest, denseDimensionOrder, subBuilder);
            denseParser.parse();
            this.position += denseParser.position();
        }

        private void consumeNumber(TensorAddress address) {
            this.consumeNumber(this.builder.type().valueType(), f -> this.builder.cell(address, f.floatValue()), d -> this.builder.cell(address, (double)d));
        }
    }

    private static class UnboundDenseValueParser
    extends ValueParser {
        private final IndexedTensor.Builder builder;
        private final long[] indexes;

        public UnboundDenseValueParser(String string, IndexedTensor.Builder builder) {
            super(string);
            this.builder = builder;
            this.indexes = new long[builder.type().dimensions().size()];
        }

        public void parse() {
            this.consumeList(0);
        }

        private void consumeList(int dimension) {
            this.consume('[');
            this.indexes[dimension] = 0L;
            while (!this.atListEnd()) {
                if (this.isInnerMostDimension(dimension)) {
                    this.consumeNumber();
                } else {
                    this.consumeList(dimension + 1);
                }
                int n = dimension;
                this.indexes[n] = this.indexes[n] + 1L;
                this.consumeOptional(',');
            }
            this.consume(']');
        }

        private void consumeNumber() {
            this.consumeNumber(this.builder.type().valueType(), f -> this.builder.cell(f.floatValue(), this.indexes), d -> this.builder.cell((double)d, this.indexes));
        }

        private boolean isInnerMostDimension(int dimension) {
            return dimension == this.indexes.length - 1;
        }

        protected boolean atListEnd() {
            this.skipSpace();
            if (this.position >= this.string.length()) {
                throw new IllegalArgumentException("At value position " + this.position + ": Expected a ']' but got the end of the string");
            }
            return this.string.charAt(this.position) == ']';
        }
    }

    private static class DenseValueParser
    extends ValueParser {
        private final IndexedTensor.DirectIndexBuilder builder;
        private final IndexedTensor.Indexes indexes;
        private final boolean hasInnerStructure;

        public DenseValueParser(String string, List<String> dimensionOrder, IndexedTensor.DirectIndexBuilder builder) {
            super(string);
            this.builder = builder;
            this.indexes = IndexedTensor.Indexes.of(builder.type(), dimensionOrder);
            this.hasInnerStructure = DenseValueParser.hasInnerStructure(string);
        }

        public void parse() {
            this.skipSpace();
            if (this.string.charAt(this.position) != '[') {
                int stopPos = this.stopCharIndex(this.position);
                String hexToken = this.string.substring(this.position, stopPos);
                if (TensorParser.validHexString(this.builder.type(), hexToken)) {
                    double[] values = JsonFormat.decodeHexString(hexToken, this.builder.type().valueType());
                    int i = 0;
                    while (this.indexes.hasNext()) {
                        this.indexes.next();
                        this.builder.cellByDirectIndex(this.indexes.toSourceValueIndex(), values[i++]);
                    }
                    if (i != values.length) {
                        throw new IllegalStateException("consume " + i + " values out of " + values.length);
                    }
                    this.position = stopPos;
                    return;
                }
            }
            if (!this.hasInnerStructure) {
                this.consume('[');
            }
            while (this.indexes.hasNext()) {
                int i;
                this.indexes.next();
                for (i = 0; i < this.indexes.nextDimensionsAtStart() && this.hasInnerStructure; ++i) {
                    this.consume('[');
                }
                this.consumeNumber();
                for (i = 0; i < this.indexes.nextDimensionsAtEnd() && this.hasInnerStructure; ++i) {
                    this.consume(']');
                }
                if (!this.indexes.hasNext()) continue;
                this.consume(',');
            }
            if (!this.hasInnerStructure) {
                this.consume(']');
            }
        }

        public int position() {
            return this.position;
        }

        private static boolean hasInnerStructure(String valueString) {
            valueString = valueString.trim();
            int firstLeftBracket = (valueString = valueString.substring(1)).indexOf(91);
            return firstLeftBracket >= 0 && firstLeftBracket < valueString.indexOf(93);
        }

        protected void consumeNumber() {
            this.consumeNumber(this.builder.type().valueType(), f -> this.builder.cellByDirectIndex(this.indexes.toSourceValueIndex(), f.floatValue()), d -> this.builder.cellByDirectIndex(this.indexes.toSourceValueIndex(), (double)d));
        }
    }

    private static abstract class ValueParser {
        protected final String string;
        protected int position = 0;

        protected ValueParser(String string) {
            this.string = string;
        }

        protected void skipSpace() {
            while (this.position < this.string.length() && Character.isWhitespace(this.string.charAt(this.position))) {
                ++this.position;
            }
        }

        protected void consume(char character) {
            this.skipSpace();
            if (this.position >= this.string.length()) {
                throw new IllegalArgumentException("At value position " + this.position + ": Expected a '" + character + "' but got the end of the string");
            }
            if (this.string.charAt(this.position) != character) {
                throw new IllegalArgumentException("At value position " + this.position + ": Expected a '" + character + "' but got '" + this.string.charAt(this.position) + "'");
            }
            ++this.position;
        }

        protected String consumeIdentifier() {
            int endIdentifier = this.requiredNextStopCharIndex(this.position);
            String identifier = this.string.substring(this.position, endIdentifier);
            this.position = endIdentifier;
            return identifier;
        }

        protected String consumeLabel() {
            if (this.consumeOptional('\'')) {
                int endQuote = this.string.indexOf(39, this.position);
                if (endQuote < 0) {
                    throw new IllegalArgumentException("At value position " + this.position + ": A label quoted by a tick (') must end by another tick");
                }
                String label = this.string.substring(this.position, endQuote);
                this.position = endQuote + 1;
                return label;
            }
            if (this.consumeOptional('\"')) {
                int endQuote = this.string.indexOf(34, this.position);
                if (endQuote < 0) {
                    throw new IllegalArgumentException("At value position " + this.position + ": A label quoted by a double quote (\") must end by another double quote");
                }
                String label = this.string.substring(this.position, endQuote);
                this.position = endQuote + 1;
                return label;
            }
            return this.consumeIdentifier();
        }

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        protected void consumeNumber(TensorType.Value cellValueType, Consumer<Float> consumeFloat, Consumer<Double> consumeDouble) {
            this.skipSpace();
            int nextNumberEnd = this.requiredNextStopCharIndex(this.position);
            String cellValueString = this.string.substring(this.position, nextNumberEnd);
            try {
                switch (cellValueType) {
                    case DOUBLE: {
                        consumeDouble.accept(Double.parseDouble(cellValueString));
                        return;
                    }
                    case FLOAT: 
                    case BFLOAT16: 
                    case INT8: {
                        consumeFloat.accept(Float.valueOf(Float.parseFloat(cellValueString)));
                        return;
                    }
                    default: {
                        throw new IllegalArgumentException(cellValueType + " is not supported");
                    }
                }
            }
            catch (NumberFormatException e) {
                throw new IllegalArgumentException("At value position " + this.position + ": '" + cellValueString + "' is not a valid " + cellValueType);
            }
            finally {
                this.position = nextNumberEnd;
            }
        }

        protected boolean consumeOptional(char character) {
            this.skipSpace();
            if (this.position >= this.string.length()) {
                return false;
            }
            if (this.string.charAt(this.position) != character) {
                return false;
            }
            ++this.position;
            return true;
        }

        protected int stopCharIndex(int pos) {
            while (pos < this.string.length()) {
                char ch = this.string.charAt(pos);
                if (Character.isWhitespace(ch)) {
                    return pos;
                }
                if (ch == ',') {
                    return pos;
                }
                if (ch == ']') {
                    return pos;
                }
                if (ch == '}') {
                    return pos;
                }
                if (ch == ':') {
                    return pos;
                }
                ++pos;
            }
            return pos;
        }

        protected int requiredNextStopCharIndex(int pos) {
            if ((pos = this.stopCharIndex(pos)) == this.string.length()) {
                throw new IllegalArgumentException("Malformed tensor string '" + this.string + "': Expected a ',', ']' or '}', ':' after position " + pos);
            }
            return pos;
        }
    }
}

