/*
 * Copyright (c) 2023 MarkLogic Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.marklogic.xcc.impl.handlers;

import com.marklogic.http.HttpChannel;
import com.marklogic.xcc.Request;
import com.marklogic.xcc.RequestOptions;
import com.marklogic.xcc.ResultSequence;
import com.marklogic.xcc.exceptions.MLCloudRequestException;
import com.marklogic.xcc.exceptions.RequestException;
import com.marklogic.xcc.impl.ContentSourceImpl;
import com.marklogic.xcc.impl.Credentials;
import com.marklogic.xcc.impl.SessionImpl;
import com.marklogic.xcc.spi.ConnectionProvider;
import com.marklogic.xcc.spi.ServerConnection;

import java.io.IOException;
import java.net.HttpURLConnection;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;

public class MLCloudRequestController extends AbstractRequestController {
    private static final int MAX_RETRY = 5;
    private static final int RETRY_DELAY_MILLIS = 125;
    private static final Map<Integer, ResponseHandler> handlers =
        new HashMap<>();
    private Credentials.MLCloudAuthConfig mlCloudAuthConfig;
    private String boundary;

    static {
        addDefaultHandler(handlers, new UnrecognizedCodeHandler());
        // 200 OK --> Request success
        addHandler(handlers, HttpURLConnection.HTTP_OK,
            new MLCloudSuccessResponseHandler());
        // 401 Unauthorized --> Invalid API key specified/session token expired
        addHandler(handlers, HttpURLConnection.HTTP_UNAUTHORIZED,
            new UnauthorizedHandler());
        // 403 Access Denied --> Access denied to the given resource
        addHandler(handlers, HttpURLConnection.HTTP_FORBIDDEN,
            new MLCloudForbiddenHandler());
        // 404 Not Found --> Invalid request URL
        addHandler(handlers, HttpURLConnection.HTTP_NOT_FOUND,
            new NotFoundCodeHandler());
        // 410 HTTP Gone --> Tenant expiry
        addHandler(handlers, HttpURLConnection.HTTP_GONE,
            new MLCloudGoneHandler());
        //5xx General server errors
        addHandler(handlers, HttpURLConnection.HTTP_INTERNAL_ERROR,
            new ServerExceptionHandler());
        addHandler(handlers, HttpURLConnection.HTTP_BAD_GATEWAY,
            new ServiceUnavailableHandler());
        addHandler(handlers, HttpURLConnection.HTTP_UNAVAILABLE,
            new ServiceUnavailableHandler());
        addHandler(handlers, HttpURLConnection.HTTP_GATEWAY_TIMEOUT,
            new ServiceUnavailableHandler());
    }

    public MLCloudRequestController(
        Credentials.MLCloudAuthConfig mlCloudAuthConfig) {
        super(handlers, DEFAULT_SERVER_PATH,
            mlCloudAuthConfig.getTokenEndpoint());
        this.mlCloudAuthConfig = mlCloudAuthConfig;
    }

    @Override
    public ResultSequence runRequest(
        ConnectionProvider provider, Request request, Logger logger)
        throws RequestException {
        SessionImpl session = (SessionImpl)request.getSession();
        RequestOptions options = request.getEffectiveOptions();
        ServerConnection connection = null;
        RequestException re = null;
        int t = 0;
        for (; t < MAX_RETRY; t++) {
            if (t > 0) {
                if (logger.isLoggable(Level.FINE)) {
                    logger.log(Level.FINE, "Retrying connecting to MarkLogic " +
                        "Cloud (" + t + ").");
                }
            }
            try {
                sleepFor(interTryDelay(RETRY_DELAY_MILLIS, t), logger);
                connection = provider.obtainConnection(session, request, logger);
                ResultSequence rs = serverDialog(connection, request, options,
                    logger);
                if ((rs == null) || rs.isCached()) {
                    provider.returnConnection(connection, logger);
                }
                return rs;
            } catch (RequestException e) {
                logger.log(Level.WARNING, "Request exception connecting to " +
                    "MarkLogic Cloud. Cannot obtain session token.", e);
                provider.returnConnection(connection, logger);
                if (e instanceof MLCloudRequestException) {
                    if (!e.isRetryable()) throw e;
                    if (logger.isLoggable(Level.FINE)) {
                        logger.log(Level.FINE, "Retryable exception caught.", e);
                    }
                    re = e;
                } else throw e;
            } catch (IOException e) {
                logger.log(Level.WARNING, "Connection IOException caught. " +
                    "Cannot obtain session token.", e);
                if (connection != null) {
                    provider.returnErrorConnection(connection, e, logger);
                }
                re = new MLCloudRequestException(e.getMessage(), request, e,
                    true);
            } catch (Exception e) {
                logger.log(Level.WARNING, "Exception connecting to " +
                    "MarkLogic Cloud. Cannot obtain session token.", e);
                re = new MLCloudRequestException(e.getMessage(), request, e,
                    true);
            }
        }
        logger.log(Level.WARNING, "Automatic connecting to MarkLogic Cloud " +
            "retries (" + t + ") exhausted, throwing: " + re, re);
        throw re;
    }

    @Override
    public ResultSequence serverDialog(ServerConnection connection,
                                       Request request, RequestOptions options,
                                       Logger logger)
        throws IOException, RequestException{
        SessionImpl session = (SessionImpl)request.getSession();
        byte[] bodyBytes = buildFormBody().getBytes();
        String reqUri = addTokenDurationToPath(httpPath);
        HttpChannel http = buildChannel(connection, reqUri, session,
            bodyBytes.length, options, logger);
        http.write(bodyBytes);
        if (HttpChannel.isUseHTTP()) {
            http.write("\r\n".getBytes());
        }
        int code = http.getResponseCode();
        ResponseHandler handler = findHandler(code);
        handler.handleResponse(http, code, request, mlCloudAuthConfig, logger);
        return null;
    }

    private HttpChannel buildChannel(ServerConnection connection, String path,
                                     SessionImpl session, int bufferSize,
                                     RequestOptions options, Logger logger) {
        String method = "POST";
        HttpChannel http = new HttpChannel(connection.channel(), method, path,
            bufferSize, options.getTimeoutMillis(), logger);

        // Add MLCloud headers
        ContentSourceImpl contentSource =
            (ContentSourceImpl)session.getContentSource();
        if(HttpChannel.isUseHTTP()) {
            ConnectionProvider cp = contentSource.getConnectionProvider();
            http.setRequestHeader("Host", cp.getHostName() + ":" + cp.getPort());
        }
        http.setRequestHeader("User-Agent", session.userAgentString());
        http.setRequestHeader("Accept", session.getAcceptedContentTypes());
        http.setRequestContentType("multipart/form-data; boundary=" + boundary);
        return http;
    }

    private String buildFormBody() {
        HttpChannel.MultipartFormBody formBody =
            new HttpChannel.MultipartFormBody();
        this.boundary = formBody.getBoundary();
        formBody.addTextBody("key",
            new String(mlCloudAuthConfig.getApiKey()));
        formBody.addTextBody("grant_type", mlCloudAuthConfig.getGrantType());
        return formBody.buildFormBody();
    }

    private String addTokenDurationToPath(String httpPath) {
        StringBuilder sb = new StringBuilder(httpPath);
        int tokenDuration = mlCloudAuthConfig.getTokenDuration();
        if (tokenDuration != Credentials.DEFAULT_TOKEN_DURATION) {
            sb.append("?duration=");
            sb.append(tokenDuration);
        }
        return sb.toString();
    }
}
