feat(protocol): add THP pairing messages and state

This commit is contained in:
Szymon Lesisz
2024-08-28 12:07:15 +02:00
committed by Szymon Lesisz
parent f16730bebb
commit 2d26b5c928
7 changed files with 283 additions and 35 deletions

View File

@@ -1,4 +1,10 @@
import { ThpCredentials, ThpDeviceProperties, ThpMessageSyncBit } from './messages';
import {
ThpCredentials,
ThpDeviceProperties,
ThpHandshakeCredentials,
ThpMessageSyncBit,
ThpPairingMethod,
} from './messages';
export type ThpStateSerialized = {
properties?: ThpDeviceProperties;
@@ -11,15 +17,22 @@ export type ThpStateSerialized = {
expectedResponses: number[]; // expected responses from the device
};
export type ThpPhase = 'handshake' | 'pairing' | 'paired';
export class ThpState {
private _properties?: ThpDeviceProperties;
private _pairingCredentials: ThpCredentials[] = [];
private _phase: ThpPhase = 'handshake';
private _isPaired: boolean = false;
private _handshakeCredentials?: ThpHandshakeCredentials;
private _channel: Buffer = Buffer.alloc(0);
private _sendBit: ThpMessageSyncBit = 0;
private _sendNonce: number = 0;
private _recvBit: ThpMessageSyncBit = 0;
private _recvNonce: number = 1;
private _expectedResponses: number[] = [];
private _selectedMethod?: ThpPairingMethod;
private _nfcSecret?: Buffer;
get properties() {
return this._properties;
@@ -29,6 +42,34 @@ export class ThpState {
this._properties = props;
}
get phase() {
return this._phase;
}
setPhase(phase: ThpPhase) {
this._phase = phase;
}
get isPaired() {
return this._isPaired;
}
get isAutoconnectPaired() {
return this._isPaired && this._pairingCredentials[0]?.autoconnect;
}
setIsPaired(isPaired: boolean) {
this._isPaired = isPaired;
}
get pairingMethod() {
return this._selectedMethod;
}
setPairingMethod(method: ThpPairingMethod) {
this._selectedMethod = method;
}
get pairingCredentials() {
return this._pairingCredentials;
}
@@ -41,6 +82,23 @@ export class ThpState {
}
}
setNfcSecret(secret: Buffer) {
this._nfcSecret = secret;
}
get nfcSecret() {
return this._nfcSecret;
}
get nfcData() {
if (this._selectedMethod === ThpPairingMethod.NFC && this._nfcSecret) {
return Buffer.concat([
this._nfcSecret,
this.handshakeCredentials!.handshakeHash.subarray(0, 16),
]);
}
}
get channel() {
return this._channel;
}
@@ -104,6 +162,33 @@ export class ThpState {
}
}
get handshakeCredentials() {
return this._handshakeCredentials;
}
updateHandshakeCredentials(newCredentials: Partial<ThpHandshakeCredentials>) {
if (!this._handshakeCredentials) {
this._handshakeCredentials = {
pairingMethods: [],
handshakeHash: Buffer.alloc(0),
handshakeCommitment: Buffer.alloc(0),
codeEntryChallenge: Buffer.alloc(0),
trezorEncryptedStaticPubkey: Buffer.alloc(0),
hostEncryptedStaticPubkey: Buffer.alloc(0),
staticKey: Buffer.alloc(0),
hostStaticPublicKey: Buffer.alloc(0),
hostKey: Buffer.alloc(0),
trezorKey: Buffer.alloc(0),
trezorCpacePublicKey: Buffer.alloc(0),
};
}
this._handshakeCredentials = {
...this._handshakeCredentials,
...newCredentials,
};
}
serialize(): ThpStateSerialized {
return {
properties: this._properties,
@@ -158,6 +243,9 @@ export class ThpState {
}
resetState() {
this._phase = 'handshake';
this._isPaired = false;
this._handshakeCredentials = undefined;
this._channel = Buffer.alloc(0);
this._sendBit = 0;
this._sendNonce = 0;
@@ -165,6 +253,8 @@ export class ThpState {
this._recvNonce = 1;
this._expectedResponses = [];
this._pairingCredentials = [];
this._selectedMethod = undefined;
this._nfcSecret = undefined;
}
toString() {

View File

@@ -1,13 +1,21 @@
import { ThpState } from './ThpState';
import {
CRC_LENGTH,
TAG_LENGTH,
THP_CONTROL_BYTE_DECRYPTED,
THP_CONTROL_BYTE_ENCRYPTED,
THP_CREATE_CHANNEL_RESPONSE,
THP_ERROR_HEADER_BYTE,
THP_HANDSHAKE_COMPLETION_RESPONSE,
THP_HANDSHAKE_INIT_RESPONSE,
THP_READ_ACK_HEADER_BYTE,
} from './constants';
import { aesgcm } from './crypto';
import { TransportProtocolDecode } from '../types';
import { crc32 } from './crypto/crc32';
import { ThpError, ThpMessageResponse } from './messages';
import { getHandshakeHash, getTrezorState } from './crypto/pairing';
import { getIvFromNonce } from './crypto/tools';
import { ThpDeviceProperties, ThpError, ThpMessageResponse } from './messages';
import { clearControlBit, readThpHeader } from './utils';
type ThpMessage = ReturnType<TransportProtocolDecode> & {
@@ -26,6 +34,14 @@ type ProtobufDecoder = (
type MessageV2 = ReturnType<TransportProtocolDecode>;
const decipherMessage = (key: Buffer, recvNonce: number, payload: Buffer, tag: Buffer) => {
const aes = aesgcm(key, getIvFromNonce(recvNonce));
aes.auth(Buffer.alloc(0));
const trezorMaskedStaticPubkey = aes.decrypt(payload, tag);
return trezorMaskedStaticPubkey.subarray(1); // NOTE: remove session_id (first byte)
};
// TODO: link-to-public-docs
// https://www.notion.so/satoshilabs/THP-Specification-2-0-18fdc5260606806ab573d0a7cba1897a
// example: 41ffff0020639ba57ff4e0c2343c830a0454335731180220002802280328042801c0171551
@@ -35,13 +51,12 @@ type MessageV2 = ReturnType<TransportProtocolDecode>;
const createChannelResponse = (
{ payload }: ThpMessage,
protobufDecoder: ProtobufDecoder,
): ThpMessageResponse => {
): ThpMessageResponse<'ThpCreateChannelResponse'> => {
const nonce = payload.subarray(0, 8);
const channel = payload.subarray(8, 10);
const props = payload.subarray(10, payload.length - CRC_LENGTH);
const properties = protobufDecoder('ThpDeviceProperties', props).message;
// TODO: add-crypto
// const handshakeHash = handleCreateChannelResponse(props);
const properties = protobufDecoder('ThpDeviceProperties', props).message as ThpDeviceProperties;
const handshakeHash = getHandshakeHash(props);
return {
type: 'ThpCreateChannelResponse',
@@ -49,10 +64,60 @@ const createChannelResponse = (
nonce,
channel,
properties,
// TODO: add-crypto
// handshakeHash,
handshakeHash,
},
} as any;
};
};
const readHandshakeInitResponse = ({
payload,
}: ThpMessage): ThpMessageResponse<'ThpHandshakeInitResponse'> => {
const trezorEphemeralPubkey = payload.subarray(0, 32);
const trezorEncryptedStaticPubkey = payload.subarray(32, 32 + 48);
const tag = payload.subarray(32 + 48, 32 + 48 + TAG_LENGTH);
return {
type: 'ThpHandshakeInitResponse',
message: {
trezorEphemeralPubkey,
trezorEncryptedStaticPubkey,
tag,
},
};
};
const readHandshakeCompletionResponse = ({
payload,
thpState,
}: ThpMessage): ThpMessageResponse<'ThpHandshakeCompletionResponse'> => {
const state = getTrezorState(thpState.handshakeCredentials!, payload);
return {
type: 'ThpHandshakeCompletionResponse',
message: {
state,
},
};
};
const readProtobufMessage = (
{ payload, thpState }: ThpMessage,
protobufDecoder: ProtobufDecoder,
): ThpMessageResponse => {
const tagPos = payload.length - TAG_LENGTH - CRC_LENGTH;
const cipheredMessage = payload.subarray(0, tagPos);
const tag = payload.subarray(tagPos, payload.length - CRC_LENGTH);
const decipheredMessage = decipherMessage(
thpState.handshakeCredentials!.trezorKey,
thpState.recvNonce,
cipheredMessage,
tag,
);
const messageType = decipheredMessage.readUInt16BE(0);
const messagePayload = decipheredMessage.subarray(2);
return protobufDecoder(messageType, messagePayload) as ThpMessageResponse;
};
const decodeReadAck = (): ThpMessageResponse => ({
@@ -144,7 +209,7 @@ export const decode = (
thpState?: ThpState,
): ThpMessageResponse => {
if (!thpState) {
throw new Error('Cannot decode THP message without ThpState');
throw new Error('ThpStateMissing');
}
validateCrc(decodedMessage);
@@ -169,5 +234,22 @@ export const decode = (
return createChannelResponse(message, protobufDecoder);
}
if (magic === THP_HANDSHAKE_INIT_RESPONSE) {
return readHandshakeInitResponse(message);
}
if (magic === THP_HANDSHAKE_COMPLETION_RESPONSE) {
return readHandshakeCompletionResponse(message);
}
if (magic === THP_CONTROL_BYTE_ENCRYPTED) {
return readProtobufMessage(message, protobufDecoder);
}
// TODO: decrypted message decoding (not implemented in FW)
if (magic === THP_CONTROL_BYTE_DECRYPTED) {
return readProtobufMessage(message, protobufDecoder);
}
throw new Error('Unknown message type: ' + magic);
};

View File

@@ -5,9 +5,12 @@ import {
THP_CONTROL_BYTE_ENCRYPTED,
THP_CREATE_CHANNEL_REQUEST,
THP_DEFAULT_CHANNEL,
THP_HANDSHAKE_COMPLETION_REQUEST,
THP_HANDSHAKE_INIT_REQUEST,
THP_READ_ACK_HEADER_BYTE,
} from './constants';
import { crc32 } from './crypto/crc32';
import { aesgcm, crc32 } from './crypto';
import { getIvFromNonce } from './crypto/tools';
import { addAckBit, addSequenceBit, getControlBit, isThpMessageName } from './utils';
// @trezor/protobuf encodeMessage without direct reference to protobuf root
@@ -19,6 +22,16 @@ type ProtobufEncoder = (
message: Buffer;
};
const cipherMessage = (key: Buffer, sendNonce: number, handshakeHash: Buffer, payload: Buffer) => {
// Set encrypted_payload = AES-GCM-ENCRYPT(key=k, IV=0^96, ad=h, plaintext=payload_binary).
const aes = aesgcm(key, getIvFromNonce(sendNonce));
aes.auth(handshakeHash);
const encryptedPayload = aes.encrypt(payload);
const encryptedPayloadTag = aes.finish();
return Buffer.concat([encryptedPayload, encryptedPayloadTag]);
};
// utility for **RequestPayload inputs/params
const getBytesFromField = (data: Record<string, unknown>, fieldName: string) => {
const value = data[fieldName];
@@ -39,10 +52,38 @@ const createChannelRequestPayload = (data: Record<string, unknown>) => {
return nonce;
};
export const encodePayload = (name: string, data: Record<string, unknown>, _thpState: ThpState) => {
const handshakeInitRequestPayload = (data: Record<string, unknown>, _thpState: ThpState) => {
const key = getBytesFromField(data, 'key');
if (!key) {
throw new Error('ThpHandshakeInitRequest missing key field');
}
return key;
};
const handshakeCompletionRequestPayload = (data: Record<string, unknown>) => {
const hostPubkey = getBytesFromField(data, 'hostPubkey');
if (!hostPubkey) {
throw new Error('ThpHandshakeCompletionRequest missing hostPubkey field');
}
const encryptedPayload = getBytesFromField(data, 'encryptedPayload');
if (!encryptedPayload) {
throw new Error('ThpHandshakeCompletionRequest missing encryptedPayload field');
}
return Buffer.concat([hostPubkey, encryptedPayload]);
};
export const encodePayload = (name: string, data: Record<string, unknown>, thpState: ThpState) => {
if (name === 'ThpCreateChannelRequest') {
return createChannelRequestPayload(data);
}
if (name === 'ThpHandshakeInitRequest') {
return handshakeInitRequestPayload(data, thpState);
}
if (name === 'ThpHandshakeCompletionRequest') {
return handshakeCompletionRequestPayload(data);
}
return Buffer.alloc(0);
};
@@ -64,17 +105,47 @@ const createChannelRequest = (data: Buffer, channel: Buffer) => {
return Buffer.concat([message, crc]);
};
const handshakeInitRequest = (data: Buffer, channel: Buffer) => {
const length = Buffer.alloc(2);
length.writeUInt16BE(data.length + CRC_LENGTH);
const magic = Buffer.from([THP_HANDSHAKE_INIT_REQUEST]);
const message = Buffer.concat([magic, channel, length, data]);
const crc = crc32(message);
return Buffer.concat([message, crc]);
};
const handshakeCompletionRequest = (data: Buffer, channel: Buffer, sendBit: number) => {
const length = Buffer.alloc(2);
length.writeUInt16BE(data.length + CRC_LENGTH);
const magic = addSequenceBit(THP_HANDSHAKE_COMPLETION_REQUEST, sendBit);
const message = Buffer.concat([magic, channel, length, data]);
const crc = crc32(message);
return Buffer.concat([message, crc]);
};
const encodeThpMessage = (
messageType: string,
data: Buffer,
channel: Buffer,
_thpState: ThpState,
thpState: ThpState,
) => {
if (messageType === 'ThpCreateChannelRequest') {
return createChannelRequest(data, channel);
}
throw new Error(`Unknown ThpMessage type ${messageType}`);
if (messageType === 'ThpHandshakeInitRequest') {
return handshakeInitRequest(data, channel);
}
if (messageType === 'ThpHandshakeCompletionRequest') {
return handshakeCompletionRequest(data, channel, thpState.sendBit || 0);
}
throw new Error(`Unknown Thp message type ${messageType}`);
};
// TODO: link-to-public-docs
@@ -86,25 +157,24 @@ export const encodeProtobufMessage = (
thpState?: ThpState,
) => {
if (!thpState) {
throw new Error('ThpState missing');
throw new Error('ThpStateMissing');
}
const length = Buffer.alloc(2);
length.writeUInt16BE(1 + 2 + data.length + TAG_LENGTH + CRC_LENGTH); // 1 session_id + 2 messageType + protobuf len + 16 tag + 4 crc
// TODO: distinguish encrypted and decrypted messages (not implemented in FW)
const magic = addSequenceBit(THP_CONTROL_BYTE_ENCRYPTED, thpState.sendBit);
const header = Buffer.concat([magic, channel]);
const messageTypeBytes = Buffer.alloc(2);
messageTypeBytes.writeUInt16BE(messageType);
// TODO: add-crypto
// const cipheredMessage = cipherMessage(
// thpState.handshakeCredentials.hostKey,
// thpState.sendNonce,
// Buffer.alloc(0),
// Buffer.concat([thpState.sessionId, messageTypeBytes, data]),
// );
const cipheredMessage = Buffer.concat([Buffer.alloc(0), messageTypeBytes, data]);
const cipheredMessage = cipherMessage(
thpState.handshakeCredentials!.hostKey,
thpState.sendNonce,
Buffer.alloc(0),
Buffer.concat([thpState.sessionId, messageTypeBytes, data]),
);
const message = Buffer.concat([header, length, cipheredMessage]);
const crc = crc32(message);
@@ -150,10 +220,9 @@ export const encode = (options: {
data: Record<string, unknown>;
thpState?: ThpState;
protobufEncoder: ProtobufEncoder;
header?: Buffer;
}) => {
if (!options.thpState) {
throw new Error('ThpState missing');
throw new Error('ThpStateMissing');
}
const channel = options.thpState.channel || THP_DEFAULT_CHANNEL;

View File

@@ -2,7 +2,18 @@ export * from './decode';
export * from './encode';
export * from './messages';
export * from './utils';
export * as constants from './constants';
export {
getCpaceHostKeys,
getSharedSecret,
getHandshakeHash,
handleHandshakeInit,
validateCodeEntryTag,
validateQrCodeTag,
validateNfcTag,
} from './crypto/pairing';
export { ThpState } from './ThpState';
export { getCurve25519KeyPair } from './crypto/curve25519';
export const name = 'thp';

View File

@@ -37,14 +37,9 @@ export type ThpHandshakeInitRequest = {
};
export type ThpHandshakeInitResponse = {
handshakeHash: Buffer;
trezorEphemeralPubkey: Buffer;
trezorEncryptedStaticPubkey: Buffer;
trezorMaskedStaticPubkey: Buffer;
tag: Buffer;
hostEncryptedStaticPubkey: Buffer;
hostKey: Buffer;
trezorKey: Buffer;
};
export type ThpHandshakeCompletionRequest = {
@@ -54,7 +49,6 @@ export type ThpHandshakeCompletionRequest = {
export type ThpHandshakeCompletionResponse = {
state: 0 | 1;
tag: Buffer;
};
export type ThpMessageType = ThpProtobufMessageType & {

View File

@@ -4,13 +4,13 @@ describe('pairing', () => {
it('findKnownPairingCredentials', () => {
const knownCredentials = [
{
trezor_static_pubkey:
trezor_static_public_key:
'1317c99c16fce04935782ed250cf0cacb12216f739cea55257258a2ff9440763',
credential:
'0a0f0a0d5472657a6f72436f6e6e6563741220f69918996c0afa1045b3625d06e7e816b0c4c4bd3902dfd4cad068b3f2425ec8',
},
{
trezor_static_pubkey:
trezor_static_public_key:
'2bcdbc9fd7949c3f37aa80a53801f52ec554facfe76118030926294250fd6838',
credential:
'0a110a0d5472657a6f72436f6e6e65637410011220b97509ef252b07dcc70071c9d13dd70746d8a9fb671765049ca74e58b9058d6b',

View File

@@ -51,9 +51,11 @@ describe('protocol-thp', () => {
expect(decoded.message).toMatchObject({
channel: Buffer.from('3c83', 'hex'),
nonce,
handshakeHash: Buffer.from(
'a1615c6dc1c2a2df8155ea5b54d0f39c320d4908e4b1eaf8d24ab5c4b466e947',
'hex',
),
// properties asserted below
// TODO: add-crypto
// handshakeHash
});
// @ts-expect-error