diff --git a/packages/react-native-usb/android/src/main/java/io/trezor/rnusb/ReactNativeUsbModule.kt b/packages/react-native-usb/android/src/main/java/io/trezor/rnusb/ReactNativeUsbModule.kt index 58436c43b9..ae7193368a 100644 --- a/packages/react-native-usb/android/src/main/java/io/trezor/rnusb/ReactNativeUsbModule.kt +++ b/packages/react-native-usb/android/src/main/java/io/trezor/rnusb/ReactNativeUsbModule.kt @@ -32,6 +32,14 @@ const val LOG_TAG = "ReactNativeUsb" // TODO: get interface index by claimed interface const val INTERFACE_INDEX = 0 +const val CHUNK_SIZE = 64 + +// Protocol v1 constants +const val PROTOCOL_V1_MAGIC_BYTE: Byte = 0x3F // '?' character - first byte of v1 messages and chunk header + +// Protocol v2 (THP) constants +const val THP_CONTINUATION_PACKET: Byte = 0x80.toByte() // Continuation packet marker for THP +const val THP_HEADER_SIZE = 3 // control_byte + 2 bytes channel class ReactNativeUsbModule : Module() { private val moduleCoroutineScope = CoroutineScope(Dispatchers.IO) @@ -88,10 +96,14 @@ class ReactNativeUsbModule : Module() { return@AsyncFunction releaseInterface(deviceName, interfaceNumber) } - AsyncFunction("transferOut") { deviceName: String, endpointNumber: Int, data: String, promise: Promise -> + AsyncFunction("transferOut") { deviceName: String, endpointNumber: Int, data: ByteArray, promise: Promise -> + // Clone the data before entering async code, as the original ByteArray from JS + // may become invalid after switching threads (GC may collect it). + val dataClone = data.copyOf() + withModuleScope(promise) { try { - val result = transferOut(deviceName, endpointNumber, data) + val result = transferOut(deviceName, endpointNumber, dataClone) promise.resolve(result) } catch (e: Exception) { promise.reject("USB Write Error", e.message, e) @@ -305,15 +317,85 @@ class ReactNativeUsbModule : Module() { usbConnection.releaseInterface(usbInterface) } - private fun transferOut(deviceName: String, endpointNumber: Int, data: String): Int { - Log.d(LOG_TAG, "Transfering data to device $deviceName") - Log.d(LOG_TAG, "data: $data") - // split string into array of numbers and then convert numbers to byte array - val dataByteArray = data.split(",").map { it.toInt().toByte() }.toByteArray() - Log.d(LOG_TAG, "dataByteArray: $dataByteArray") + /** + * Detects the protocol and returns the appropriate chunk header. + * - Protocol v1: First byte is 0x3F ('?'), chunk header is 1 byte (0x3F) + * - Protocol v2 (THP): First byte is a THP control byte, chunk header is 3 bytes (0x80 + channel) + */ + private fun getChunkHeader(data: ByteArray): ByteArray { + val firstByte = data[0] + + return if (firstByte == PROTOCOL_V1_MAGIC_BYTE) { + // Protocol v1: chunk header is just 0x3F + Log.d(LOG_TAG, "Detected protocol v1, using 1-byte chunk header") + byteArrayOf(PROTOCOL_V1_MAGIC_BYTE) + } else { + // Protocol v2 (THP): chunk header is 0x80 + channel (bytes 1-2) + // Message format: [control_byte, channel_high, channel_low, ...] + if (data.size < THP_HEADER_SIZE) { + Log.w(LOG_TAG, "Data too short for THP header, falling back to v1 header") + byteArrayOf(PROTOCOL_V1_MAGIC_BYTE) + } else { + val channelHigh = data[1] + val channelLow = data[2] + Log.d(LOG_TAG, "Detected protocol v2 (THP), using 3-byte chunk header with channel ${channelHigh.toInt() and 0xFF}:${channelLow.toInt() and 0xFF}") + byteArrayOf(THP_CONTINUATION_PACKET, channelHigh, channelLow) + } + } + } + + /** + * Creates 64-byte chunks from data for USB transfer. + * + * Supports Protocol v1 (1-byte header) and Protocol v2/THP (3-byte header) + */ + private fun createChunks(data: ByteArray): List { + val dataSize = data.size + + // Single chunk case - just pad to CHUNK_SIZE + if (dataSize <= CHUNK_SIZE) { + val chunk = ByteArray(CHUNK_SIZE) + System.arraycopy(data, 0, chunk, 0, dataSize) + return listOf(chunk) + } + + // Multi-chunk case - detect protocol and build chunks + val chunkHeader = getChunkHeader(data) + val chunkHeaderSize = chunkHeader.size + val chunkDataSize = CHUNK_SIZE - chunkHeaderSize + + val chunks = mutableListOf() + var dataOffset = 0 + + // First chunk: no header, up to 64 bytes + val firstChunk = ByteArray(CHUNK_SIZE) + System.arraycopy(data, 0, firstChunk, 0, CHUNK_SIZE) + chunks.add(firstChunk) + dataOffset = CHUNK_SIZE + + // Subsequent chunks: header + data + while (dataOffset < dataSize) { + val chunk = ByteArray(CHUNK_SIZE) + System.arraycopy(chunkHeader, 0, chunk, 0, chunkHeaderSize) + val bytesToCopy = minOf(chunkDataSize, dataSize - dataOffset) + System.arraycopy(data, dataOffset, chunk, chunkHeaderSize, bytesToCopy) + chunks.add(chunk) + dataOffset += bytesToCopy + } + + return chunks + } + + /** + * Transfers data to USB device, chunking if necessary. + * Uses bulkTransfer for efficient sequential transfers. + */ + private fun transferOut(deviceName: String, endpointNumber: Int, data: ByteArray): Int { + val transferStartTime = System.nanoTime() + val device = getDeviceByName(deviceName) val usbConnection = openedConnections.getOrPut(device.deviceName) { - Log.d(LOG_TAG, "Reopening device ${device.deviceName}") + Log.d(LOG_TAG, "transferOut: Reopening device ${device.deviceName}") usbManager.openDevice(device) ?: throw Exception("Failed to open device ${device.deviceName}") } @@ -322,12 +404,32 @@ class ReactNativeUsbModule : Module() { Log.e(LOG_TAG, "Failed to get endpoint $endpointNumber for device ${device.deviceName}") throw Exception("Failed to get endpoint $endpointNumber for device ${device.deviceName}") } - val result = usbConnection.bulkTransfer(usbEndpoint, dataByteArray, dataByteArray.size, 0) - Log.d(LOG_TAG, "Transfered data to device ${device.deviceName}: $result") - return result + + val chunks = createChunks(data) + val totalChunks = chunks.size + val bulkTimeout = 0 // 0 means no timeout + + for ((index, chunk) in chunks.withIndex()) { + val bytesWritten = usbConnection.bulkTransfer(usbEndpoint, chunk, CHUNK_SIZE, bulkTimeout) + + if (bytesWritten < 0) { + Log.e(LOG_TAG, "transferOut: FAILED chunk ${index + 1}/$totalChunks, error: $bytesWritten") + throw Exception("USB transfer failed for chunk ${index + 1}/$totalChunks, error: $bytesWritten") + } + + if (bytesWritten != CHUNK_SIZE) { + Log.e(LOG_TAG, "transferOut: FAILED chunk ${index + 1}/$totalChunks incomplete: $bytesWritten/$CHUNK_SIZE") + throw Exception("USB transfer incomplete for chunk ${index + 1}/$totalChunks: $bytesWritten/$CHUNK_SIZE bytes") + } + } + + val durationMs = (System.nanoTime() - transferStartTime) / 1_000_000.0 + Log.d(LOG_TAG, "transferOut: Done - ${data.size} bytes, $totalChunks chunks, %.1f ms".format(durationMs)) + + return data.size } - private fun transferIn(deviceName: String, endpointNumber: Int, length: Int): IntArray { + private fun transferIn(deviceName: String, endpointNumber: Int, length: Int): ByteArray { Log.d(LOG_TAG, "Reading data from device $deviceName") val device = getDeviceByName(deviceName) @@ -366,13 +468,9 @@ class ReactNativeUsbModule : Module() { Log.e(LOG_TAG, "Failed to transfer data from device ${device.deviceName}") throw Exception("Failed to transfer data from device ${device.deviceName}") } - Log.d(LOG_TAG, "Read data from device ${device.deviceName}: ${buffer.array()}") - // convert buffer to Array - val bufferArray = buffer.array() - Log.d(LOG_TAG, "bufferArray: ${bufferArray.toList()}") - // convert Array to IntArray - val bufferIntArray = bufferArray.map { it.toInt() }.toIntArray() - return bufferIntArray + + Log.d(LOG_TAG, "Read data from device ${device.deviceName}: $length bytes") + return buffer.array() } private fun getOpenedConnection(deviceName: String): UsbDeviceConnection { diff --git a/packages/react-native-usb/src/ReactNativeUsbModule.ts b/packages/react-native-usb/src/ReactNativeUsbModule.ts index 413d21815c..9384c4f694 100644 --- a/packages/react-native-usb/src/ReactNativeUsbModule.ts +++ b/packages/react-native-usb/src/ReactNativeUsbModule.ts @@ -14,8 +14,8 @@ declare class ReactNativeUsbModuleDeclaration extends NativeModule close: (deviceName: string) => Promise; claimInterface: (deviceName: string, interfaceNumber: number) => Promise; releaseInterface: (deviceName: string, interfaceNumber: number) => Promise; - transferIn: (deviceName: string, endpointNumber: number, length: number) => Promise; - transferOut: (deviceName: string, endpointNumber: number, data: string) => Promise; + transferIn: (deviceName: string, endpointNumber: number, length: number) => Promise; + transferOut: (deviceName: string, endpointNumber: number, data: Uint8Array) => Promise; setPriorityMode: (isInPriorityMode: boolean) => void; } diff --git a/packages/react-native-usb/src/index.ts b/packages/react-native-usb/src/index.ts index 7859223c83..3a5c4295d5 100644 --- a/packages/react-native-usb/src/index.ts +++ b/packages/react-native-usb/src/index.ts @@ -3,7 +3,7 @@ import { EventSubscription } from 'expo-modules-core'; import { NativeDevice, OnConnectEvent, WebUSBDevice } from './ReactNativeUsb.types'; import { ReactNativeUsbModule } from './ReactNativeUsbModule'; -const DEBUG_LOGS = false; +const DEBUG_LOGS = true; const debugLog = (...args: any[]) => { if (DEBUG_LOGS) { @@ -34,11 +34,11 @@ const transferIn = async (deviceName: string, endpointNumber: number, length: nu debugLog('JS: USB read error: ', error); throw error; }) - .then((result: number[]) => { - debugLog('JS: Native USB read result:', JSON.stringify(result)); + .then(result => { + debugLog('JS: Native USB read result length:', result.length); return { - data: new Uint8Array(result), + data: result, status: 'ok', }; }); @@ -54,7 +54,12 @@ const transferOut = async ( ) => { try { const perf = performance.now(); - await ReactNativeUsbModule.transferOut(deviceName, endpointNumber, data.toString()); + // Ensure we pass a Uint8Array directly to native code (maps to kotlin.ByteArray) + const uint8Data = + data instanceof Uint8Array + ? data + : new Uint8Array(ArrayBuffer.isView(data) ? data.buffer : data); + await ReactNativeUsbModule.transferOut(deviceName, endpointNumber, uint8Data); debugLog('JS: USB write time', performance.now() - perf); return { status: 'ok' }; diff --git a/packages/transport-native-usb/src/nativeUsb.ts b/packages/transport-native-usb/src/nativeUsb.ts index 57b0b5bb5c..3502a86602 100644 --- a/packages/transport-native-usb/src/nativeUsb.ts +++ b/packages/transport-native-usb/src/nativeUsb.ts @@ -18,5 +18,9 @@ export class NativeUsbTransport extends AbstractApiTransport { logger, ...rest, }); + + // Let the native Kotlin code handle the chunking. + // It significantly improves the performance of writes during FW update. + this.api.nativeWriteChunking = true; } } diff --git a/packages/transport/src/api/abstract.ts b/packages/transport/src/api/abstract.ts index c5b12ca7ce..647724d598 100644 --- a/packages/transport/src/api/abstract.ts +++ b/packages/transport/src/api/abstract.ts @@ -141,6 +141,11 @@ export abstract class AbstractApi extends TypedEmitter<{ */ public abstract chunkSize: number; + /** + * send whole data in one chunk and let the native code handle the chunking (used only in React Native) + */ + public nativeWriteChunking: boolean = false; + protected success(payload: T): Success { return success(payload); } diff --git a/packages/transport/src/api/usb.ts b/packages/transport/src/api/usb.ts index 1c6b128604..9df83c2a99 100644 --- a/packages/transport/src/api/usb.ts +++ b/packages/transport/src/api/usb.ts @@ -251,8 +251,15 @@ export class UsbApi extends AbstractApi { if (!device) { return this.error({ error: ERRORS.DEVICE_NOT_FOUND }); } - const newArray = new Uint8Array(this.chunkSize); - newArray.set(new Uint8Array(buffer)); + + let newArray: Uint8Array; + if (this.nativeWriteChunking) { + // Pass the full buffer for native chunking + newArray = new Uint8Array(buffer); + } else { + newArray = new Uint8Array(this.chunkSize); + newArray.set(new Uint8Array(buffer)); + } const timeout = setTimeout(() => { this.logger?.debug('usb: device.transfer out take suspiciously long. timing out.'); diff --git a/packages/transport/src/transports/abstractApi.ts b/packages/transport/src/transports/abstractApi.ts index e76c449b1a..cd6ab89600 100644 --- a/packages/transport/src/transports/abstractApi.ts +++ b/packages/transport/src/transports/abstractApi.ts @@ -247,7 +247,11 @@ export abstract class AbstractApiTransport extends AbstractTransport { thpState, }); const [, chunkHeader] = protocol.getHeaders(bytes); - const chunks = createChunks(bytes, chunkHeader, this.api.chunkSize); + const chunks = createChunks( + bytes, + chunkHeader, + this.api.nativeWriteChunking ? 0 : this.api.chunkSize, + ); let progress = 0; const apiWrite = (chunk: Buffer, attemptSignal?: AbortSignal) => { if (chunks.length > 1) { @@ -341,7 +345,11 @@ export abstract class AbstractApiTransport extends AbstractTransport { }); const [_, chunkHeader] = protocol.getHeaders(bytes); - const chunks = createChunks(bytes, chunkHeader, this.api.chunkSize); + const chunks = createChunks( + bytes, + chunkHeader, + this.api.nativeWriteChunking ? 0 : this.api.chunkSize, + ); let progress = 0; const apiWrite = (chunk: Buffer) => { if (chunks.length > 1) {