diff --git a/native/expo-modules/comm-expo-package/android/src/main/java/app/comm/android/aescrypto/AESCryptoModule.kt b/native/expo-modules/comm-expo-package/android/src/main/java/app/comm/android/aescrypto/AESCryptoModule.kt --- a/native/expo-modules/comm-expo-package/android/src/main/java/app/comm/android/aescrypto/AESCryptoModule.kt +++ b/native/expo-modules/comm-expo-package/android/src/main/java/app/comm/android/aescrypto/AESCryptoModule.kt @@ -5,6 +5,7 @@ import expo.modules.kotlin.modules.ModuleDefinition import expo.modules.kotlin.typedarray.Uint8Array import java.security.SecureRandom +import java.nio.ByteBuffer import javax.crypto.Cipher import javax.crypto.KeyGenerator import javax.crypto.SecretKey @@ -64,18 +65,11 @@ plaintext: Uint8Array, destination: Uint8Array ) { - if (destination.length != plaintext.length + IV_LENGTH + TAG_LENGTH) { - throw InvalidDataLengthException() - } - val key = rawKey.toAESSecretKey() val plaintextBuffer = plaintext.toDirectBuffer() - val plaintextBytes = ByteArray(plaintext.byteLength) - .also(plaintextBuffer::get) - val (iv, ciphertextWithTag) = encryptAES(plaintextBytes, key) + val destinationBuffer = destination.toDirectBuffer() - destination.write(iv, position = 0, size = IV_LENGTH) - destination.write(ciphertextWithTag, IV_LENGTH, ciphertextWithTag.size) + encryptAES(plaintextBuffer, key, destinationBuffer) } /** @@ -92,17 +86,11 @@ sealedData: Uint8Array, destination: Uint8Array ) { - if (destination.byteLength - != sealedData.byteLength - IV_LENGTH - TAG_LENGTH) { - throw InvalidDataLengthException() - } val key = rawKey.toAESSecretKey() - val input = sealedData.toDirectBuffer() - val iv = ByteArray(IV_LENGTH).also(input::get) - val ciphertextWithTagBytes = ByteArray(input.remaining()).also(input::get) - val plaintext = decryptAES(ciphertextWithTagBytes, key, iv) + val sealedDataBuffer = sealedData.toDirectBuffer() + val destinationBuffer = destination.toDirectBuffer() - destination.write(plaintext, position = 0, size = plaintext.size) + decryptAES(sealedDataBuffer, key, destinationBuffer) } // endregion @@ -125,8 +113,9 @@ plaintext: ByteArray, ): ByteArray { val secretKey = rawKey.toSecretKey() - val (iv, ciphertextWithTag) = encryptAES(plaintext, secretKey) - return iv + ciphertextWithTag + val plaintextBuffer = ByteBuffer.wrap(plaintext) + val cipherText = encryptAES(plaintextBuffer, secretKey) + return ByteArray(cipherText.remaining()).also(cipherText::get) } public fun decrypt( @@ -137,45 +126,70 @@ throw InvalidDataLengthException() } val secretKey = rawKey.toSecretKey() - val iv = sealedData.copyOfRange(0, IV_LENGTH) - val ciphertextWithTag = sealedData.copyOfRange(IV_LENGTH, sealedData.size) - return decryptAES(ciphertextWithTag, secretKey, iv) + val sealedDataBuffer = ByteBuffer.wrap(sealedData) + val plaintext = decryptAES(sealedDataBuffer, secretKey) + return ByteArray(plaintext.remaining()).also(plaintext::get) } } /** - * Encrypts given [plaintext] with given [key] using AES-256 GCM algorithm + * Encrypts given [plaintext] with given [key] using AES-256 GCM algorithm. + * You can optionally pass in a destination buffer of a correct length + * that will be filled in with the result of the encryption. * - * @return A pair of: - * - IV (initialization vector) - 12 bytes long - * - [ByteArray] containing ciphertext with 16-byte GCM auth tag appended + * @return reference to the passed destination buffer or a new [ByteBuffer], + * containing sealed data consisting of the following, concatenated in order: + * - IV - 12 bytes long + * - Ciphertext with 16-byte GCM tag */ private fun encryptAES( - plaintext: ByteArray, - key: SecretKey -): Pair<ByteArray, ByteArray> { + plaintext: ByteBuffer, + key: SecretKey, + destination: ByteBuffer? = null +): ByteBuffer { + if (destination != null && + destination.remaining() != plaintext.remaining() + IV_LENGTH + TAG_LENGTH) { + throw InvalidDataLengthException() + } val cipher = Cipher.getInstance(CIPHER_TRANSFORMATION_NAME).apply { init(Cipher.ENCRYPT_MODE, key) } - val iv = cipher.iv.copyOf() - val ciphertextWithTag = cipher.doFinal(plaintext) - return Pair(iv, ciphertextWithTag) + val iv = cipher.iv + val sealedData = destination ?: + ByteBuffer.allocate(iv.size + cipher.getOutputSize(plaintext.remaining())) + sealedData.put(iv); + cipher.doFinal(plaintext, sealedData) + sealedData.position(0) + return sealedData } /** * Does the reverse of the [encryptAES] function. * Decrypts the [ciphertext] with given [key] and [iv] + * You can optionally pass in a destination buffer of a correct length + * that will be filled in with the result of the decryption. + * + * @return reference to the passed destination buffer or a new [ByteBuffer] */ private fun decryptAES( - ciphertextWithTag: ByteArray, + sealedData: ByteBuffer, key: SecretKey, - iv: ByteArray -): ByteArray { + destination: ByteBuffer? = null +): ByteBuffer { + if (destination != null && + destination.remaining() != sealedData.remaining() - IV_LENGTH - TAG_LENGTH) { + throw InvalidDataLengthException() + } + val iv = ByteArray(IV_LENGTH).also(sealedData::get) val spec = GCMParameterSpec(TAG_LENGTH * 8, iv) val cipher = Cipher.getInstance(CIPHER_TRANSFORMATION_NAME).apply { init(Cipher.DECRYPT_MODE, key, spec) } - return cipher.doFinal(ciphertextWithTag) + val plaintext = destination ?: + ByteBuffer.allocate(cipher.getOutputSize(sealedData.remaining())) + cipher.doFinal(sealedData, plaintext) + plaintext.position(0) + return plaintext } // endregion