diff --git a/python/src/trezorlib/thp/channel.py b/python/src/trezorlib/thp/channel.py index 2d0317acf4..00ef7e5002 100644 --- a/python/src/trezorlib/thp/channel.py +++ b/python/src/trezorlib/thp/channel.py @@ -300,7 +300,8 @@ class Channel: if e.code == exceptions.ThpErrorCode.DEVICE_LOCKED: raise DeviceLockedError from e raise - self._send_ack(message) + if not self.is_ack_piggybacking_allowed: + self._send_ack(message) if not message.is_handshake_init_response(): raise ProtocolError(f"Not a valid handshake init response: {message}") @@ -408,9 +409,17 @@ class Channel: continue raise - def _send_ack(self, acked_message: Message) -> None: - ack = control_byte.make_ack_for(acked_message.ctrl_byte) - ack_message = Message(ack, acked_message.cid, b"") + def _send_ack(self, acked_message: Message | None) -> None: + if self.is_ack_piggybacking_allowed and self._active_workflow is not None: + return + + if acked_message is not None: + ack = control_byte.make_ack_for(acked_message.ctrl_byte) + ack_message = Message(ack, acked_message.cid, b"") + else: + ack = control_byte.make_ack(not self.sync_bit_receive) + ack_message = Message(ack, self.channel_id, b"") + thp_io.write_payload_to_wire(self.transport, ack_message) def _read_ack(self, message: Message) -> None: @@ -432,6 +441,23 @@ class Channel: f"Failed to read ACK in {retries} retries for message: {message}" ) + @contextmanager + def piggyback_acks(self, marker: object) -> t.Generator[None, None, None]: + # Make sure the previous workflow is over. + assert self._active_workflow is None + self._active_workflow = marker + # Skip explicit ACKs during this workflow + try: + yield + finally: + active = self._active_workflow + self._active_workflow = None + assert active is marker + if self.is_ack_piggybacking_allowed: + # Explicitly ACK the latest received message. The device may restart + # the event loop, so the next request will be sent in a separate message. + self._send_ack(None) + def write_chunk(self, data: bytes, /) -> None: self._assert_handshake_done() encrypted_data = self.noise.encrypt(data) diff --git a/python/src/trezorlib/thp/client.py b/python/src/trezorlib/thp/client.py index 8c240223c9..778b0e6066 100644 --- a/python/src/trezorlib/thp/client.py +++ b/python/src/trezorlib/thp/client.py @@ -188,6 +188,19 @@ class TrezorClientThp(client.TrezorClient[ThpSession]): else: self._session_message_queue[session_id].append(msg) + def _call( + self, + session: ThpSession, + msg: client.MessageType, + *, + expect: type[client.MT] = client.MessageType, + timeout: float | None = None, + ) -> client.MT: + with self.channel.piggyback_acks(msg): + return super()._call( + session=session, msg=msg, expect=expect, timeout=timeout + ) + @staticmethod def detect_model(props: messages.ThpDeviceProperties) -> models.TrezorModel: internal_model = props.internal_model