diff --git a/CODEOWNERS b/CODEOWNERS index c67e8f021b..c1b2870c21 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -93,3 +93,5 @@ /packages/schema-utils @martykan /packages/connect-explorer-theme @martykan + +/packages/websocket-client @mroz22 @marekrjpolak diff --git a/packages/websocket-client/README.md b/packages/websocket-client/README.md new file mode 100644 index 0000000000..72cc5da59e --- /dev/null +++ b/packages/websocket-client/README.md @@ -0,0 +1,6 @@ +# @trezor/websocket-client + +[![NPM](https://img.shields.io/npm/v/@trezor/websocket-client.svg)](https://www.npmjs.org/package/@trezor/websocket-client) +[![Known Vulnerabilities](https://snyk.io/test/github/trezor/trezor-suite/badge.svg?targetFile=packages/websocket-client/package.json)](https://snyk.io/test/github/trezor/trezor-suite/badge.svg?targetFile=packages/websocket-client/package.json) + +Shared websocket client implementation diff --git a/packages/websocket-client/package.json b/packages/websocket-client/package.json new file mode 100644 index 0000000000..30f950dbf5 --- /dev/null +++ b/packages/websocket-client/package.json @@ -0,0 +1,56 @@ +{ + "name": "@trezor/websocket-client", + "version": "1.0.0", + "author": "Trezor ", + "homepage": "https://github.com/trezor/trezor-suite/tree/develop/packages/websocket", + "description": "Shared websocket client implementation", + "npmPublishAccess": "public", + "license": "SEE LICENSE IN LICENSE.md", + "repository": { + "type": "git", + "url": "git://github.com/trezor/trezor-suite.git" + }, + "bugs": { + "url": "https://github.com/trezor/trezor-suite/issues" + }, + "sideEffects": false, + "main": "src/index", + "browser": { + "ws": "./src/ws-browser" + }, + "react-native": { + "__comment__": "Hotfix for issue where RN metro bundler resolve relatives paths wrong", + "ws": "@trezor/websocket-client/src/ws-native.ts" + }, + "publishConfig": { + "main": "./lib/index.js", + "types": "lib/index.d.ts", + "typings": "lib/index.d.ts", + "browser": { + "ws": "./lib/ws-browser.js" + }, + "react-native": { + "__comment__": "Hotfix for issue where RN metro bundler resolve relatives paths wrong", + "ws": "@trezor/websocket-client/lib/ws-native.js" + } + }, + "files": [ + "lib/", + "!**/*.map" + ], + "scripts": { + "depcheck": "yarn g:depcheck", + "test:unit": "jest -c ../../jest.config.base.js", + "type-check": "yarn g:tsc --build", + "build:lib": "yarn g:rimraf lib && yarn g:tsc --build tsconfig.lib.json && ../../scripts/replace-imports.sh ./lib", + "prepublishOnly": "yarn tsx ../../scripts/prepublishNPM.js", + "prepublish": "yarn tsx ../../scripts/prepublish.js" + }, + "dependencies": { + "@trezor/utils": "workspace:*", + "ws": "^8.18.0" + }, + "peerDependencies": { + "tslib": "^2.6.2" + } +} diff --git a/packages/websocket-client/src/client.ts b/packages/websocket-client/src/client.ts new file mode 100644 index 0000000000..199484142a --- /dev/null +++ b/packages/websocket-client/src/client.ts @@ -0,0 +1,244 @@ +import WebSocket from 'ws'; + +import { TypedEmitter, createDeferred, createDeferredManager } from '@trezor/utils'; + +type WebsocketOptions = { + url: string; + timeout?: number; + agent?: WebSocket.ClientOptions['agent']; + headers?: WebSocket.ClientOptions['headers']; +}; + +type Options = WebsocketOptions & { + pingTimeout?: number; + connectionTimeout?: number; + keepAlive?: boolean; + onSending?: (message: Record) => void; +}; + +const DEFAULT_TIMEOUT = 20 * 1000; +const DEFAULT_PING_TIMEOUT = 50 * 1000; + +type WebsocketClientEvents = { + error: string; + disconnected: undefined; +}; + +export type WebsocketRequest = Record; +export type WebsocketResponse = WebSocket.Data; + +export abstract class WebsocketClient> extends TypedEmitter< + Events & WebsocketClientEvents +> { + readonly options: Options; + + public readonly messages; + private readonly emitter: TypedEmitter = this; + + private ws?: WebSocket; + private pingTimeout?: ReturnType; + private connectPromise?: Promise; + + protected abstract createWebsocket(): WebSocket; + protected abstract ping(): Promise; + + constructor(options: Options) { + super(); + this.options = options; + this.messages = createDeferredManager({ + timeout: this.options.timeout || DEFAULT_TIMEOUT, + onTimeout: this.onTimeout.bind(this), + }); + } + + protected initWebsocket({ url, timeout, headers, agent }: WebsocketOptions) { + // url validation + if (typeof url !== 'string') { + throw new Error('websocket_no_url'); + } + if (url.startsWith('http')) { + url = url.replace('http', 'ws'); + } + + return new WebSocket(url, { timeout, headers, agent }); + } + + private setPingTimeout() { + clearTimeout(this.pingTimeout); + + const doPing = () => { + if (this.isConnected()) { + return this.onPing().catch(() => {}); + } + }; + + this.pingTimeout = this.isConnected() + ? setTimeout(doPing, this.options.pingTimeout || DEFAULT_PING_TIMEOUT) + : undefined; + } + + protected onPing() { + return this.ping(); + } + + private onTimeout() { + const { ws } = this; + if (!ws) return; + this.messages.rejectAll(new Error('websocket_timeout')); + ws.close(); + } + + private onError() { + this.onClose(); + } + + protected sendMessage(message: WebsocketRequest) { + const { ws } = this; + if (!ws || !this.isConnected()) throw new Error('websocket_not_initialized'); + const { promiseId, promise } = this.messages.create(); + + const req = { id: promiseId.toString(), ...message }; + + this.setPingTimeout(); + + this.options.onSending?.(message); + + ws.send(JSON.stringify(req)); + + return promise; + } + + protected sendRawMessage(message: WebSocket.Data) { + const { ws } = this; + if (!ws || !this.isConnected()) throw new Error('websocket_not_initialized'); + + ws.send(message, { + binary: typeof message !== 'string', + }); + + this.setPingTimeout(); + } + + // TODO: data type generic + // `messageValidation` - additionally validates `data` in the subclass + // returns `payload` or throws error to automatically resolve/reject pending message + // returns `undefined` to resolve pending message manually in the subclass + protected onMessage( + message: WebsocketResponse, + messageValidation?: (data: Record) => Record | void, + ) { + try { + const data = JSON.parse(message.toString()); + const messageId = Number(data.id); + try { + const payload = messageValidation ? messageValidation(data) : data; + if (payload) { + this.messages.resolve(messageId, payload); + } + } catch (error) { + this.messages.reject(messageId, error); + } + } catch { + // empty + } finally { + this.setPingTimeout(); + } + } + + async connect() { + // if connecting already, just return the promise + if (this.connectPromise) { + return this.connectPromise; + } + + if (this.isConnected()) return Promise.resolve(); + + if (this.ws?.readyState === WebSocket.CLOSING) { + await new Promise(resolve => this.emitter.once('disconnected', resolve)); + } + + // create deferred promise + const dfd = createDeferred(); + this.connectPromise = dfd.promise; + + const ws = this.createWebsocket(); + + // set connection timeout before WebSocket initialization + const connectionTimeout = setTimeout( + () => { + ws.emit('error', new Error('websocket_timeout')); + try { + ws.once('error', () => {}); // hack; ws throws uncaughtably when there's no error listener + ws.close(); + } catch { + // empty + } + }, + this.options.connectionTimeout || this.options.timeout || DEFAULT_TIMEOUT, + ); + + ws.once('error', error => { + clearTimeout(connectionTimeout); + this.onClose(); + dfd.reject(new Error(error.message)); + }); + ws.on('open', () => { + clearTimeout(connectionTimeout); + this.init(); + dfd.resolve(); + }); + + this.ws = ws; + + // wait for onopen event + return dfd.promise.finally(() => { + this.connectPromise = undefined; + }); + } + + private init() { + const { ws } = this; + if (!ws || !this.isConnected()) { + throw Error('Websocket init cannot be called'); + } + + // remove previous listeners and add new listeners + ws.removeAllListeners(); + ws.on('error', _error => this.onError()); + ws.on('message', message => this.onMessage(message)); + ws.on('close', () => { + this.emitter.emit('disconnected'); + this.onClose(); + }); + } + + disconnect() { + if (this.isConnected()) { + const disconnectPromise = new Promise(resolve => { + this.ws?.once('close', resolve); + }); + this.ws?.close(); + + return disconnectPromise; + } + + return Promise.resolve(); + } + + isConnected() { + return this.ws?.readyState === WebSocket.OPEN; + } + + private onClose() { + clearTimeout(this.pingTimeout); + + this.ws?.removeAllListeners(); + this.messages.rejectAll(new Error('Websocket closed unexpectedly')); + } + + dispose() { + this.removeAllListeners(); + this.disconnect(); + this.onClose(); + } +} diff --git a/packages/websocket-client/src/index.ts b/packages/websocket-client/src/index.ts new file mode 100644 index 0000000000..d7240b7a39 --- /dev/null +++ b/packages/websocket-client/src/index.ts @@ -0,0 +1 @@ +export { WebsocketClient, type WebsocketRequest, type WebsocketResponse } from './client'; diff --git a/packages/websocket-client/src/ws-browser.ts b/packages/websocket-client/src/ws-browser.ts new file mode 100644 index 0000000000..e4990b1740 --- /dev/null +++ b/packages/websocket-client/src/ws-browser.ts @@ -0,0 +1,61 @@ +import { EventEmitter } from 'events'; + +/** + * Provides `EventEmitter` interface for native browser `WebSocket`, + * same, as `ws` package provides. + */ +class WSWrapper extends EventEmitter { + private _ws: WebSocket; + static CONNECTING = 0; + static OPEN = 1; + static CLOSING = 2; + static CLOSED = 3; + + constructor(url: string, _protocols: any, _websocketOptions: any) { + super(); + + this._ws = new WebSocket(url); + + this._ws.onclose = () => { + this.emit('close'); + }; + + this._ws.onopen = () => { + this.emit('open'); + }; + + // WebSocket error Event does not contain any useful description. + // https://websockets.spec.whatwg.org//#dom-websocket-onerror + // If the user agent was required to fail the WebSocket connection, + // or if the WebSocket connection was closed after being flagged as full, + // fire an event named error at the WebSocket object. + // https://stackoverflow.com/a/31003057 + this._ws.onerror = _event => { + this.emit('error', new Error(`WsWrapper error. Ready state: ${this.readyState}`)); + }; + + this._ws.onmessage = message => { + this.emit('message', message.data); + }; + } + + close() { + if (this.readyState === WSWrapper.OPEN) { + this._ws.close(); + } + } + + send(message: any) { + if (this.readyState !== WSWrapper.OPEN) { + throw new Error(`Connection is not open. state: ${this.readyState}`); + } + this._ws.send(message); + } + + get readyState() { + return this._ws.readyState; + } +} + +// eslint-disable-next-line import/no-default-export +export default WSWrapper; diff --git a/packages/websocket-client/src/ws-native.ts b/packages/websocket-client/src/ws-native.ts new file mode 100644 index 0000000000..f168545666 --- /dev/null +++ b/packages/websocket-client/src/ws-native.ts @@ -0,0 +1,67 @@ +import { EventEmitter } from 'events'; + +/** + * Provides `EventEmitter` interface for React Native `WebSocket`, + * same, as `ws` package provides. + */ +class WSWrapper extends EventEmitter { + private _ws: WebSocket; + static CONNECTING = 0; + static OPEN = 1; + static CLOSING = 2; + static CLOSED = 3; + + constructor(url: string, _protocols: any, _websocketOptions: any) { + super(); + + // React Native WebSocket is able to accept headers compared to the native browser `WebSocket`. + // @ts-expect-error + this._ws = new WebSocket(url, ['wss'], { + headers: { + 'User-Agent': 'Trezor Suite Native', + }, + }); + + this._ws.onclose = () => { + this.emit('close'); + }; + + this._ws.onopen = () => { + this.emit('open'); + }; + + // WebSocket error Event does not contain any useful description. + // https://websockets.spec.whatwg.org//#dom-websocket-onerror + // If the user agent was required to fail the WebSocket connection, + // or if the WebSocket connection was closed after being flagged as full, + // fire an event named error at the WebSocket object. + // https://stackoverflow.com/a/31003057 + this._ws.onerror = _event => { + this.emit('error', new Error(`WsWrapper error. Ready state: ${this.readyState}`)); + }; + + this._ws.onmessage = message => { + this.emit('message', message.data); + }; + } + + close() { + if (this.readyState === WSWrapper.OPEN) { + this._ws.close(); + } + } + + send(message: any) { + if (this.readyState !== WSWrapper.OPEN) { + throw new Error(`Connection is not open. state: ${this.readyState}`); + } + this._ws.send(message); + } + + get readyState() { + return this._ws.readyState; + } +} + +// eslint-disable-next-line import/no-default-export +export default WSWrapper; diff --git a/packages/websocket-client/tests/client.test.ts b/packages/websocket-client/tests/client.test.ts new file mode 100644 index 0000000000..ed0e46485d --- /dev/null +++ b/packages/websocket-client/tests/client.test.ts @@ -0,0 +1,164 @@ +import { ServerOptions, WebSocket } from 'ws'; + +import { WebsocketClient } from '../src/client'; + +class Client extends WebsocketClient<{ 'foo-event': 'bar-event' }> { + createWebsocket() { + return this.initWebsocket(this.options); + } + ping() { + return this.sendMessage({ method: 'ping' }); + } + sendMessage(message: Record) { + return super.sendMessage(message); + } +} + +class Server extends WebSocket.Server { + private _url: string; + fixtures?: any[]; + + constructor(options: ServerOptions, callback?: () => void) { + super(options, callback); + + this._url = `ws://localhost:${options.port}`; + this.on('connection', ws => { + ws.on('message', data => this.sendResponse(ws, data)); + }); + } + + public getUrl() { + return this._url; + } + + private sendResponse(client: WebSocket, data: any) { + const request = JSON.parse(data); + const { id, method } = request; + let response; + + if (method === 'init') { + response = { success: true }; + } + + if (method === 'ping') { + response = { success: true }; + } + + if (!response) { + response = { + success: false, + error: { message: `unknown response for method ${method}` }, + }; + } + + client.send(JSON.stringify({ ...response, id })); + } +} + +const createServer = async () => { + const port = 12345; + const server = new Server({ port }); + await new Promise((resolve, reject) => { + server.once('listening', () => resolve()); + server.once('error', error => reject(error)); + }); + + return { server, url: `ws://localhost:${port}` }; +}; + +describe('WebsocketClient', () => { + let server: Server; + beforeAll(async () => { + const r = await createServer(); + server = r.server; + }); + + afterAll(() => { + server.close(); + }); + + it('success', async () => { + const cli = new Client({ url: server.getUrl(), pingTimeout: 500 }); + await cli.connect(); + + // types check: + cli.on('foo-event', event => { + if (event === 'bar-event') { + // + } + }); + + const resp = await cli.sendMessage({ method: 'init' }); + expect(resp.success).toEqual(true); + + await cli.disconnect(); + }); + + it('ping', async () => { + jest.useFakeTimers(); + + const cli = new Client({ url: server.getUrl(), pingTimeout: 5000 }); + const pingSpy = jest.spyOn(cli, 'ping'); + await cli.connect(); + + // call first messages to init ping + const resp = await cli.sendMessage({ method: 'init' }); + expect(resp.success).toEqual(true); + // wait for ping + await jest.advanceTimersByTimeAsync(4 * 5000); + expect(pingSpy).toHaveBeenCalledTimes(4); + + await cli.disconnect(); + + pingSpy.mockRestore(); + jest.useRealTimers(); + }); + + it('reconnect with sync disconnect()', async () => { + const cli = new Client({ url: server.getUrl() }); + await cli.connect(); + cli.disconnect(); // NOTE: intentionally not awaited + await cli.connect(); + + const resp = await cli.sendMessage({ method: 'init' }); + expect(resp.success).toEqual(true); + + cli.disconnect(); + }); + + it('client.disconnect()', async () => { + const cli = new Client({ url: server.getUrl() }); + const disconnectedSpy = jest.fn(); + cli.on('disconnected', disconnectedSpy); + + // calling before connection + await cli.disconnect(); + expect(disconnectedSpy).toHaveBeenCalledTimes(0); + + await cli.connect(); + await cli.disconnect(); + expect(disconnectedSpy).toHaveBeenCalledTimes(1); + }); + + it('client.dispose()', async () => { + const cli = new Client({ url: server.getUrl() }); + const disconnectedSpy = jest.fn(); + cli.on('disconnected', disconnectedSpy); + + // calling before connection + cli.dispose(); + expect(disconnectedSpy).toHaveBeenCalledTimes(0); + + // set listener again, previous .dispose removed it + cli.on('disconnected', disconnectedSpy); + await cli.connect(); + cli.dispose(); + expect(disconnectedSpy).toHaveBeenCalledTimes(0); + }); + + it('throws connection error', async () => { + const cli = new Client({ url: 'invalid-url' }); + + await expect(() => cli.connect()).rejects.toThrow('invalid-url'); + }); +}); diff --git a/packages/websocket-client/tsconfig.json b/packages/websocket-client/tsconfig.json new file mode 100644 index 0000000000..0ec4519c33 --- /dev/null +++ b/packages/websocket-client/tsconfig.json @@ -0,0 +1,5 @@ +{ + "extends": "../../tsconfig.base.json", + "compilerOptions": { "outDir": "libDev" }, + "references": [{ "path": "../utils" }] +} diff --git a/packages/websocket-client/tsconfig.lib.json b/packages/websocket-client/tsconfig.lib.json new file mode 100644 index 0000000000..c9e91da72f --- /dev/null +++ b/packages/websocket-client/tsconfig.lib.json @@ -0,0 +1,14 @@ +{ + "extends": "../../tsconfig.lib.json", + "compilerOptions": { + "outDir": "./lib", + "lib": ["webworker"], + "types": ["jest", "node", "web"] + }, + "include": ["./src"], + "references": [ + { + "path": "../utils" + } + ] +} diff --git a/yarn.lock b/yarn.lock index a776e12c90..006c130ab2 100644 --- a/yarn.lock +++ b/yarn.lock @@ -12953,6 +12953,17 @@ __metadata: languageName: unknown linkType: soft +"@trezor/websocket-client@workspace:packages/websocket-client": + version: 0.0.0-use.local + resolution: "@trezor/websocket-client@workspace:packages/websocket-client" + dependencies: + "@trezor/utils": "workspace:*" + ws: "npm:^8.18.0" + peerDependencies: + tslib: ^2.6.2 + languageName: unknown + linkType: soft + "@trysound/sax@npm:0.2.0": version: 0.2.0 resolution: "@trysound/sax@npm:0.2.0"