package org.jfrog.security.common;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.commons.codec.DecoderException;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;
import org.bouncycastle.util.encoders.Hex;
import org.jfrog.security.crypto.EncryptionWrapper;
import org.jfrog.security.crypto.EncryptionWrapperFactory;
import org.jfrog.security.crypto.SecretProvider;
import org.jfrog.security.file.SecurityFolderHelper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;
import javax.crypto.SecretKey;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.BooleanSupplier;
import java.util.function.Supplier;

import static org.apache.commons.codec.binary.Hex.decodeHex;

/**
 * @author gidis
 */
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class KeyUtils {
    private static final Logger log = LoggerFactory.getLogger(KeyUtils.class);
    private static final int LOG_PRINT_INTERVAL_SECONDS = 5;

    public static void saveKeyToFile(File keyFile, EncryptionWrapper keyWrapper) {
        if (!keyFile.getParentFile().exists()) {
            boolean created = keyFile.getParentFile().mkdirs();
            if (!created) {
                log.warn("Unable to create directory {}!", keyFile.getParentFile().getPath());
            }
        }
        try {
            SecretKey secretKey = ((SecretProvider) keyWrapper).getSecret();
            byte[] keyBytes = secretKey.getEncoded();
            byte[] encodedBytes = Hex.encode(keyBytes);
            Path keyPath = Files.write(keyFile.toPath(), encodedBytes);
            SecurityFolderHelper
                    .setPermissionsOnSecurityFile(keyPath, SecurityFolderHelper.PERMISSIONS_MODE_600);
        } catch (Exception e) {
            throw new RuntimeException("Could not save key " + keyFile.getName(), e);
        }
    }

    public static String readKeyFromFile(File joinKeyFile) throws IOException {
        return StringUtils.trim(FileUtils.readFileToString(joinKeyFile));
    }

    public static void initKey(File keyFile, String key) {
        EncryptionWrapper keyWrapper = EncryptionWrapperFactory.aesKeyWrapperFromString(key);
        KeyUtils.saveKeyToFile(keyFile, keyWrapper);
    }

    public static void waitForKey(String keyName, long timeoutMillis, File keyFile) {
        waitUntilKeyIsAvailable(keyName, timeoutMillis, keyFile::exists, 1);
        if (!keyFile.exists()) {
            throw keyResolutionFailure(keyName, null);
        }
        resolvedSuccessfullyLog(keyName);
    }

    public static String waitForKey(String keyName, long timeoutMillis, Supplier<Optional<String>> keyRetriever) {
        waitUntilKeyIsAvailable(keyName, timeoutMillis, () -> keyRetriever.get().isPresent(), 1);
        Optional<String> key = keyRetriever.get();
        if (key.isPresent()) {
            validateKey(keyName, key.get());
            resolvedSuccessfullyLog(keyName);
            return key.get();
        }
        throw keyResolutionFailure(keyName, null);
    }

    public static void validateHexEncoding(String joinKey) throws DecoderException {
        decodeHex(joinKey.toCharArray());
    }

    /**
     * Wait until key is available or until timed out (first). Does not throw exception if key is missing.
     * Do not use this method directly, use {@link KeyUtils#waitForKey(String, long, File)}\
     *
     * @param keyExistenceSupplier The key existence supplier
     * @return true if the key exists
     */
    static boolean waitUntilKeyIsAvailable(String keyName, long timeoutMillis, BooleanSupplier keyExistenceSupplier,
            int existenceCheckDelaySecs) {
        long startTime = System.currentTimeMillis();
        long now = startTime;
        long lastWaitingLogEntryTime = 0;
        if (log.isDebugEnabled()) {
            log.debug("{}Resolving {} key with {} seconds timeout", getPrefixForLogs(keyName), keyName,
                    TimeUnit.MILLISECONDS.toSeconds(timeoutMillis));
        }
        while (!keyExistenceSupplier.getAsBoolean() && now - startTime < timeoutMillis) {
            long secondsPending = TimeUnit.MILLISECONDS.toSeconds(now - lastWaitingLogEntryTime);
            if (secondsPending >= LOG_PRINT_INTERVAL_SECONDS) {
                if (lastWaitingLogEntryTime > 0 && log.isInfoEnabled()) {
                    log.info("{}{} key is missing. Pending for {} seconds with {} seconds timeout",
                            getPrefixForLogs(keyName), StringUtils.capitalize(keyName), secondsPending,
                            TimeUnit.MILLISECONDS.toSeconds(timeoutMillis));
                }
                lastWaitingLogEntryTime = now;
            }
            sleep(existenceCheckDelaySecs, keyName);
            now = System.currentTimeMillis();
        }
        return keyExistenceSupplier.getAsBoolean();
    }

    public static void sleep(long secToSleep, String keyName) {
        try {
            Thread.sleep(secToSleep * 1000);
        } catch (InterruptedException e) {
            log.warn("{}Sleep interrupted while waiting for {} key", getPrefixForLogs(keyName), keyName);
            log.debug("", e);
        }
    }

    private static void validateKey(String keyName, String key) {
        if (StringUtils.isBlank(key)) {
            throw keyResolutionFailure(keyName,
                    new IllegalStateException("Corrupted " + keyName + " key: Empty " + keyName + " key"));
        }
        try {
            validateHexEncoding(key);
        } catch (DecoderException e) {
            throw keyResolutionFailure(keyName,
                    new IllegalStateException("Corrupted " + keyName + " key: key must be hex encoded"));
        }
    }

    private static void resolvedSuccessfullyLog(String keyName) {
        if (log.isDebugEnabled()) {
            log.debug("{}{} key resolved successfully", getPrefixForLogs(keyName), StringUtils.capitalize(keyName));
        }
    }

    private static IllegalStateException keyResolutionFailure(String keyName, @Nullable Throwable exception) {
        IllegalStateException toThrow;
        if (exception == null) {
            toThrow = new IllegalStateException(
                    getPrefixForLogs(keyName) + "Failed resolving " + keyName + " key; Missing " + keyName + " key");
        } else {
            if (log.isDebugEnabled()) {
                log.debug("", exception);
            }
            String errorMessage = String
                    .format("%sFailed resolving %s key; %s", getPrefixForLogs(keyName), keyName,
                            exception.getMessage());
            log.error(errorMessage);
            toThrow = new IllegalStateException(errorMessage);
        }
        return toThrow;
    }

    private static String getPrefixForLogs(String keyName) {
        return keyName.equalsIgnoreCase("join") ? "Cluster join: " : "";
    }
}
