diff --git a/packages/lib/utils/ipc/RemoteMessenger.test.ts b/packages/lib/utils/ipc/RemoteMessenger.test.ts index cf81ac808..5b558f71d 100644 --- a/packages/lib/utils/ipc/RemoteMessenger.test.ts +++ b/packages/lib/utils/ipc/RemoteMessenger.test.ts @@ -147,4 +147,34 @@ describe('RemoteMessenger', () => { expect(await remoteApi.subObject.multiplyRounded(1.1, 2)).toBe(2); expect(await remoteApi.subObject.multiplyRounded.call(remoteApi.subObject, 3.1, 4.2)).toBe(12); }); + + it('should delete callbacks when dropped remotely', async () => { + const testApi = { + test: jest.fn(), + }; + + type ApiType = typeof testApi; + const messenger1 = new TestMessenger('testid', testApi); + const messenger2 = new TestMessenger('testid', testApi); + + messenger1.connectTo(messenger2); + + const callback = async () => {}; + messenger1.remoteApi.test(callback); + + // Callbacks should be stored with the source messenger + const callbackId = messenger1.getIdForCallback_(callback); + expect(callbackId).toBeTruthy(); + expect(messenger2.getIdForCallback_(callback)).toBe(undefined); + + // Dropping a callback at the remote messenger should clear the + // callback on the original messenger + messenger2.mockCallbackDropped(callbackId); + + // To avoid random test failure, wait for a round-tip before checking + // whether the callback is still registered. + await messenger1.remoteApi.test(async ()=>{}); + + expect(messenger1.getIdForCallback_(callback)).toBe(undefined); + }); }); diff --git a/packages/lib/utils/ipc/RemoteMessenger.ts b/packages/lib/utils/ipc/RemoteMessenger.ts index eaa376768..627af1403 100644 --- a/packages/lib/utils/ipc/RemoteMessenger.ts +++ b/packages/lib/utils/ipc/RemoteMessenger.ts @@ -9,6 +9,7 @@ enum MessageType { ErrorResponse = 'ErrorResponse', ReturnValueResponse = 'ReturnValueResponse', CloseChannel = 'CloseChannel', + OnCallbackDropped = 'OnCallbackDropped', } type RemoteReadyMessage = Readonly<{ @@ -55,11 +56,16 @@ type CloseChannelMessage = Readonly<{ kind: MessageType.CloseChannel; }>; +type CallbackDroppedMessage = Readonly<{ + kind: MessageType.OnCallbackDropped; + callbackIds: string[]; +}>; + type BaseMessage = Readonly<{ channelId: string; }>; -type InternalMessage = (RemoteReadyMessage|CloseChannelMessage|InvokeMethodMessage|ErrorResponse|ReturnValueResponse) & BaseMessage; +type InternalMessage = (RemoteReadyMessage|CloseChannelMessage|InvokeMethodMessage|ErrorResponse|ReturnValueResponse|CallbackDroppedMessage) & BaseMessage; // Listeners for a remote method to resolve or reject. type OnMethodResolveListener = (returnValue: SerializableDataAndCallbacks)=> void; @@ -68,6 +74,13 @@ type OnRemoteReadyListener = ()=> void; type OnAllMethodsRespondedToListener = ()=> void; +// TODO: Remove after upgrading nodejs/browser types sufficiently +// (FinalizationRegistry is supported in modern browsers). +declare class FinalizationRegistry { + public constructor(onDrop: any); + public register(v: any, id: string): void; +} + // A thin wrapper around postMessage. A script within `targetWindow` should // also construct a RemoteMessenger (with IncomingMessageType and // OutgoingMessageType reversed). @@ -75,6 +88,7 @@ export default abstract class RemoteMessenger { private resolveMethodCallbacks: Record = Object.create(null); private rejectMethodCallbacks: Record = Object.create(null); private argumentCallbacks: Map = new Map(); + private callbackTracker: FinalizationRegistry|undefined = undefined; private numberUnrespondedToMethods = 0; private noWaitingMethodsListeners: OnAllMethodsRespondedToListener[] = []; @@ -121,6 +135,14 @@ export default abstract class RemoteMessenger { }); }; this.remoteApi = makeApiFor([]) as RemoteInterface; + + if (typeof FinalizationRegistry !== 'undefined') { + // Creating a FinalizationRegistry allows us to track **local** deletions of callbacks. + // We can then inform the remote so that it can free the corresponding remote callback. + this.callbackTracker = new FinalizationRegistry((callbackId: string) => { + this.dropRemoteCallback_(callbackId); + }); + } } private createResponseId(methodPath: string[]) { @@ -133,6 +155,26 @@ export default abstract class RemoteMessenger { } } + private lastCallbackDropTime_ = 0; + private bufferedDroppedCallbackIds_: string[] = []; + // protected: For testing + protected dropRemoteCallback_(callbackId: string) { + this.bufferedDroppedCallbackIds_.push(callbackId); + if (!this.isRemoteReady) return; + // Don't send too many messages. On mobile platforms, each + // message has overhead and .dropRemoteCallback is called + // frequently. + if (Date.now() - this.lastCallbackDropTime_ < 10000) return; + + this.postMessage({ + kind: MessageType.OnCallbackDropped, + callbackIds: this.bufferedDroppedCallbackIds_, + channelId: this.channelId, + }); + this.bufferedDroppedCallbackIds_ = []; + this.lastCallbackDropTime_ = Date.now(); + } + private async invokeRemoteMethod(methodPath: string[], args: SerializableDataAndCallbacks[]) { // Function arguments can't be transferred using standard .postMessage calls. // As such, we assign them IDs and transfer the IDs instead: @@ -191,6 +233,10 @@ export default abstract class RemoteMessenger { return this.invokeRemoteMethod(['__callbacks', callbackId], callbackArgs); }; + private trackCallbackFinalization = (callbackId: string, callback: any) => { + this.callbackTracker?.register(callback, callbackId); + }; + // Calls a local method and sends the result to the remote connection. private async invokeLocalMethod(message: InvokeMethodMessage) { try { @@ -239,6 +285,7 @@ export default abstract class RemoteMessenger { message.arguments.serializable, message.arguments.callbacks, this.onInvokeCallback, + this.trackCallbackFinalization, ); let result; @@ -336,6 +383,7 @@ export default abstract class RemoteMessenger { message.returnValue.serializable, message.returnValue.callbacks, this.onInvokeCallback, + this.trackCallbackFinalization, ); this.resolveMethodCallbacks[message.responseId](returnValue); @@ -347,6 +395,12 @@ export default abstract class RemoteMessenger { this.onMethodRespondedTo(message.responseId); } + private async onRemoteCallbackDropped(message: CallbackDroppedMessage) { + for (const id of message.callbackIds) { + this.argumentCallbacks.delete(id); + } + } + private async onRemoteReadyToReceive(message: RemoteReadyMessage) { if (this.isRemoteReady && !message.requiresResponse) { return; @@ -431,6 +485,8 @@ export default abstract class RemoteMessenger { await this.onRemoteReject(asInternalMessage); } else if (asInternalMessage.kind === MessageType.RemoteReady) { await this.onRemoteReadyToReceive(asInternalMessage); + } else if (asInternalMessage.kind === MessageType.OnCallbackDropped) { + await this.onRemoteCallbackDropped(asInternalMessage); } else { // Have TypeScript verify that the above cases are exhaustive const exhaustivenessCheck: never = asInternalMessage; @@ -494,4 +550,15 @@ export default abstract class RemoteMessenger { protected abstract postMessage(message: InternalMessage): void; protected abstract onClose(): void; + + + // For testing + public getIdForCallback_(callback: TransferableCallback) { + for (const [id, otherCallback] of this.argumentCallbacks) { + if (otherCallback === callback) { + return id; + } + } + return undefined; + } } diff --git a/packages/lib/utils/ipc/TestMessenger.ts b/packages/lib/utils/ipc/TestMessenger.ts index 070d817bc..6703bd865 100644 --- a/packages/lib/utils/ipc/TestMessenger.ts +++ b/packages/lib/utils/ipc/TestMessenger.ts @@ -33,4 +33,11 @@ export default class TestMessenger extends Remo protected override onClose(): void { this.remoteMessenger = null; } + + + // Test utility methods + // + public mockCallbackDropped(callbackId: string) { + this.dropRemoteCallback_(callbackId); + } } diff --git a/packages/lib/utils/ipc/utils/mergeCallbacksAndSerializable.test.ts b/packages/lib/utils/ipc/utils/mergeCallbacksAndSerializable.test.ts index 651ab0fad..1103bf3b8 100644 --- a/packages/lib/utils/ipc/utils/mergeCallbacksAndSerializable.test.ts +++ b/packages/lib/utils/ipc/utils/mergeCallbacksAndSerializable.test.ts @@ -27,7 +27,7 @@ describe('mergeCallbacksAndSerializable', () => { }; const callMethodWithId = jest.fn(); - const merged: any = mergeCallbacksAndSerializable(data, callbacks, callMethodWithId); + const merged: any = mergeCallbacksAndSerializable(data, callbacks, callMethodWithId, ()=>{}); // Should have created functions merged.foo.fn1(3, 4); diff --git a/packages/lib/utils/ipc/utils/mergeCallbacksAndSerializable.ts b/packages/lib/utils/ipc/utils/mergeCallbacksAndSerializable.ts index 33d790512..bfdad4e79 100644 --- a/packages/lib/utils/ipc/utils/mergeCallbacksAndSerializable.ts +++ b/packages/lib/utils/ipc/utils/mergeCallbacksAndSerializable.ts @@ -2,6 +2,9 @@ import { CallbackIds, SerializableData, SerializableDataAndCallbacks } from '../ type CallMethodWithIdCallback = (id: string, args: SerializableDataAndCallbacks[])=> Promise; +// Intended to be used to track callbacks for garbage collection +type OnAfterCallbackCreated = (callbackId: string, callbackRef: ()=> any)=> void; + // Below, we use TypeScript syntax to specify the return type of mergeCallbacksAndSerializable // based on the type of its arguments. // @@ -10,25 +13,37 @@ type CallMethodWithIdCallback = (id: string, args: SerializableDataAndCallbacks[ // // eslint-disable-next-line no-redeclare function mergeCallbacksAndSerializable( - serializable: SerializableData[], callbacks: CallbackIds[], callMethodWithId: CallMethodWithIdCallback, + serializable: SerializableData[], + callbacks: CallbackIds[], + callMethodWithId: CallMethodWithIdCallback, + afterCallbackCreated: OnAfterCallbackCreated, ): SerializableDataAndCallbacks[]; // eslint-disable-next-line no-redeclare function mergeCallbacksAndSerializable( - serializable: SerializableData, callbacks: CallbackIds, callMethodWithId: CallMethodWithIdCallback, + serializable: SerializableData, + callbacks: CallbackIds, + callMethodWithId: CallMethodWithIdCallback, + afterCallbackCreated: OnAfterCallbackCreated, ): SerializableDataAndCallbacks; // eslint-disable-next-line no-redeclare function mergeCallbacksAndSerializable( - serializable: SerializableData|SerializableData[], callbacks: CallbackIds|CallbackIds[], callMethodWithId: CallMethodWithIdCallback, + serializable: SerializableData|SerializableData[], + callbacks: CallbackIds|CallbackIds[], + callMethodWithId: CallMethodWithIdCallback, + afterCallbackCreated: OnAfterCallbackCreated, ): SerializableDataAndCallbacks|SerializableDataAndCallbacks[] { const mergeCallbackAndSerializable = (serializableObj: SerializableData, callbackObj: CallbackIds): SerializableDataAndCallbacks => { if (typeof callbackObj === 'string') { const callbackId = callbackObj; - return (...args: SerializableDataAndCallbacks[]) => { + const callback = (...args: SerializableDataAndCallbacks[]) => { return callMethodWithId(callbackId, args); }; + afterCallbackCreated(callbackId, callback); + + return callback; } else if (typeof serializableObj === 'object' && serializableObj !== null) { // typeof(null) is object if (typeof callbackObj !== 'object') { throw new Error('Callback arguments should be an object (and thus match the type of serializableArgs)');