/*
 * Decompiled with CFR 0.152.
 */
package software.amazon.smithy.aws.apigateway.openapi;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.TreeSet;
import java.util.logging.Logger;
import software.amazon.smithy.aws.apigateway.openapi.ApiGatewayConfig;
import software.amazon.smithy.aws.apigateway.openapi.ApiGatewayMapper;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.knowledge.TopDownIndex;
import software.amazon.smithy.model.node.Node;
import software.amazon.smithy.model.node.ObjectNode;
import software.amazon.smithy.model.node.ToNode;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ToShapeId;
import software.amazon.smithy.model.traits.CorsTrait;
import software.amazon.smithy.model.traits.Trait;
import software.amazon.smithy.openapi.fromsmithy.Context;
import software.amazon.smithy.openapi.model.OpenApi;
import software.amazon.smithy.openapi.model.OperationObject;
import software.amazon.smithy.openapi.model.ParameterObject;
import software.amazon.smithy.openapi.model.PathItem;
import software.amazon.smithy.openapi.model.Ref;
import software.amazon.smithy.openapi.model.ResponseObject;
import software.amazon.smithy.utils.ListUtils;
import software.amazon.smithy.utils.SetUtils;
import software.amazon.smithy.utils.SmithyInternalApi;

@SmithyInternalApi
public final class CorsHttpIntegration
implements ApiGatewayMapper {
    private static final Logger LOGGER = Logger.getLogger(CorsHttpIntegration.class.getName());
    private static final String CORS_HTTP_EXTENSION = "x-amazon-apigateway-cors";

    @Override
    public List<ApiGatewayConfig.ApiType> getApiTypes() {
        return ListUtils.of((Object)((Object)ApiGatewayConfig.ApiType.HTTP));
    }

    public OpenApi after(Context<? extends Trait> context, OpenApi openapi) {
        return context.getService().getTrait(CorsTrait.class).map(corsTrait -> this.addCors(context, openapi, (CorsTrait)corsTrait)).orElse(openapi);
    }

    private OpenApi addCors(Context<? extends Trait> context, OpenApi openapi, CorsTrait trait) {
        Node alreadySetCorsValue = openapi.getExtension(CORS_HTTP_EXTENSION).flatMap(Node::asObjectNode).orElse(null);
        if (alreadySetCorsValue != null) {
            return openapi;
        }
        Set<String> allowedMethodsInService = this.getMethodsUsedInApi(context, openapi);
        Set<String> allowedRequestHeaders = this.getAllowedHeaders(context, trait, openapi);
        Set<String> exposedHeaders = this.getExposedHeaders(context, trait, openapi);
        ObjectNode.Builder corsObjectBuilder = Node.objectNodeBuilder().withMember("allowOrigins", (ToNode)Node.fromStrings((String[])new String[]{trait.getOrigin()})).withMember("maxAge", (Number)trait.getMaxAge()).withMember("allowMethods", (ToNode)Node.fromStrings(allowedMethodsInService)).withMember("exposeHeaders", (ToNode)Node.fromStrings(exposedHeaders)).withMember("allowHeaders", (ToNode)Node.fromStrings(allowedRequestHeaders));
        if (context.usesHttpCredentials()) {
            corsObjectBuilder.withMember("allowCredentials", true);
        }
        return ((OpenApi.Builder)openapi.toBuilder().putExtension(CORS_HTTP_EXTENSION, (Node)corsObjectBuilder.build())).build();
    }

    private <T extends Trait> Set<String> getMethodsUsedInApi(Context<T> context, OpenApi openApi) {
        TreeSet<String> methods = new TreeSet<String>();
        if (!context.usesHttpCredentials()) {
            LOGGER.info("Using * for Access-Control-Allow-Methods because the service does not use HTTP credentials");
            return SetUtils.of((Object)"*");
        }
        LOGGER.info("Generating a value for Access-Control-Allow-Methods because the service uses HTTP credentials");
        for (PathItem pathItem : openApi.getPaths().values()) {
            for (String method : pathItem.getOperations().keySet()) {
                if (method.equalsIgnoreCase("OPTIONS")) continue;
                methods.add(method.toUpperCase(Locale.ENGLISH));
            }
        }
        return methods;
    }

    private <T extends Trait> Set<String> getAllowedHeaders(Context<T> context, CorsTrait corsTrait, OpenApi openApi) {
        TreeSet<String> headers = new TreeSet<String>(corsTrait.getAdditionalAllowedHeaders());
        if (headers.isEmpty() && !context.usesHttpCredentials()) {
            LOGGER.info("Using * for Access-Control-Allow-Headers because the service does not use HTTP credentials");
            return SetUtils.of((Object)"*");
        }
        LOGGER.info("Generating a value for Access-Control-Allow-Headers because the service uses HTTP credentials");
        headers.addAll(context.getAllSecuritySchemeRequestHeaders());
        TopDownIndex topDownIndex = TopDownIndex.of((Model)context.getModel());
        for (OperationShape operation : topDownIndex.getContainedOperations((ToShapeId)context.getService())) {
            headers.addAll(context.getOpenApiProtocol().getProtocolRequestHeaders(context, operation));
        }
        for (PathItem item : openApi.getPaths().values()) {
            headers.addAll(this.getHeadersFromParameterRefs(openApi, item.getParameters()));
            for (OperationObject operationObject : item.getOperations().values()) {
                headers.addAll(this.getHeadersFromParameters(operationObject.getParameters()));
            }
        }
        return headers;
    }

    private <T extends Trait> Set<String> getExposedHeaders(Context<T> context, CorsTrait corsTrait, OpenApi openApi) {
        TreeSet<String> headers = new TreeSet<String>(corsTrait.getAdditionalExposedHeaders());
        if (headers.isEmpty() && !context.usesHttpCredentials()) {
            LOGGER.info("Using * for Access-Control-Expose-Headers because the service does not use HTTP credentials");
            return SetUtils.of((Object)"*");
        }
        LOGGER.info("Generating a value for Access-Control-Expose-Headers because the service uses HTTP credentials");
        headers.addAll(context.getAllSecuritySchemeResponseHeaders());
        TopDownIndex topDownIndex = TopDownIndex.of((Model)context.getModel());
        for (OperationShape operation : topDownIndex.getContainedOperations((ToShapeId)context.getService())) {
            headers.addAll(context.getOpenApiProtocol().getProtocolResponseHeaders(context, operation));
        }
        for (PathItem item : openApi.getPaths().values()) {
            for (OperationObject operationObject : item.getOperations().values()) {
                for (ResponseObject responseObject : operationObject.getResponses().values()) {
                    headers.addAll(responseObject.getHeaders().keySet());
                }
            }
        }
        return headers;
    }

    private Set<String> getHeadersFromParameterRefs(OpenApi openApi, Collection<Ref<ParameterObject>> params) {
        ArrayList<ParameterObject> resolved = new ArrayList<ParameterObject>();
        for (Ref<ParameterObject> ref : params) {
            resolved.add((ParameterObject)ref.deref(openApi.getComponents()));
        }
        return this.getHeadersFromParameters(resolved);
    }

    private Set<String> getHeadersFromParameters(Collection<ParameterObject> params) {
        TreeSet<String> result = new TreeSet<String>();
        for (ParameterObject param : params) {
            if (!param.getIn().filter(in -> in.equals("header")).isPresent()) continue;
            param.getName().ifPresent(result::add);
        }
        return result;
    }
}

