mirror of
https://github.com/trezor/trezor-suite.git
synced 2026-02-20 00:33:07 +01:00
feat: improve USB write performance in mobile app
This commit is contained in:
committed by
Tomáš Martykán
parent
237c94425d
commit
5ccfee82eb
@@ -32,6 +32,14 @@ const val LOG_TAG = "ReactNativeUsb"
|
||||
|
||||
// TODO: get interface index by claimed interface
|
||||
const val INTERFACE_INDEX = 0
|
||||
const val CHUNK_SIZE = 64
|
||||
|
||||
// Protocol v1 constants
|
||||
const val PROTOCOL_V1_MAGIC_BYTE: Byte = 0x3F // '?' character - first byte of v1 messages and chunk header
|
||||
|
||||
// Protocol v2 (THP) constants
|
||||
const val THP_CONTINUATION_PACKET: Byte = 0x80.toByte() // Continuation packet marker for THP
|
||||
const val THP_HEADER_SIZE = 3 // control_byte + 2 bytes channel
|
||||
|
||||
class ReactNativeUsbModule : Module() {
|
||||
private val moduleCoroutineScope = CoroutineScope(Dispatchers.IO)
|
||||
@@ -88,10 +96,14 @@ class ReactNativeUsbModule : Module() {
|
||||
return@AsyncFunction releaseInterface(deviceName, interfaceNumber)
|
||||
}
|
||||
|
||||
AsyncFunction("transferOut") { deviceName: String, endpointNumber: Int, data: String, promise: Promise ->
|
||||
AsyncFunction("transferOut") { deviceName: String, endpointNumber: Int, data: ByteArray, promise: Promise ->
|
||||
// Clone the data before entering async code, as the original ByteArray from JS
|
||||
// may become invalid after switching threads (GC may collect it).
|
||||
val dataClone = data.copyOf()
|
||||
|
||||
withModuleScope(promise) {
|
||||
try {
|
||||
val result = transferOut(deviceName, endpointNumber, data)
|
||||
val result = transferOut(deviceName, endpointNumber, dataClone)
|
||||
promise.resolve(result)
|
||||
} catch (e: Exception) {
|
||||
promise.reject("USB Write Error", e.message, e)
|
||||
@@ -305,15 +317,85 @@ class ReactNativeUsbModule : Module() {
|
||||
usbConnection.releaseInterface(usbInterface)
|
||||
}
|
||||
|
||||
private fun transferOut(deviceName: String, endpointNumber: Int, data: String): Int {
|
||||
Log.d(LOG_TAG, "Transfering data to device $deviceName")
|
||||
Log.d(LOG_TAG, "data: $data")
|
||||
// split string into array of numbers and then convert numbers to byte array
|
||||
val dataByteArray = data.split(",").map { it.toInt().toByte() }.toByteArray()
|
||||
Log.d(LOG_TAG, "dataByteArray: $dataByteArray")
|
||||
/**
|
||||
* Detects the protocol and returns the appropriate chunk header.
|
||||
* - Protocol v1: First byte is 0x3F ('?'), chunk header is 1 byte (0x3F)
|
||||
* - Protocol v2 (THP): First byte is a THP control byte, chunk header is 3 bytes (0x80 + channel)
|
||||
*/
|
||||
private fun getChunkHeader(data: ByteArray): ByteArray {
|
||||
val firstByte = data[0]
|
||||
|
||||
return if (firstByte == PROTOCOL_V1_MAGIC_BYTE) {
|
||||
// Protocol v1: chunk header is just 0x3F
|
||||
Log.d(LOG_TAG, "Detected protocol v1, using 1-byte chunk header")
|
||||
byteArrayOf(PROTOCOL_V1_MAGIC_BYTE)
|
||||
} else {
|
||||
// Protocol v2 (THP): chunk header is 0x80 + channel (bytes 1-2)
|
||||
// Message format: [control_byte, channel_high, channel_low, ...]
|
||||
if (data.size < THP_HEADER_SIZE) {
|
||||
Log.w(LOG_TAG, "Data too short for THP header, falling back to v1 header")
|
||||
byteArrayOf(PROTOCOL_V1_MAGIC_BYTE)
|
||||
} else {
|
||||
val channelHigh = data[1]
|
||||
val channelLow = data[2]
|
||||
Log.d(LOG_TAG, "Detected protocol v2 (THP), using 3-byte chunk header with channel ${channelHigh.toInt() and 0xFF}:${channelLow.toInt() and 0xFF}")
|
||||
byteArrayOf(THP_CONTINUATION_PACKET, channelHigh, channelLow)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates 64-byte chunks from data for USB transfer.
|
||||
*
|
||||
* Supports Protocol v1 (1-byte header) and Protocol v2/THP (3-byte header)
|
||||
*/
|
||||
private fun createChunks(data: ByteArray): List<ByteArray> {
|
||||
val dataSize = data.size
|
||||
|
||||
// Single chunk case - just pad to CHUNK_SIZE
|
||||
if (dataSize <= CHUNK_SIZE) {
|
||||
val chunk = ByteArray(CHUNK_SIZE)
|
||||
System.arraycopy(data, 0, chunk, 0, dataSize)
|
||||
return listOf(chunk)
|
||||
}
|
||||
|
||||
// Multi-chunk case - detect protocol and build chunks
|
||||
val chunkHeader = getChunkHeader(data)
|
||||
val chunkHeaderSize = chunkHeader.size
|
||||
val chunkDataSize = CHUNK_SIZE - chunkHeaderSize
|
||||
|
||||
val chunks = mutableListOf<ByteArray>()
|
||||
var dataOffset = 0
|
||||
|
||||
// First chunk: no header, up to 64 bytes
|
||||
val firstChunk = ByteArray(CHUNK_SIZE)
|
||||
System.arraycopy(data, 0, firstChunk, 0, CHUNK_SIZE)
|
||||
chunks.add(firstChunk)
|
||||
dataOffset = CHUNK_SIZE
|
||||
|
||||
// Subsequent chunks: header + data
|
||||
while (dataOffset < dataSize) {
|
||||
val chunk = ByteArray(CHUNK_SIZE)
|
||||
System.arraycopy(chunkHeader, 0, chunk, 0, chunkHeaderSize)
|
||||
val bytesToCopy = minOf(chunkDataSize, dataSize - dataOffset)
|
||||
System.arraycopy(data, dataOffset, chunk, chunkHeaderSize, bytesToCopy)
|
||||
chunks.add(chunk)
|
||||
dataOffset += bytesToCopy
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
/**
|
||||
* Transfers data to USB device, chunking if necessary.
|
||||
* Uses bulkTransfer for efficient sequential transfers.
|
||||
*/
|
||||
private fun transferOut(deviceName: String, endpointNumber: Int, data: ByteArray): Int {
|
||||
val transferStartTime = System.nanoTime()
|
||||
|
||||
val device = getDeviceByName(deviceName)
|
||||
val usbConnection = openedConnections.getOrPut(device.deviceName) {
|
||||
Log.d(LOG_TAG, "Reopening device ${device.deviceName}")
|
||||
Log.d(LOG_TAG, "transferOut: Reopening device ${device.deviceName}")
|
||||
usbManager.openDevice(device) ?: throw Exception("Failed to open device ${device.deviceName}")
|
||||
}
|
||||
|
||||
@@ -322,12 +404,32 @@ class ReactNativeUsbModule : Module() {
|
||||
Log.e(LOG_TAG, "Failed to get endpoint $endpointNumber for device ${device.deviceName}")
|
||||
throw Exception("Failed to get endpoint $endpointNumber for device ${device.deviceName}")
|
||||
}
|
||||
val result = usbConnection.bulkTransfer(usbEndpoint, dataByteArray, dataByteArray.size, 0)
|
||||
Log.d(LOG_TAG, "Transfered data to device ${device.deviceName}: $result")
|
||||
return result
|
||||
|
||||
val chunks = createChunks(data)
|
||||
val totalChunks = chunks.size
|
||||
val bulkTimeout = 0 // 0 means no timeout
|
||||
|
||||
for ((index, chunk) in chunks.withIndex()) {
|
||||
val bytesWritten = usbConnection.bulkTransfer(usbEndpoint, chunk, CHUNK_SIZE, bulkTimeout)
|
||||
|
||||
if (bytesWritten < 0) {
|
||||
Log.e(LOG_TAG, "transferOut: FAILED chunk ${index + 1}/$totalChunks, error: $bytesWritten")
|
||||
throw Exception("USB transfer failed for chunk ${index + 1}/$totalChunks, error: $bytesWritten")
|
||||
}
|
||||
|
||||
if (bytesWritten != CHUNK_SIZE) {
|
||||
Log.e(LOG_TAG, "transferOut: FAILED chunk ${index + 1}/$totalChunks incomplete: $bytesWritten/$CHUNK_SIZE")
|
||||
throw Exception("USB transfer incomplete for chunk ${index + 1}/$totalChunks: $bytesWritten/$CHUNK_SIZE bytes")
|
||||
}
|
||||
}
|
||||
|
||||
val durationMs = (System.nanoTime() - transferStartTime) / 1_000_000.0
|
||||
Log.d(LOG_TAG, "transferOut: Done - ${data.size} bytes, $totalChunks chunks, %.1f ms".format(durationMs))
|
||||
|
||||
return data.size
|
||||
}
|
||||
|
||||
private fun transferIn(deviceName: String, endpointNumber: Int, length: Int): IntArray {
|
||||
private fun transferIn(deviceName: String, endpointNumber: Int, length: Int): ByteArray {
|
||||
Log.d(LOG_TAG, "Reading data from device $deviceName")
|
||||
val device = getDeviceByName(deviceName)
|
||||
|
||||
@@ -366,13 +468,9 @@ class ReactNativeUsbModule : Module() {
|
||||
Log.e(LOG_TAG, "Failed to transfer data from device ${device.deviceName}")
|
||||
throw Exception("Failed to transfer data from device ${device.deviceName}")
|
||||
}
|
||||
Log.d(LOG_TAG, "Read data from device ${device.deviceName}: ${buffer.array()}")
|
||||
// convert buffer to Array
|
||||
val bufferArray = buffer.array()
|
||||
Log.d(LOG_TAG, "bufferArray: ${bufferArray.toList()}")
|
||||
// convert Array to IntArray
|
||||
val bufferIntArray = bufferArray.map { it.toInt() }.toIntArray()
|
||||
return bufferIntArray
|
||||
|
||||
Log.d(LOG_TAG, "Read data from device ${device.deviceName}: $length bytes")
|
||||
return buffer.array()
|
||||
}
|
||||
|
||||
private fun getOpenedConnection(deviceName: String): UsbDeviceConnection {
|
||||
|
||||
@@ -14,8 +14,8 @@ declare class ReactNativeUsbModuleDeclaration extends NativeModule<DeviceEvents>
|
||||
close: (deviceName: string) => Promise<void>;
|
||||
claimInterface: (deviceName: string, interfaceNumber: number) => Promise<void>;
|
||||
releaseInterface: (deviceName: string, interfaceNumber: number) => Promise<void>;
|
||||
transferIn: (deviceName: string, endpointNumber: number, length: number) => Promise<number[]>;
|
||||
transferOut: (deviceName: string, endpointNumber: number, data: string) => Promise<void>;
|
||||
transferIn: (deviceName: string, endpointNumber: number, length: number) => Promise<Uint8Array>;
|
||||
transferOut: (deviceName: string, endpointNumber: number, data: Uint8Array) => Promise<void>;
|
||||
setPriorityMode: (isInPriorityMode: boolean) => void;
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import { EventSubscription } from 'expo-modules-core';
|
||||
import { NativeDevice, OnConnectEvent, WebUSBDevice } from './ReactNativeUsb.types';
|
||||
import { ReactNativeUsbModule } from './ReactNativeUsbModule';
|
||||
|
||||
const DEBUG_LOGS = false;
|
||||
const DEBUG_LOGS = true;
|
||||
|
||||
const debugLog = (...args: any[]) => {
|
||||
if (DEBUG_LOGS) {
|
||||
@@ -34,11 +34,11 @@ const transferIn = async (deviceName: string, endpointNumber: number, length: nu
|
||||
debugLog('JS: USB read error: ', error);
|
||||
throw error;
|
||||
})
|
||||
.then((result: number[]) => {
|
||||
debugLog('JS: Native USB read result:', JSON.stringify(result));
|
||||
.then(result => {
|
||||
debugLog('JS: Native USB read result length:', result.length);
|
||||
|
||||
return {
|
||||
data: new Uint8Array(result),
|
||||
data: result,
|
||||
status: 'ok',
|
||||
};
|
||||
});
|
||||
@@ -54,7 +54,12 @@ const transferOut = async (
|
||||
) => {
|
||||
try {
|
||||
const perf = performance.now();
|
||||
await ReactNativeUsbModule.transferOut(deviceName, endpointNumber, data.toString());
|
||||
// Ensure we pass a Uint8Array directly to native code (maps to kotlin.ByteArray)
|
||||
const uint8Data =
|
||||
data instanceof Uint8Array
|
||||
? data
|
||||
: new Uint8Array(ArrayBuffer.isView(data) ? data.buffer : data);
|
||||
await ReactNativeUsbModule.transferOut(deviceName, endpointNumber, uint8Data);
|
||||
debugLog('JS: USB write time', performance.now() - perf);
|
||||
|
||||
return { status: 'ok' };
|
||||
|
||||
@@ -18,5 +18,9 @@ export class NativeUsbTransport extends AbstractApiTransport {
|
||||
logger,
|
||||
...rest,
|
||||
});
|
||||
|
||||
// Let the native Kotlin code handle the chunking.
|
||||
// It significantly improves the performance of writes during FW update.
|
||||
this.api.nativeWriteChunking = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,6 +141,11 @@ export abstract class AbstractApi extends TypedEmitter<{
|
||||
*/
|
||||
public abstract chunkSize: number;
|
||||
|
||||
/**
|
||||
* send whole data in one chunk and let the native code handle the chunking (used only in React Native)
|
||||
*/
|
||||
public nativeWriteChunking: boolean = false;
|
||||
|
||||
protected success<T>(payload: T): Success<T> {
|
||||
return success(payload);
|
||||
}
|
||||
|
||||
@@ -251,8 +251,15 @@ export class UsbApi extends AbstractApi {
|
||||
if (!device) {
|
||||
return this.error({ error: ERRORS.DEVICE_NOT_FOUND });
|
||||
}
|
||||
const newArray = new Uint8Array(this.chunkSize);
|
||||
newArray.set(new Uint8Array(buffer));
|
||||
|
||||
let newArray: Uint8Array<ArrayBuffer>;
|
||||
if (this.nativeWriteChunking) {
|
||||
// Pass the full buffer for native chunking
|
||||
newArray = new Uint8Array(buffer);
|
||||
} else {
|
||||
newArray = new Uint8Array(this.chunkSize);
|
||||
newArray.set(new Uint8Array(buffer));
|
||||
}
|
||||
|
||||
const timeout = setTimeout(() => {
|
||||
this.logger?.debug('usb: device.transfer out take suspiciously long. timing out.');
|
||||
|
||||
@@ -247,7 +247,11 @@ export abstract class AbstractApiTransport extends AbstractTransport {
|
||||
thpState,
|
||||
});
|
||||
const [, chunkHeader] = protocol.getHeaders(bytes);
|
||||
const chunks = createChunks(bytes, chunkHeader, this.api.chunkSize);
|
||||
const chunks = createChunks(
|
||||
bytes,
|
||||
chunkHeader,
|
||||
this.api.nativeWriteChunking ? 0 : this.api.chunkSize,
|
||||
);
|
||||
let progress = 0;
|
||||
const apiWrite = (chunk: Buffer, attemptSignal?: AbortSignal) => {
|
||||
if (chunks.length > 1) {
|
||||
@@ -341,7 +345,11 @@ export abstract class AbstractApiTransport extends AbstractTransport {
|
||||
});
|
||||
const [_, chunkHeader] = protocol.getHeaders(bytes);
|
||||
|
||||
const chunks = createChunks(bytes, chunkHeader, this.api.chunkSize);
|
||||
const chunks = createChunks(
|
||||
bytes,
|
||||
chunkHeader,
|
||||
this.api.nativeWriteChunking ? 0 : this.api.chunkSize,
|
||||
);
|
||||
let progress = 0;
|
||||
const apiWrite = (chunk: Buffer) => {
|
||||
if (chunks.length > 1) {
|
||||
|
||||
Reference in New Issue
Block a user