/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */
package software.amazon.smithy.model.validation;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.SourceLocation;
import software.amazon.smithy.model.knowledge.NullableIndex;
import software.amazon.smithy.model.node.Node;
import software.amazon.smithy.model.node.NodeType;
import software.amazon.smithy.model.node.StringNode;
import software.amazon.smithy.model.shapes.BigDecimalShape;
import software.amazon.smithy.model.shapes.BigIntegerShape;
import software.amazon.smithy.model.shapes.BlobShape;
import software.amazon.smithy.model.shapes.BooleanShape;
import software.amazon.smithy.model.shapes.ByteShape;
import software.amazon.smithy.model.shapes.DocumentShape;
import software.amazon.smithy.model.shapes.DoubleShape;
import software.amazon.smithy.model.shapes.FloatShape;
import software.amazon.smithy.model.shapes.IntegerShape;
import software.amazon.smithy.model.shapes.ListShape;
import software.amazon.smithy.model.shapes.LongShape;
import software.amazon.smithy.model.shapes.MapShape;
import software.amazon.smithy.model.shapes.MemberShape;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ResourceShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.shapes.ShapeVisitor;
import software.amazon.smithy.model.shapes.ShortShape;
import software.amazon.smithy.model.shapes.StringShape;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.shapes.TimestampShape;
import software.amazon.smithy.model.shapes.UnionShape;
import software.amazon.smithy.model.validation.node.NodeValidatorPlugin;
import software.amazon.smithy.model.validation.node.TimestampValidationStrategy;
import software.amazon.smithy.utils.ListUtils;
import software.amazon.smithy.utils.SmithyBuilder;

/**
 * Validates {@link Node} values provided for {@link Shape} definitions.
 *
 * <p>This visitor validator is used to ensure that values provided for custom
 * traits and examples are correct for their schema definitions. A map of
 * shape types to a list of additional validators can be provided to perform
 * additional, non-standard, validation of these values. For example, this can
 * be used to provide additional validation needed for custom traits that are
 * applied to the shape of the data.
 */
public final class NodeValidationVisitor implements ShapeVisitor<List<ValidationEvent>> {

    private static final List<NodeValidatorPlugin> BUILTIN = NodeValidatorPlugin.getBuiltins();

    private final Model model;
    private final TimestampValidationStrategy timestampValidationStrategy;
    private String eventId;
    private Node value;
    private ShapeId eventShapeId;
    private String startingContext;
    private NodeValidatorPlugin.Context validationContext;
    private final NullableIndex nullableIndex;

    private NodeValidationVisitor(Builder builder) {
        this.model = SmithyBuilder.requiredState("model", builder.model);
        this.nullableIndex = NullableIndex.of(model);
        this.validationContext = new NodeValidatorPlugin.Context(model, Feature.enumSet(builder.features));
        this.timestampValidationStrategy = builder.timestampValidationStrategy;
        setValue(SmithyBuilder.requiredState("value", builder.value));
        setStartingContext(builder.contextText);
        setValue(builder.value);
        setEventShapeId(builder.eventShapeId);
        setEventId(builder.eventId);
    }

    /**
     * Features to use when validating.
     */
    public enum Feature {
        /**
         * Emit a warning when a range trait is incompatible with a default value of 0.
         *
         * <p>This was a common pattern in Smithy 1.0 and earlier. It implies that the value is effectively
         * required. However, changing the type of the value by un-boxing it or adjusting the range trait would
         * be a lossy transformation when migrating a model from 1.0 to 2.0.
         */
        RANGE_TRAIT_ZERO_VALUE_WARNING,

        /**
         * Lowers severity of constraint trait validations to WARNING.
         */
        ALLOW_CONSTRAINT_ERRORS,

        /**
         * Allows null values to be provided for an optional structure member.
         *
         * <p>By default, null values are not allowed for optional types.
         */
        ALLOW_OPTIONAL_NULLS,

        /**
         * Requires that blob values are validly encoded base64 strings.
         *
         * <p>By default, blob values which are not valid base64 encoded strings will be allowed.
         */
        REQUIRE_BASE_64_BLOB_VALUES;

        public static Feature fromNode(Node node) {
            return Feature.valueOf(node.expectStringNode().getValue());
        }

        public static Node toNode(Feature feature) {
            return StringNode.from(feature.toString());
        }

        private static EnumSet<Feature> enumSet(Collection<Feature> features) {
            return features.isEmpty() ? EnumSet.noneOf(Feature.class) : EnumSet.copyOf(features);
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    /**
     * Changes the Node value the visitor will evaluate.
     *
     * @param value Value to set.
     */
    public void setValue(Node value) {
        this.value = Objects.requireNonNull(value);
    }

    /**
     * Changes the shape ID that emitted events are associated with.
     *
     * @param eventShapeId Shape ID to set.
     */
    public void setEventShapeId(ShapeId eventShapeId) {
        this.eventShapeId = eventShapeId;
    }

    /**
     * Changes the starting context of the messages emitted by events.
     *
     * @param startingContext Starting context message to set.
     */
    public void setStartingContext(String startingContext) {
        this.startingContext = startingContext == null ? "" : startingContext;
    }

    /**
     * Changes the event ID emitted for events created by this validator.
     *
     * @param eventId Event ID to set.
     */
    public void setEventId(String eventId) {
        this.eventId = eventId == null ? Validator.MODEL_ERROR : eventId;
    }

    private NodeValidationVisitor traverse(String segment, Node node) {
        Builder builder = builder();
        builder.eventShapeId(eventShapeId);
        builder.eventId(eventId);
        builder.value(node);
        builder.model(model);
        builder.startingContext(startingContext.isEmpty() ? segment : (startingContext + "." + segment));
        builder.timestampValidationStrategy(timestampValidationStrategy);
        NodeValidationVisitor visitor = new NodeValidationVisitor(builder);
        // Use the same validation context.
        visitor.validationContext = this.validationContext;
        return visitor;
    }

    @Override
    public List<ValidationEvent> blobShape(BlobShape shape) {
        return value.asStringNode()
                .map(stringNode -> {
                    if (validationContext.hasFeature(Feature.REQUIRE_BASE_64_BLOB_VALUES)) {
                        byte[] encodedValue = stringNode.getValue().getBytes(StandardCharsets.UTF_8);

                        try {
                            Base64.getDecoder().decode(encodedValue);
                        } catch (IllegalArgumentException e) {
                            return ListUtils.of(event("Blob value must be a valid base64 string"));
                        }
                    }

                    return applyPlugins(shape);
                })
                .orElseGet(() -> invalidShape(shape, NodeType.STRING));
    }

    @Override
    public List<ValidationEvent> booleanShape(BooleanShape shape) {
        return value.isBooleanNode()
                ? applyPlugins(shape)
                : invalidShape(shape, NodeType.BOOLEAN);
    }

    @Override
    public List<ValidationEvent> byteShape(ByteShape shape) {
        return validateNaturalNumber(shape, Long.valueOf(Byte.MIN_VALUE), Long.valueOf(Byte.MAX_VALUE));
    }

    @Override
    public List<ValidationEvent> shortShape(ShortShape shape) {
        return validateNaturalNumber(shape, Long.valueOf(Short.MIN_VALUE), Long.valueOf(Short.MAX_VALUE));
    }

    @Override
    public List<ValidationEvent> integerShape(IntegerShape shape) {
        return validateNaturalNumber(shape, Long.valueOf(Integer.MIN_VALUE), Long.valueOf(Integer.MAX_VALUE));
    }

    @Override
    public List<ValidationEvent> longShape(LongShape shape) {
        return validateNaturalNumber(shape, Long.MIN_VALUE, Long.MAX_VALUE);
    }

    @Override
    public List<ValidationEvent> bigIntegerShape(BigIntegerShape shape) {
        return validateNaturalNumber(shape, null, null);
    }

    private List<ValidationEvent> validateNaturalNumber(Shape shape, Long min, Long max) {
        return value.asNumberNode()
                .map(number -> {
                    if (number.isFloatingPointNumber()) {
                        return ListUtils.of(event(String.format(
                                "%s shapes must not have floating point values, but found `%s` provided for `%s`",
                                shape.getType(),
                                number.getValue(),
                                shape.getId())));
                    }

                    Long numberValue = number.getValue().longValue();
                    if (min != null && numberValue < min) {
                        return ListUtils.of(event(String.format(
                                "%s value must be > %d, but found %d",
                                shape.getType(),
                                min,
                                numberValue)));
                    } else if (max != null && numberValue > max) {
                        return ListUtils.of(event(String.format(
                                "%s value must be < %d, but found %d",
                                shape.getType(),
                                max,
                                numberValue)));
                    } else {
                        return applyPlugins(shape);
                    }
                })
                .orElseGet(() -> invalidShape(shape, NodeType.NUMBER));
    }

    @Override
    public List<ValidationEvent> floatShape(FloatShape shape) {
        return value.isNumberNode() || value.isStringNode()
                ? applyPlugins(shape)
                : invalidShape(shape, NodeType.NUMBER);
    }

    @Override
    public List<ValidationEvent> documentShape(DocumentShape shape) {
        // Document values are always valid.
        return Collections.emptyList();
    }

    @Override
    public List<ValidationEvent> doubleShape(DoubleShape shape) {
        return value.isNumberNode() || value.isStringNode()
                ? applyPlugins(shape)
                : invalidShape(shape, NodeType.NUMBER);
    }

    @Override
    public List<ValidationEvent> bigDecimalShape(BigDecimalShape shape) {
        return value.isNumberNode()
                ? applyPlugins(shape)
                : invalidShape(shape, NodeType.NUMBER);
    }

    @Override
    public List<ValidationEvent> stringShape(StringShape shape) {
        return value.asStringNode()
                .map(string -> applyPlugins(shape))
                .orElseGet(() -> invalidShape(shape, NodeType.STRING));
    }

    @Override
    public List<ValidationEvent> timestampShape(TimestampShape shape) {
        return applyPlugins(shape);
    }

    @Override
    public List<ValidationEvent> listShape(ListShape shape) {
        return value.asArrayNode()
                .map(array -> {
                    MemberShape member = shape.getMember();
                    List<ValidationEvent> events = applyPlugins(shape);
                    // Each element creates a context with a numeric index (e.g., "foo.0.baz", "foo.1.baz", etc.).
                    for (int i = 0; i < array.getElements().size(); i++) {
                        events.addAll(member.accept(traverse(String.valueOf(i), array.getElements().get(i))));
                    }
                    return events;
                })
                .orElseGet(() -> invalidShape(shape, NodeType.ARRAY));
    }

    @Override
    public List<ValidationEvent> mapShape(MapShape shape) {
        return value.asObjectNode()
                .map(object -> {
                    List<ValidationEvent> events = applyPlugins(shape);
                    for (Map.Entry<StringNode, Node> entry : object.getMembers().entrySet()) {
                        String key = entry.getKey().getValue();
                        events.addAll(traverse(key + " (map-key)", entry.getKey()).memberShape(shape.getKey()));
                        events.addAll(traverse(key, entry.getValue()).memberShape(shape.getValue()));
                    }
                    return events;
                })
                .orElseGet(() -> invalidShape(shape, NodeType.OBJECT));
    }

    @Override
    public List<ValidationEvent> structureShape(StructureShape shape) {
        return value.asObjectNode()
                .map(object -> {
                    List<ValidationEvent> events = applyPlugins(shape);
                    Map<String, MemberShape> members = shape.getAllMembers();

                    for (Map.Entry<String, Node> entry : object.getStringMap().entrySet()) {
                        String entryKey = entry.getKey();
                        Node entryValue = entry.getValue();
                        if (!members.containsKey(entryKey)) {
                            events.add(unknownMember(entryKey, shape, Severity.WARNING));
                        } else {
                            events.addAll(traverse(entryKey, entryValue).memberShape(members.get(entryKey)));
                        }
                    }

                    for (MemberShape member : members.values()) {
                        if (member.isRequired() && !object.getMember(member.getMemberName()).isPresent()) {
                            Severity severity = this.validationContext.hasFeature(Feature.ALLOW_CONSTRAINT_ERRORS)
                                    ? Severity.WARNING
                                    : Severity.ERROR;
                            events.add(event(String.format(
                                    "Missing required structure member `%s` for `%s`",
                                    member.getMemberName(),
                                    shape.getId()), severity));
                        }
                    }
                    return events;
                })
                .orElseGet(() -> invalidShape(shape, NodeType.OBJECT));
    }

    @Override
    public List<ValidationEvent> unionShape(UnionShape shape) {
        return value.asObjectNode()
                .map(object -> {
                    List<ValidationEvent> events = applyPlugins(shape);
                    if (object.size() > 1) {
                        events.add(event("union values can contain a value for only a single member"));
                    } else {
                        Map<String, MemberShape> members = shape.getAllMembers();
                        for (Map.Entry<String, Node> entry : object.getStringMap().entrySet()) {
                            String entryKey = entry.getKey();
                            Node entryValue = entry.getValue();
                            if (!members.containsKey(entryKey)) {
                                events.add(unknownMember(entryKey, shape, Severity.ERROR));
                            } else {
                                events.addAll(traverse(entryKey, entryValue).memberShape(members.get(entryKey)));
                            }
                        }
                    }
                    return events;
                })
                .orElseGet(() -> invalidShape(shape, NodeType.OBJECT));
    }

    @Override
    public List<ValidationEvent> memberShape(MemberShape shape) {
        List<ValidationEvent> events = applyPlugins(shape);
        if (value.isNullNode()) {
            events.addAll(checkNullMember(shape));
        }
        model.getShape(shape.getTarget()).ifPresent(target -> {
            // We only need to keep track of a single referring member, so a stack of members or anything like that
            // isn't needed here.
            validationContext.setReferringMember(shape);
            events.addAll(target.accept(this));
            validationContext.setReferringMember(null);
        });
        return events;
    }

    public List<ValidationEvent> checkNullMember(MemberShape shape) {
        if (!nullableIndex.isMemberNullable(shape)) {
            switch (model.expectShape(shape.getContainer()).getType()) {
                case LIST:
                    return ListUtils.of(event(
                            String.format(
                                    "Non-sparse list shape `%s` cannot contain null values",
                                    shape.getContainer())));
                case MAP:
                    return ListUtils.of(event(
                            String.format(
                                    "Non-sparse map shape `%s` cannot contain null values",
                                    shape.getContainer())));
                case STRUCTURE:
                    return ListUtils.of(event(
                            String.format("Required structure member `%s` for `%s` cannot be null",
                                    shape.getMemberName(),
                                    shape.getContainer())));
                default:
                    break;
            }
        }
        return ListUtils.of();
    }

    @Override
    public List<ValidationEvent> operationShape(OperationShape shape) {
        return invalidSchema(shape);
    }

    @Override
    public List<ValidationEvent> resourceShape(ResourceShape shape) {
        return invalidSchema(shape);
    }

    @Override
    public List<ValidationEvent> serviceShape(ServiceShape shape) {
        return invalidSchema(shape);
    }

    private List<ValidationEvent> invalidShape(Shape shape, NodeType expectedType) {
        // Nullable shapes allow null values.
        if (value.isNullNode() && validationContext.hasFeature(Feature.ALLOW_OPTIONAL_NULLS)) {
            // Non-members are nullable. Members are nullable based on context.
            if (!shape.isMemberShape() || shape.asMemberShape().filter(nullableIndex::isMemberNullable).isPresent()) {
                return Collections.emptyList();
            }
        }

        String message = String.format(
                "Expected %s value for %s shape, `%s`; found %s value",
                expectedType,
                shape.getType(),
                shape.getId(),
                value.getType());
        if (value.isStringNode()) {
            message += ", `" + value.expectStringNode().getValue() + "`";
        } else if (value.isNumberNode()) {
            message += ", `" + value.expectNumberNode().getValue() + "`";
        } else if (value.isBooleanNode()) {
            message += ", `" + value.expectBooleanNode().getValue() + "`";
        }
        return ListUtils.of(event(message));
    }

    private List<ValidationEvent> invalidSchema(Shape shape) {
        return ListUtils.of(event("Encountered invalid shape type: " + shape.getType()));
    }

    private ValidationEvent unknownMember(String memberName, Shape shape, Severity severity) {
        return event(String.format("Member `%s` does not exist in `%s`", memberName, shape.getId()),
                severity,
                "UnknownMember",
                shape.getId().toString(),
                memberName);
    }

    private ValidationEvent event(String message, String... additionalEventIdParts) {
        return event(message, Severity.ERROR, additionalEventIdParts);
    }

    private ValidationEvent event(String message, Severity severity, String... additionalEventIdParts) {
        return event(message, severity, value.getSourceLocation(), additionalEventIdParts);
    }

    private ValidationEvent event(
            String message,
            Severity severity,
            SourceLocation sourceLocation,
            String... additionalEventIdParts
    ) {
        return ValidationEvent.builder()
                .id(additionalEventIdParts.length > 0
                        ? eventId + "." + String.join(".", additionalEventIdParts)
                        : eventId)
                .severity(severity)
                .sourceLocation(sourceLocation)
                .shapeId(eventShapeId)
                .message(startingContext.isEmpty() ? message : startingContext + ": " + message)
                .build();
    }

    private List<ValidationEvent> applyPlugins(Shape shape) {
        List<ValidationEvent> events = new ArrayList<>();
        timestampValidationStrategy.apply(shape,
                value,
                validationContext,
                (location, severity, message, additionalEventIdParts) -> events
                        .add(event(message, severity, location.getSourceLocation(), additionalEventIdParts)));

        for (NodeValidatorPlugin plugin : BUILTIN) {
            plugin.apply(shape,
                    value,
                    validationContext,
                    (location, severity, message, additionalEventIdParts) -> events
                            .add(event(message, severity, location.getSourceLocation(), additionalEventIdParts)));
        }

        return events;
    }

    /**
     * Builds a {@link NodeValidationVisitor}.
     */
    public static final class Builder implements SmithyBuilder<NodeValidationVisitor> {
        private String eventId;
        private String contextText;
        private ShapeId eventShapeId;
        private Node value;
        private Model model;
        private TimestampValidationStrategy timestampValidationStrategy = TimestampValidationStrategy.FORMAT;
        private final Set<Feature> features = new HashSet<>();

        Builder() {}

        /**
         * Sets the <strong>required</strong> model to use when traversing
         * walking shapes during validation.
         *
         * @param model Model that contains shapes to validate.
         * @return Returns the builder.
         */
        public Builder model(Model model) {
            this.model = model;
            return this;
        }

        /**
         * Sets the <strong>required</strong> node value to validate.
         *
         * @param value Value to validate.
         * @return Returns the builder.
         */
        public Builder value(Node value) {
            this.value = Objects.requireNonNull(value);
            return this;
        }

        /**
         * Sets an optional custom event ID to use for created validation events.
         *
         * @param id Custom event ID.
         * @return Returns the builder.
         */
        public Builder eventId(String id) {
            this.eventId = Objects.requireNonNull(id);
            return this;
        }

        /**
         * Sets an optional starting context of the validator that is prepended
         * to each emitted validation event message.
         *
         * @param contextText Starting event message content.
         * @return Returns the builder.
         */
        public Builder startingContext(String contextText) {
            this.contextText = Objects.requireNonNull(contextText);
            return this;
        }

        /**
         * Sets an optional shape ID that is used as the shape ID in each
         * validation event emitted by the validator.
         *
         * @param eventShapeId Shape ID to set on every validation event.
         * @return Returns the builder.
         */
        public Builder eventShapeId(ShapeId eventShapeId) {
            this.eventShapeId = eventShapeId;
            return this;
        }

        /**
         * Sets the strategy used to validate timestamps.
         *
         * <p>By default, timestamps are validated using
         * {@link TimestampValidationStrategy#FORMAT}.
         *
         * @param timestampValidationStrategy Timestamp validation strategy.
         * @return Returns the builder.
         */
        public Builder timestampValidationStrategy(TimestampValidationStrategy timestampValidationStrategy) {
            this.timestampValidationStrategy = timestampValidationStrategy;
            return this;
        }

        @Deprecated
        public Builder allowBoxedNull(boolean allowBoxedNull) {
            return allowOptionalNull(allowBoxedNull);
        }

        @Deprecated
        public Builder allowOptionalNull(boolean allowOptionalNull) {
            if (allowOptionalNull) {
                return addFeature(Feature.ALLOW_OPTIONAL_NULLS);
            } else {
                features.remove(Feature.ALLOW_OPTIONAL_NULLS);
                return this;
            }
        }

        /**
         * Adds a feature flag to the validator.
         *
         * @param feature Feature to set.
         * @return Returns the builder.
         */
        public Builder addFeature(Feature feature) {
            this.features.add(feature);
            return this;
        }

        @Override
        public NodeValidationVisitor build() {
            return new NodeValidationVisitor(this);
        }
    }
}
