feat(python): skip THP ACKs on TrezorClientThp._call()

Due to event loop restart, the last response from the device will be explicitly ACKed.

It will allow piggyback `ButtonRequest` THP ACKs using corresponding `ButtonAck` messages.
This commit is contained in:
Roman Zeyde
2026-02-18 16:19:56 +01:00
parent 117f83e0ca
commit 3b59a26d0d
2 changed files with 43 additions and 4 deletions

View File

@@ -300,7 +300,8 @@ class Channel:
if e.code == exceptions.ThpErrorCode.DEVICE_LOCKED:
raise DeviceLockedError from e
raise
self._send_ack(message)
if not self.is_ack_piggybacking_allowed:
self._send_ack(message)
if not message.is_handshake_init_response():
raise ProtocolError(f"Not a valid handshake init response: {message}")
@@ -408,9 +409,17 @@ class Channel:
continue
raise
def _send_ack(self, acked_message: Message) -> None:
ack = control_byte.make_ack_for(acked_message.ctrl_byte)
ack_message = Message(ack, acked_message.cid, b"")
def _send_ack(self, acked_message: Message | None) -> None:
if self.is_ack_piggybacking_allowed and self._active_workflow is not None:
return
if acked_message is not None:
ack = control_byte.make_ack_for(acked_message.ctrl_byte)
ack_message = Message(ack, acked_message.cid, b"")
else:
ack = control_byte.make_ack(not self.sync_bit_receive)
ack_message = Message(ack, self.channel_id, b"")
thp_io.write_payload_to_wire(self.transport, ack_message)
def _read_ack(self, message: Message) -> None:
@@ -432,6 +441,23 @@ class Channel:
f"Failed to read ACK in {retries} retries for message: {message}"
)
@contextmanager
def piggyback_acks(self, marker: object) -> t.Generator[None, None, None]:
# Make sure the previous workflow is over.
assert self._active_workflow is None
self._active_workflow = marker
# Skip explicit ACKs during this workflow
try:
yield
finally:
active = self._active_workflow
self._active_workflow = None
assert active is marker
if self.is_ack_piggybacking_allowed:
# Explicitly ACK the latest received message. The device may restart
# the event loop, so the next request will be sent in a separate message.
self._send_ack(None)
def write_chunk(self, data: bytes, /) -> None:
self._assert_handshake_done()
encrypted_data = self.noise.encrypt(data)

View File

@@ -188,6 +188,19 @@ class TrezorClientThp(client.TrezorClient[ThpSession]):
else:
self._session_message_queue[session_id].append(msg)
def _call(
self,
session: ThpSession,
msg: client.MessageType,
*,
expect: type[client.MT] = client.MessageType,
timeout: float | None = None,
) -> client.MT:
with self.channel.piggyback_acks(msg):
return super()._call(
session=session, msg=msg, expect=expect, timeout=timeout
)
@staticmethod
def detect_model(props: messages.ThpDeviceProperties) -> models.TrezorModel:
internal_model = props.internal_model