diff --git a/native/cpp/CommonCpp/CryptoTools/CryptoModule.h b/native/cpp/CommonCpp/CryptoTools/CryptoModule.h --- a/native/cpp/CommonCpp/CryptoTools/CryptoModule.h +++ b/native/cpp/CommonCpp/CryptoTools/CryptoModule.h @@ -67,6 +67,7 @@ const OlmBuffer &oneTimeKey); bool hasSessionFor(const std::string &targetDeviceId); std::shared_ptr getSessionByDeviceId(const std::string &deviceId); + void removeSessionByDeviceId(const std::string &deviceId); Persist storeAsB64(const std::string &secretKey); void restoreFromB64(const std::string &secretKey, Persist persist); diff --git a/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp b/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp --- a/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp +++ b/native/cpp/CommonCpp/CryptoTools/CryptoModule.cpp @@ -309,6 +309,10 @@ return this->sessions.at(deviceId); } +void CryptoModule::removeSessionByDeviceId(const std::string &deviceId) { + this->sessions.erase(deviceId); +} + Persist CryptoModule::storeAsB64(const std::string &secretKey) { Persist persist; size_t accountPickleLength = @@ -355,11 +359,8 @@ std::unordered_map::iterator it; for (it = persist.sessions.begin(); it != persist.sessions.end(); ++it) { - std::unique_ptr session = session->restoreFromB64( - this->getOlmAccount(), - this->keys.identityKeys.data(), - secretKey, - it->second); + std::unique_ptr session = + session->restoreFromB64(secretKey, it->second); this->sessions.insert(make_pair(it->first, move(session))); } } @@ -398,48 +399,7 @@ if (!this->hasSessionFor(targetDeviceId)) { throw std::runtime_error{"error decrypt => uninitialized session"}; } - OlmSession *session = this->sessions.at(targetDeviceId)->getOlmSession(); - - OlmBuffer utilityBuffer(::olm_utility_size()); - OlmUtility *olmUtility = ::olm_utility(utilityBuffer.data()); - - OlmBuffer messageHashBuffer(::olm_sha256_length(olmUtility)); - ::olm_sha256( - olmUtility, - encryptedData.message.data(), - encryptedData.message.size(), - messageHashBuffer.data(), - messageHashBuffer.size()); - - OlmBuffer tmpEncryptedMessage(encryptedData.message); - size_t maxSize = ::olm_decrypt_max_plaintext_length( - session, - encryptedData.messageType, - tmpEncryptedMessage.data(), - tmpEncryptedMessage.size()); - - if (maxSize == -1) { - throw std::runtime_error{ - "error decrypt_max_plaintext_length => " + - std::string{::olm_session_last_error(session)} + ". Hash: " + - std::string{messageHashBuffer.begin(), messageHashBuffer.end()}}; - } - - OlmBuffer decryptedMessage(maxSize); - size_t decryptedSize = ::olm_decrypt( - session, - encryptedData.messageType, - encryptedData.message.data(), - encryptedData.message.size(), - decryptedMessage.data(), - decryptedMessage.size()); - if (decryptedSize == -1) { - throw std::runtime_error{ - "error decrypt => " + std::string{::olm_session_last_error(session)} + - ". Hash: " + - std::string{messageHashBuffer.begin(), messageHashBuffer.end()}}; - } - return std::string{(char *)decryptedMessage.data(), decryptedSize}; + return this->sessions.at(targetDeviceId)->decrypt(encryptedData); } std::string CryptoModule::signMessage(const std::string &message) { diff --git a/native/cpp/CommonCpp/CryptoTools/Session.h b/native/cpp/CommonCpp/CryptoTools/Session.h --- a/native/cpp/CommonCpp/CryptoTools/Session.h +++ b/native/cpp/CommonCpp/CryptoTools/Session.h @@ -11,15 +11,8 @@ namespace crypto { class Session { - OlmAccount *ownerUserAccount; - std::uint8_t *ownerIdentityKeys; - OlmBuffer olmSessionBuffer; - Session(OlmAccount *account, std::uint8_t *ownerIdentityKeys) - : ownerUserAccount(account), ownerIdentityKeys(ownerIdentityKeys) { - } - public: static std::unique_ptr createSessionAsInitializer( OlmAccount *account, @@ -34,12 +27,10 @@ const OlmBuffer &encryptedMessage, const OlmBuffer &idKeys); OlmBuffer storeAsB64(const std::string &secretKey); - static std::unique_ptr restoreFromB64( - OlmAccount *account, - std::uint8_t *ownerIdentityKeys, - const std::string &secretKey, - OlmBuffer &b64); + static std::unique_ptr + restoreFromB64(const std::string &secretKey, OlmBuffer &b64); OlmSession *getOlmSession(); + std::string decrypt(EncryptedData &encryptedData); }; } // namespace crypto diff --git a/native/cpp/CommonCpp/CryptoTools/Session.cpp b/native/cpp/CommonCpp/CryptoTools/Session.cpp --- a/native/cpp/CommonCpp/CryptoTools/Session.cpp +++ b/native/cpp/CommonCpp/CryptoTools/Session.cpp @@ -17,7 +17,7 @@ const OlmBuffer &preKeys, const OlmBuffer &preKeySignature, const OlmBuffer &oneTimeKey) { - std::unique_ptr session(new Session(account, ownerIdentityKeys)); + std::unique_ptr session(new Session()); session->olmSessionBuffer.resize(::olm_session_size()); ::olm_session(session->olmSessionBuffer.data()); @@ -30,7 +30,7 @@ if (-1 == ::olm_create_outbound_session( session->getOlmSession(), - session->ownerUserAccount, + account, idKeys.data() + ID_KEYS_PREFIX_OFFSET, KEYSIZE, idKeys.data() + SIGNING_KEYS_PREFIX_OFFSET, @@ -55,7 +55,7 @@ std::uint8_t *ownerIdentityKeys, const OlmBuffer &encryptedMessage, const OlmBuffer &idKeys) { - std::unique_ptr session(new Session(account, ownerIdentityKeys)); + std::unique_ptr session(new Session()); OlmBuffer tmpEncryptedMessage(encryptedMessage); session->olmSessionBuffer.resize(::olm_session_size()); @@ -63,7 +63,7 @@ if (-1 == ::olm_create_inbound_session( session->getOlmSession(), - session->ownerUserAccount, + account, tmpEncryptedMessage.data(), encryptedMessage.size())) { throw std::runtime_error( @@ -94,12 +94,9 @@ return pickle; } -std::unique_ptr Session::restoreFromB64( - OlmAccount *account, - std::uint8_t *ownerIdentityKeys, - const std::string &secretKey, - OlmBuffer &b64) { - std::unique_ptr session(new Session(account, ownerIdentityKeys)); +std::unique_ptr +Session::restoreFromB64(const std::string &secretKey, OlmBuffer &b64) { + std::unique_ptr session(new Session()); session->olmSessionBuffer.resize(::olm_session_size()); ::olm_session(session->olmSessionBuffer.data()); @@ -115,5 +112,50 @@ return session; } +std::string Session::decrypt(EncryptedData &encryptedData) { + OlmSession *session = this->getOlmSession(); + + OlmBuffer utilityBuffer(::olm_utility_size()); + OlmUtility *olmUtility = ::olm_utility(utilityBuffer.data()); + + OlmBuffer messageHashBuffer(::olm_sha256_length(olmUtility)); + ::olm_sha256( + olmUtility, + encryptedData.message.data(), + encryptedData.message.size(), + messageHashBuffer.data(), + messageHashBuffer.size()); + + OlmBuffer tmpEncryptedMessage(encryptedData.message); + size_t maxSize = ::olm_decrypt_max_plaintext_length( + session, + encryptedData.messageType, + tmpEncryptedMessage.data(), + tmpEncryptedMessage.size()); + + if (maxSize == -1) { + throw std::runtime_error{ + "error decrypt_max_plaintext_length => " + + std::string{::olm_session_last_error(session)} + ". Hash: " + + std::string{messageHashBuffer.begin(), messageHashBuffer.end()}}; + } + + OlmBuffer decryptedMessage(maxSize); + size_t decryptedSize = ::olm_decrypt( + session, + encryptedData.messageType, + encryptedData.message.data(), + encryptedData.message.size(), + decryptedMessage.data(), + decryptedMessage.size()); + if (decryptedSize == -1) { + throw std::runtime_error{ + "error decrypt => " + std::string{::olm_session_last_error(session)} + + ". Hash: " + + std::string{messageHashBuffer.begin(), messageHashBuffer.end()}}; + } + return std::string{(char *)decryptedMessage.data(), decryptedSize}; +} + } // namespace crypto } // namespace comm