// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
////////////////////////////////////////////////////////////////////////////////

package com.google.crypto.tink.hybrid.subtle;

import com.google.crypto.tink.Aead;
import com.google.crypto.tink.HybridDecrypt;
import com.google.crypto.tink.aead.subtle.AeadFactory;
import com.google.crypto.tink.subtle.Hkdf;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.InvalidKeyException;
import java.security.PrivateKey;
import java.security.interfaces.RSAKey;
import java.security.interfaces.RSAPrivateKey;

/**
 * Hybrid encryption with RSA-KEM as defined in Shoup's ISO standard proposal as KEM, and AEAD as
 * DEM and HKDF as KDF.
 *
 * <p>Shoup's ISO standard proposal is available at https://www.shoup.net/iso/std6.pdf.
 */
public final class RsaKemHybridDecrypt implements HybridDecrypt {
  private final PrivateKey recipientPrivateKey;
  private final String hkdfHmacAlgo;
  private final byte[] hkdfSalt;
  private final AeadFactory aeadFactory;
  private final int modSizeInBytes;

  private RsaKemHybridDecrypt(
      final PrivateKey recipientPrivateKey,
      String hkdfHmacAlgo,
      final byte[] hkdfSalt,
      AeadFactory aeadFactory)
      throws GeneralSecurityException {
    BigInteger mod = ((RSAKey) recipientPrivateKey).getModulus();
    RsaKem.validateRsaModulus(mod);

    this.recipientPrivateKey = recipientPrivateKey;
    this.hkdfSalt = hkdfSalt;
    this.hkdfHmacAlgo = hkdfHmacAlgo;
    this.aeadFactory = aeadFactory;
    this.modSizeInBytes = RsaKem.bigIntSizeInBytes(mod);
  }

  public RsaKemHybridDecrypt(
      final RSAPrivateKey recipientPrivateKey,
      String hkdfHmacAlgo,
      final byte[] hkdfSalt,
      AeadFactory aeadFactory)
      throws GeneralSecurityException {
    this((PrivateKey) recipientPrivateKey, hkdfHmacAlgo, hkdfSalt, aeadFactory);
  }

  /**
   * This alternate factory method is to support Android KeyStore, whose RSA private key class does
   * not implement RSAPrivateKey.
   *
   * @param recipientPrivateKey should implement both PrivateKey and RSAKey.
   */
  public static RsaKemHybridDecrypt create(
      final PrivateKey recipientPrivateKey,
      String hkdfHmacAlgo,
      final byte[] hkdfSalt,
      AeadFactory aeadFactory)
      throws GeneralSecurityException {
    if (!(recipientPrivateKey instanceof RSAKey)) {
      throw new InvalidKeyException("Must be an RSA private key");
    }
    return new RsaKemHybridDecrypt(recipientPrivateKey, hkdfHmacAlgo, hkdfSalt, aeadFactory);
  }

  @Override
  public byte[] decrypt(final byte[] ciphertext, final byte[] contextInfo)
      throws GeneralSecurityException {
    if (ciphertext.length < modSizeInBytes) {
      throw new GeneralSecurityException(
          String.format(
              "Ciphertext must be of at least size %d bytes, but got %d",
              modSizeInBytes, ciphertext.length));
    }

    // Get the first modSizeInBytes bytes of ciphertext.
    ByteBuffer cipherBuffer = ByteBuffer.wrap(ciphertext);
    byte[] token = new byte[modSizeInBytes];
    cipherBuffer.get(token);

    // Decrypt the token to obtain the raw shared secret.
    byte[] sharedSecret = RsaKem.rsaDecrypt(recipientPrivateKey, token);

    // KDF: derive a DEM key from the shared secret, salt, and contextInfo.
    byte[] demKey =
        Hkdf.computeHkdf(
            hkdfHmacAlgo, sharedSecret, hkdfSalt, contextInfo, aeadFactory.getKeySizeInBytes());

    // DEM: decrypt the payload.
    Aead aead = aeadFactory.createAead(demKey);
    byte[] demPayload = new byte[cipherBuffer.remaining()];
    cipherBuffer.get(demPayload);
    return aead.decrypt(demPayload, RsaKem.EMPTY_AAD);
  }
}
