/*
 * Artifactory is a binaries repository manager.
 * Copyright (C) 2019 JFrog Ltd.
 *
 * Artifactory is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 * Artifactory is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Artifactory.  If not, see <http://www.gnu.org/licenses/>.
 */

package org.jfrog.security.crypto;

import org.apache.commons.codec.digest.DigestUtils;
import org.jfrog.security.crypto.encoder.EncryptedString;
import org.jfrog.security.crypto.exception.CryptoRuntimeException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Objects;

import static java.nio.charset.StandardCharsets.UTF_8;

/**
 * This master key encrypter implements decryption/encryption according to the new JFrog specification.
 * See {@link EncryptedString} for more details.
 *
 * @author Yossi Shaul
 * @see EncryptedString
 */
public class JFrogMasterKeyEncrypter {
    private static final Logger log = LoggerFactory.getLogger(JFrogMasterKeyEncrypter.class);

    public static final String AES_CYPHER_TRANSFORM = "AES/GCM/NoPadding";
    public static final String ALG_AES_GCM_128 = "aesgcm128";
    public static final String ALG_AES_GCM_256 = "aesgcm256";
    private static final int GCM_TAG_LENGTH = 16;   // 128 bit tag length
    private static final int GCM_IV_LENGTH = 12;    // 96 bit iv length (aka nonce)

    private final SecretKey secretKey;
    /**
     * Key id contains the first 6 characters of the sha256 of the encryption key (lower cased)
     */
    final String keyId;
    /**
     * Algorithm used by this encryptor. Either aesgcm128 or aesgcm256
     */
    final String alg;

    /**
     * Creates a new master key encrypter using the give 128 or 256 bit key. The key is HEX encoded.
     *
     * @param key Hexadecimal encoded 128 or 256 bit key
     */
    public JFrogMasterKeyEncrypter(String key) {
        secretKey = JFrogCryptoHelper.aesFromString(key);
        keyId = calculateKeyId(secretKey);
        alg = secretKey.getEncoded().length == 16 ? ALG_AES_GCM_128 : ALG_AES_GCM_256;
    }

    @Nonnull
    public EncryptedString encrypt(@Nonnull String plainText) {
        Objects.requireNonNull(plainText, "Cannot encrypt null value");
        byte[] cipherText = encryptInternal(plainText);
        return new EncryptedString(keyId, alg, cipherText);
    }


    private byte[] decryptInternal(byte[] cipherText) {
        try {
            Cipher cipher = Cipher.getInstance(AES_CYPHER_TRANSFORM);
            AesCipherText aesCipherText = AesCipherText.fromCipherTextWithIv(cipherText);
            GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(GCM_TAG_LENGTH * 8, aesCipherText.iv);
            cipher.init(Cipher.DECRYPT_MODE, secretKey, gcmParameterSpec);
            return cipher.doFinal(aesCipherText.cipherNoIv);
        } catch (GeneralSecurityException e) {
            throw new CryptoRuntimeException(e);
        }
    }

    public String decrypt(String encrypted) {
        return decrypt(EncryptedString.parse(encrypted));
    }

    public String decrypt(EncryptedString encrypted) {
        if (!isEncryptedByMe(encrypted.encode())) {
            throw new CryptoRuntimeException("Input is not encrypted by current encrypter");
        }
        return new String(decryptInternal(encrypted.getCipherText()), UTF_8);
    }

    public boolean isEncryptedByMe(String encryptedString) {
        if (!EncryptedString.isEncodedByMe(encryptedString)) {
            return false;
        }
        EncryptedString encrypted = EncryptedString.parse(encryptedString);
        if (!keyId.equalsIgnoreCase(encrypted.getKeyId())) {
            log.warn("Encrypted data with key id {} is not encrypted with current key id of {}",
                    encrypted.getKeyId(), keyId);
            logStackTrackIfNeeded();
            return false;
        }
        if (!alg.equalsIgnoreCase(encrypted.getAlg())) {
            log.warn("Encrypted data with algorithm {} is not encrypted with current algorithm of {}",
                    encrypted.getAlg(), alg);
            logStackTrackIfNeeded();
            return false;
        }
        if (encrypted.getCipherText().length < GCM_IV_LENGTH + GCM_TAG_LENGTH) {
            log.warn("Encrypted data size of {} is smaller than the minimum required of {}",
                    encrypted.getCipherText().length, GCM_IV_LENGTH + GCM_TAG_LENGTH);
            logStackTrackIfNeeded();
            return false;
        }

        return true;
    }

    private void logStackTrackIfNeeded() {
        if (log.isTraceEnabled()) {
            StringBuilder stringBuilder = new StringBuilder();
            stringBuilder.append("Decryption mismatch stacktrace:");
            StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
            for (StackTraceElement element : stackTrace) {
                stringBuilder.append(System.lineSeparator());
                stringBuilder.append("\tat ").append(element);
            }
            log.trace(stringBuilder.toString());
        }
    }

    private byte[] encryptInternal(String plainText) {
        try {
            Cipher cipher = Cipher.getInstance(AES_CYPHER_TRANSFORM);
            byte[] iv = generateRandomInitializationVector();
            GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(GCM_TAG_LENGTH * 8, iv);
            cipher.init(Cipher.ENCRYPT_MODE, secretKey, gcmParameterSpec);
            byte[] cipherNoIv = cipher.doFinal(plainText.getBytes());
            return AesCipherText.fromIvAndCipher(iv, cipherNoIv).getCipherText();
        } catch (GeneralSecurityException e) {
            throw new CryptoRuntimeException(e);
        }
    }

    private byte[] generateRandomInitializationVector() {
        SecureRandom random = new SecureRandom();
        byte[] iv = new byte[GCM_IV_LENGTH];
        random.nextBytes(iv);
        return iv;
    }

    private String calculateKeyId(SecretKey secretKey) {
        return DigestUtils.sha256Hex(secretKey.getEncoded()).substring(0, 6).toLowerCase();
    }

    private static class AesCipherText {
        private final byte[] iv;
        private final byte[] cipherNoIv;

        private static AesCipherText fromIvAndCipher(byte[] iv, byte[] cipherNoIv) {
            return new AesCipherText(iv, cipherNoIv);
        }

        private static AesCipherText fromCipherTextWithIv(byte[] cipherText) {
            byte[] iv = extractInitializationVector(cipherText);
            byte[] cipherNoIv = removeInitializationVector(cipherText);
            return new AesCipherText(iv, cipherNoIv);
        }

        private AesCipherText(byte[] iv, byte[] cipherNoIv) {
            this.iv = iv;
            this.cipherNoIv = cipherNoIv;
        }

        /**
         * The aes cipher returns a cipher text without the initialization vector (iv). The iv is needed
         * for decryption and expected to occupy the first 12 bytes ot the cypher text.
         */
        private byte[] getCipherText() {
            byte[] cipherText = new byte[cipherNoIv.length + GCM_IV_LENGTH];
            System.arraycopy(iv, 0, cipherText, 0, iv.length);
            System.arraycopy(cipherNoIv, 0, cipherText, iv.length, cipherNoIv.length);
            return cipherText;
        }

        private static byte[] extractInitializationVector(byte[] cipherText) {
            // first $GCM_IV_LENGTH bytes of the cypherText are the IV (initialization vector)
            return Arrays.copyOfRange(cipherText, 0, GCM_IV_LENGTH);
        }

        private static byte[] removeInitializationVector(byte[] cipherText) {
            return Arrays.copyOfRange(cipherText, GCM_IV_LENGTH, cipherText.length);
        }

    }

}
