Skip to content

Commit dc1cccf

Browse files
authored
Handle late-arriving m.room_key.withheld messages (#4310)
* Restructure eventsPendingKey to remove sender key For withheld notices, we don't necessarily receive the sender key, so we'll jhave to do without it. * Re-decrypt events when we receive a withheld notice * Extend test to cover late-arriving withheld notices * update unit tests
1 parent d32f398 commit dc1cccf

File tree

4 files changed

+60
-27
lines changed

4 files changed

+60
-27
lines changed

spec/integ/crypto/crypto.spec.ts

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2343,13 +2343,12 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
23432343
])(
23442344
"Decryption fails with withheld error if a withheld notice with code '%s' is received",
23452345
(withheldCode, expectedMessage, expectedErrorCode) => {
2346-
// TODO: test arrival after the event too.
2347-
it.each(["before"])("%s the event", async (when) => {
2346+
it.each(["before", "after"])("%s the event", async (when) => {
23482347
expectAliceKeyQuery({ device_keys: { "@alice:localhost": {} }, failures: {} });
23492348
await startClientAndAwaitFirstSync();
23502349

23512350
// A promise which resolves, with the MatrixEvent which wraps the event, once the decryption fails.
2352-
const awaitDecryption = emitPromise(aliceClient, MatrixEventEvent.Decrypted);
2351+
let awaitDecryption = emitPromise(aliceClient, MatrixEventEvent.Decrypted);
23532352

23542353
// Send Alice an encrypted room event which looks like it was encrypted with a megolm session
23552354
async function sendEncryptedEvent() {
@@ -2393,6 +2392,9 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
23932392
await sendEncryptedEvent();
23942393
} else {
23952394
await sendEncryptedEvent();
2395+
// Make sure that the first attempt to decrypt has happened before the withheld arrives
2396+
await awaitDecryption;
2397+
awaitDecryption = emitPromise(aliceClient, MatrixEventEvent.Decrypted);
23962398
await sendWithheldMessage();
23972399
}
23982400

spec/unit/rust-crypto/rust-crypto.spec.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ describe("initRustCrypto", () => {
9595
deleteSecretsFromInbox: jest.fn(),
9696
registerReceiveSecretCallback: jest.fn(),
9797
registerDevicesUpdatedCallback: jest.fn(),
98+
registerRoomKeysWithheldCallback: jest.fn(),
9899
outgoingRequests: jest.fn(),
99100
isBackupEnabled: jest.fn().mockResolvedValue(false),
100101
verifyBackup: jest.fn().mockResolvedValue({ trusted: jest.fn().mockReturnValue(false) }),

src/rust-crypto/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ async function initOlmMachine(
174174
await olmMachine.registerRoomKeyUpdatedCallback((sessions: RustSdkCryptoJs.RoomKeyInfo[]) =>
175175
rustCrypto.onRoomKeysUpdated(sessions),
176176
);
177+
await olmMachine.registerRoomKeysWithheldCallback((withheld: RustSdkCryptoJs.RoomKeyWithheldInfo[]) =>
178+
rustCrypto.onRoomKeysWithheld(withheld),
179+
);
177180
await olmMachine.registerUserIdentityUpdatedCallback((userId: RustSdkCryptoJs.UserId) =>
178181
rustCrypto.onUserIdentityUpdated(userId),
179182
);

src/rust-crypto/rust-crypto.ts

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,7 +1486,7 @@ export class RustCrypto extends TypedEventEmitter<RustCryptoEvents, RustCryptoEv
14861486
this.logger.debug(
14871487
`Got update for session ${key.senderKey.toBase64()}|${key.sessionId} in ${key.roomId.toString()}`,
14881488
);
1489-
const pendingList = this.eventDecryptor.getEventsPendingRoomKey(key);
1489+
const pendingList = this.eventDecryptor.getEventsPendingRoomKey(key.roomId.toString(), key.sessionId);
14901490
if (pendingList.length === 0) return;
14911491

14921492
this.logger.debug(
@@ -1507,6 +1507,37 @@ export class RustCrypto extends TypedEventEmitter<RustCryptoEvents, RustCryptoEv
15071507
}
15081508
}
15091509

1510+
/**
1511+
* Callback for `OlmMachine.registerRoomKeyWithheldCallback`.
1512+
*
1513+
* Called by the rust sdk whenever we are told that a key has been withheld. We see if we had any events that
1514+
* failed to decrypt for the given session, and update their status if so.
1515+
*
1516+
* @param withheld - Details of the withheld sessions.
1517+
*/
1518+
public async onRoomKeysWithheld(withheld: RustSdkCryptoJs.RoomKeyWithheldInfo[]): Promise<void> {
1519+
for (const session of withheld) {
1520+
this.logger.debug(`Got withheld message for session ${session.sessionId} in ${session.roomId.toString()}`);
1521+
const pendingList = this.eventDecryptor.getEventsPendingRoomKey(
1522+
session.roomId.toString(),
1523+
session.sessionId,
1524+
);
1525+
if (pendingList.length === 0) return;
1526+
1527+
// The easiest way to update the status of the event is to have another go at decrypting it.
1528+
this.logger.debug(
1529+
"Retrying decryption on events:",
1530+
pendingList.map((e) => `${e.getId()}`),
1531+
);
1532+
1533+
for (const ev of pendingList) {
1534+
ev.attemptDecryption(this, { isRetry: true }).catch((_e) => {
1535+
// It's somewhat expected that we still can't decrypt here.
1536+
});
1537+
}
1538+
}
1539+
}
1540+
15101541
/**
15111542
* Callback for `OlmMachine.registerUserIdentityUpdatedCallback`
15121543
*
@@ -1683,7 +1714,7 @@ class EventDecryptor {
16831714
/**
16841715
* Events which we couldn't decrypt due to unknown sessions / indexes.
16851716
*
1686-
* Map from senderKey to sessionId to Set of MatrixEvents
1717+
* Map from roomId to sessionId to Set of MatrixEvents
16871718
*/
16881719
private eventsPendingKey = new MapWithDefault<string, MapWithDefault<string, Set<MatrixEvent>>>(
16891720
() => new MapWithDefault<string, Set<MatrixEvent>>(() => new Set()),
@@ -1843,54 +1874,50 @@ class EventDecryptor {
18431874
* Look for events which are waiting for a given megolm session
18441875
*
18451876
* Returns a list of events which were encrypted by `session` and could not be decrypted
1846-
*
1847-
* @param session -
18481877
*/
1849-
public getEventsPendingRoomKey(session: RustSdkCryptoJs.RoomKeyInfo): MatrixEvent[] {
1850-
const senderPendingEvents = this.eventsPendingKey.get(session.senderKey.toBase64());
1851-
if (!senderPendingEvents) return [];
1878+
public getEventsPendingRoomKey(roomId: string, sessionId: string): MatrixEvent[] {
1879+
const roomPendingEvents = this.eventsPendingKey.get(roomId);
1880+
if (!roomPendingEvents) return [];
18521881

1853-
const sessionPendingEvents = senderPendingEvents.get(session.sessionId);
1882+
const sessionPendingEvents = roomPendingEvents.get(sessionId);
18541883
if (!sessionPendingEvents) return [];
18551884

1856-
const roomId = session.roomId.toString();
1857-
return [...sessionPendingEvents].filter((ev) => ev.getRoomId() === roomId);
1885+
return [...sessionPendingEvents];
18581886
}
18591887

18601888
/**
18611889
* Add an event to the list of those awaiting their session keys.
18621890
*/
18631891
private addEventToPendingList(event: MatrixEvent): void {
1864-
const content = event.getWireContent();
1865-
const senderKey = content.sender_key;
1866-
const sessionId = content.session_id;
1892+
const roomId = event.getRoomId();
1893+
// We shouldn't have events without a room id here.
1894+
if (!roomId) return;
18671895

1868-
const senderPendingEvents = this.eventsPendingKey.getOrCreate(senderKey);
1869-
const sessionPendingEvents = senderPendingEvents.getOrCreate(sessionId);
1896+
const roomPendingEvents = this.eventsPendingKey.getOrCreate(roomId);
1897+
const sessionPendingEvents = roomPendingEvents.getOrCreate(event.getWireContent().session_id);
18701898
sessionPendingEvents.add(event);
18711899
}
18721900

18731901
/**
18741902
* Remove an event from the list of those awaiting their session keys.
18751903
*/
18761904
private removeEventFromPendingList(event: MatrixEvent): void {
1877-
const content = event.getWireContent();
1878-
const senderKey = content.sender_key;
1879-
const sessionId = content.session_id;
1905+
const roomId = event.getRoomId();
1906+
if (!roomId) return;
18801907

1881-
const senderPendingEvents = this.eventsPendingKey.get(senderKey);
1882-
if (!senderPendingEvents) return;
1908+
const roomPendingEvents = this.eventsPendingKey.getOrCreate(roomId);
1909+
if (!roomPendingEvents) return;
18831910

1884-
const sessionPendingEvents = senderPendingEvents.get(sessionId);
1911+
const sessionPendingEvents = roomPendingEvents.get(event.getWireContent().session_id);
18851912
if (!sessionPendingEvents) return;
18861913

18871914
sessionPendingEvents.delete(event);
18881915

18891916
// also clean up the higher-level maps if they are now empty
18901917
if (sessionPendingEvents.size === 0) {
1891-
senderPendingEvents.delete(sessionId);
1892-
if (senderPendingEvents.size === 0) {
1893-
this.eventsPendingKey.delete(senderKey);
1918+
roomPendingEvents.delete(event.getWireContent().session_id);
1919+
if (roomPendingEvents.size === 0) {
1920+
this.eventsPendingKey.delete(roomId);
18941921
}
18951922
}
18961923
}

0 commit comments

Comments
 (0)