chore(protocol): validate control byte (magic) in protocol v2 decode

This commit is contained in:
Szymon Lesisz
2025-12-08 12:11:16 +01:00
committed by Szymon Lesisz
parent 8c979e83ad
commit c149719480
4 changed files with 62 additions and 6 deletions

View File

@@ -1,3 +1,19 @@
export const HEADER_SIZE = 1 + 2; // 1: control_byte + 2: channel
export const MESSAGE_LEN_SIZE = 2;
export const MESSAGE_TYPE = 'TrezorHostProtocolMessage';
// https://github.com/trezor/trezor-firmware/blob/dc7dccfa6c6121333f732c39565878fc46e67085/core/src/trezor/wire/thp/__init__.py
export const THP_CONTROL_BYTE = {
HANDSHAKE_INIT_REQ: 0x00, // out
HANDSHAKE_INIT_RES: 0x01, // in
HANDSHAKE_COMP_REQ: 0x02, // out
HANDSHAKE_COMP_RES: 0x03, // in
ENCRYPTED: 0x04, // in / out
ACK_MESSAGE: 0x20, // in / out
CHANNEL_ALLOCATION_REQ: 0x40, // out
CHANNEL_ALLOCATION_RES: 0x41, // in
ERROR: 0x42, // in
PING: 0x43, // out
PONG: 0x44, // in
CONTINUATION_PACKET: 0x80, // in / out
};

View File

@@ -1,8 +1,41 @@
import * as ERRORS from '../errors';
import { HEADER_SIZE, MESSAGE_LEN_SIZE, MESSAGE_TYPE } from './constants';
import { HEADER_SIZE, MESSAGE_LEN_SIZE, MESSAGE_TYPE, THP_CONTROL_BYTE } from './constants';
import { getHeaders } from './encode';
import { TransportProtocolDecode } from '../types';
// TODO: link-to-public-docs
// https://github.com/trezor/trezor-firmware/blob/m1nd3r/thp-documentation/docs/common/thp/specification.md#transport-packet-structure
export const decodeCtrlByte = (ctrlByte: number) => {
// DATA message
const dataType = ctrlByte & 0xe7;
switch (dataType) {
case THP_CONTROL_BYTE.HANDSHAKE_COMP_REQ:
case THP_CONTROL_BYTE.HANDSHAKE_COMP_RES:
case THP_CONTROL_BYTE.HANDSHAKE_INIT_REQ:
case THP_CONTROL_BYTE.HANDSHAKE_INIT_RES:
case THP_CONTROL_BYTE.ENCRYPTED:
return dataType;
}
// ACK message
const ackType = ctrlByte & 0xf7;
if (ackType === THP_CONTROL_BYTE.ACK_MESSAGE) {
return ackType;
}
// Unmasked message
switch (ctrlByte) {
case THP_CONTROL_BYTE.CHANNEL_ALLOCATION_REQ:
case THP_CONTROL_BYTE.CHANNEL_ALLOCATION_RES:
case THP_CONTROL_BYTE.PING:
case THP_CONTROL_BYTE.PONG:
case THP_CONTROL_BYTE.ERROR:
return ctrlByte;
}
return undefined;
};
// Parses raw input from Trezor and returns some information about the whole message
export const decode: TransportProtocolDecode = bytes => {
const buffer = Buffer.from(bytes);
@@ -12,13 +45,18 @@ export const decode: TransportProtocolDecode = bytes => {
throw new Error(ERRORS.PROTOCOL_MALFORMED);
}
const messageType = decodeCtrlByte(buffer.readUInt8());
if (messageType === undefined) {
throw new Error(ERRORS.PROTOCOL_MALFORMED);
}
const [header, chunkHeader] = getHeaders(buffer);
return {
header,
chunkHeader,
length: buffer.readUint16BE(HEADER_SIZE),
messageType: MESSAGE_TYPE, // will be decoded by `protocol-thp`
messageType: MESSAGE_TYPE, // will be decoded by `protocol-thp`, TODO messageType
payload: buffer.subarray(HEADER_SIZE + MESSAGE_LEN_SIZE),
};
};

View File

@@ -1,6 +1,5 @@
import * as ERRORS from '../errors';
import { HEADER_SIZE, MESSAGE_LEN_SIZE, MESSAGE_TYPE } from './constants';
import { THP_CONTINUATION_PACKET } from '../protocol-thp/constants';
import { HEADER_SIZE, MESSAGE_LEN_SIZE, MESSAGE_TYPE, THP_CONTROL_BYTE } from './constants';
import { TransportProtocol } from '../types';
const getChunkHeader = (data: Buffer) => {
@@ -10,7 +9,7 @@ const getChunkHeader = (data: Buffer) => {
}
const channel = data.subarray(1, HEADER_SIZE);
const header = Buffer.concat([Buffer.from([THP_CONTINUATION_PACKET]), channel]);
const header = Buffer.concat([Buffer.from([THP_CONTROL_BYTE.CONTINUATION_PACKET]), channel]);
return header;
};

View File

@@ -41,8 +41,11 @@ describe('protocol-v2', () => {
it('decode with error', () => {
expect(() => decode(Buffer.alloc(0))).toThrow('Malformed protocol format');
// CONTINUATION_PACKET
expect(() => decode(Buffer.from('8012380000', 'hex'))).toThrow('Malformed protocol format');
// unrecognized chunk
expect(() => decode(Buffer.from('9912380000', 'hex'))).toThrow('Malformed protocol format');
});
it('getHeaders', () => {
expect(getHeaders(Buffer.from('0412380000', 'hex'))).toEqual([
Buffer.from('041238', 'hex'),