diff --git a/native/expo-modules/comm-expo-package/android/src/main/java/app/comm/android/aescrypto/AESCryptoModule.kt b/native/expo-modules/comm-expo-package/android/src/main/java/app/comm/android/aescrypto/AESCryptoModule.kt --- a/native/expo-modules/comm-expo-package/android/src/main/java/app/comm/android/aescrypto/AESCryptoModule.kt +++ b/native/expo-modules/comm-expo-package/android/src/main/java/app/comm/android/aescrypto/AESCryptoModule.kt @@ -11,6 +11,8 @@ import javax.crypto.SecretKey import javax.crypto.spec.GCMParameterSpec import javax.crypto.spec.SecretKeySpec +import android.util.Log +import java.util.Base64 private const val ALGORITHM_AES = "AES" private const val CIPHER_TRANSFORMATION_NAME = "AES/GCM/NoPadding" @@ -26,6 +28,7 @@ Name("AESCrypto") Function("generateKey", this@AESCryptoModule::generateKey) + Function("generateIV", this@AESCryptoModule::generateIV) Function("encrypt", this@AESCryptoModule::encrypt) Function("decrypt", this@AESCryptoModule::decrypt) } @@ -49,6 +52,16 @@ destination.write(keyBytes, position = 0, size = keyBytes.size) } + private fun generateIV(destination: Uint8Array) { + if (destination.byteLength != IV_LENGTH) { + throw InvalidInitializationVectorLengthException() + } + + val randomBytes = ByteArray(IV_LENGTH) + secureRandom.nextBytes(randomBytes) + destination.write(randomBytes, position = 0, size = randomBytes.size) + } + /** * Encrypts given [plaintext] with provided key and saves encrypted results * (sealed data) into [destination]. After the encryption, the destination @@ -63,13 +76,20 @@ private fun encrypt( rawKey: Uint8Array, plaintext: Uint8Array, - destination: Uint8Array + destination: Uint8Array, + initializationVector: Uint8Array, ) { val key = rawKey.toAESSecretKey() val plaintextBuffer = plaintext.toDirectBuffer() val destinationBuffer = destination.toDirectBuffer() + + val ivBuffer = if (initializationVector.byteLength > 0) { + initializationVector.toDirectBuffer() + } else { + null + } - encryptAES(plaintextBuffer, key, destinationBuffer) + encryptAES(plaintextBuffer, key, destinationBuffer, ivBuffer) } /** @@ -177,16 +197,32 @@ private fun encryptAES( plaintext: ByteBuffer, key: SecretKey, - destination: ByteBuffer? = null + destination: ByteBuffer? = null, + initializationVector: ByteBuffer? = null ): ByteBuffer { if (destination != null && destination.remaining() != plaintext.remaining() + IV_LENGTH + TAG_LENGTH) { throw InvalidDataLengthException() } - val cipher = Cipher.getInstance(CIPHER_TRANSFORMATION_NAME).apply { - init(Cipher.ENCRYPT_MODE, key) + + if(initializationVector != null && initializationVector.remaining() != IV_LENGTH) { + throw InvalidInitializationVectorLengthException() } - val iv = cipher.iv + + val cipher = if (initializationVector != null) { + val customIV = ByteArray(initializationVector.remaining()).also(initializationVector::get); + val spec = GCMParameterSpec(TAG_LENGTH * 8, customIV); + + Cipher.getInstance(CIPHER_TRANSFORMATION_NAME).apply { + init(Cipher.ENCRYPT_MODE, key, spec) + } + } else { + Cipher.getInstance(CIPHER_TRANSFORMATION_NAME).apply { + init(Cipher.ENCRYPT_MODE, key) + } + } + + val iv = cipher.iv; val sealedData = destination ?: ByteBuffer.allocate(iv.size + cipher.getOutputSize(plaintext.remaining())) sealedData.put(iv); @@ -258,4 +294,7 @@ private class InvalidDataLengthException : CodedException("Source or destination array has invalid length") +private class InvalidInitializationVectorLengthException : + CodedException("Initialization vector has invalid length") + // endregion diff --git a/native/expo-modules/comm-expo-package/ios/AESCryptoModule.swift b/native/expo-modules/comm-expo-package/ios/AESCryptoModule.swift --- a/native/expo-modules/comm-expo-package/ios/AESCryptoModule.swift +++ b/native/expo-modules/comm-expo-package/ios/AESCryptoModule.swift @@ -15,14 +15,22 @@ byteLength: destination.byteLength) } + Function("generateIV") { (destination: Uint8Array) throws in + try generateIV(destinationPtr: destination.rawBufferPtr(), + byteLength: destination.byteLength) + } + Function("encrypt") { (rawKey: Uint8Array, plaintext: Uint8Array, - destination: Uint8Array) throws in + destination: Uint8Array, + initializationVector: Uint8Array) throws in try encrypt(rawKey: rawKey.data(), plaintext: plaintext.data(), plaintextLength: plaintext.byteLength, destinationPtr: destination.rawBufferPtr(), - destinationLength: destination.byteLength) + destinationLength: destination.byteLength, + initializationVector: initializationVector.byteLength > 0 + ? initializationVector.data(): nil) } Function("decrypt") { (rawKey: Uint8Array, @@ -138,18 +146,36 @@ } } +private func generateIV(destinationPtr: UnsafeMutableRawBufferPointer, + byteLength: Int) throws { + guard byteLength == IV_LENGTH else { + throw InvalidInitializationVectorLengthException() + } + + let iv = AES.GCM.Nonce() + iv.withUnsafeBytes { bytes in + let _ = bytes.copyBytes(to: destinationPtr) + } +} + private func encrypt(rawKey: Data, plaintext: Data, plaintextLength: Int, destinationPtr: UnsafeMutableRawBufferPointer, - destinationLength: Int) throws { + destinationLength: Int, + initializationVector: Data? = nil) throws { guard destinationLength == plaintextLength + IV_LENGTH + TAG_LENGTH else { throw InvalidDataLengthException() } let key = SymmetricKey(data: rawKey) - let iv = AES.GCM.Nonce() + let iv = if let data = initializationVector { + try AES.GCM.Nonce(data: data) + } else { + AES.GCM.Nonce() + } + let encryptionResult = try AES.GCM.seal(plaintext, using: key, nonce: iv) @@ -195,6 +221,12 @@ } } +private class InvalidInitializationVectorLengthException: Exception { + override var reason: String { + "Initialization vector has invalid length" + } +} + private class EncryptionFailedException: GenericException { override var reason: String { "Failed to encrypt data: \(param)" diff --git a/native/utils/aes-crypto-module.js b/native/utils/aes-crypto-module.js --- a/native/utils/aes-crypto-module.js +++ b/native/utils/aes-crypto-module.js @@ -9,10 +9,12 @@ const AESCryptoModule: { +generateKey: (destination: Uint8Array) => void, + +generateIV: (destination: Uint8Array) => void, +encrypt: ( key: Uint8Array, data: Uint8Array, destination: Uint8Array, + initializationVector: Uint8Array, ) => void, +decrypt: ( key: Uint8Array, @@ -27,9 +29,20 @@ return keyBuffer; } -export function encrypt(key: Uint8Array, data: Uint8Array): Uint8Array { +export function generateIV(): Uint8Array { + const ivBuffer = new Uint8Array(IV_LENGTH); + AESCryptoModule.generateIV(ivBuffer); + return ivBuffer; +} + +export function encrypt( + key: Uint8Array, + data: Uint8Array, + initializationVector?: ?Uint8Array, +): Uint8Array { const sealedDataBuffer = new Uint8Array(data.length + IV_LENGTH + TAG_LENGTH); - AESCryptoModule.encrypt(key, data, sealedDataBuffer); + const iv = initializationVector ?? new Uint8Array(0); + AESCryptoModule.encrypt(key, data, sealedDataBuffer, iv); return sealedDataBuffer; }