refactor(python): better handling of TREZOR_BLE flag

[no changelog]
This commit is contained in:
matejcik
2025-10-14 15:28:27 +02:00
committed by Roman Zeyde
parent 942e6716b3
commit 1a2c9bf438
5 changed files with 25 additions and 35 deletions

View File

@@ -139,13 +139,11 @@ class TrezorConnection:
session_id: bytes | None,
passphrase_on_host: bool,
script: bool,
ble_enabled: bool,
) -> None:
self.path = path
self.session_id = session_id
self.passphrase_on_host = passphrase_on_host
self.script = script
self.ble_enabled = ble_enabled
def get_session(
self,
@@ -219,9 +217,7 @@ class TrezorConnection:
_TRANSPORT = None
try:
# look for transport without prefix search
_TRANSPORT = transport.get_transport(
self.path, prefix_search=False, ble_enabled=self.ble_enabled
)
_TRANSPORT = transport.get_transport(self.path, prefix_search=False)
except Exception:
# most likely not found. try again below.
pass
@@ -229,9 +225,7 @@ class TrezorConnection:
# look for transport with prefix search
# if this fails, we want the exception to bubble up to the caller
if not _TRANSPORT:
_TRANSPORT = transport.get_transport(
self.path, prefix_search=True, ble_enabled=self.ble_enabled
)
_TRANSPORT = transport.get_transport(self.path, prefix_search=True)
_TRANSPORT.open()
atexit.register(_TRANSPORT.close)

View File

@@ -28,6 +28,7 @@ import click
from .. import log, messages, protobuf
from ..transport import DeviceIsBusy, enumerate_devices
from ..transport.session import Session
from ..transport.ble import BleTransport
from ..transport.udp import UdpTransport
from . import (
AliasedGroup,
@@ -175,7 +176,6 @@ def configure_logging(verbose: int) -> None:
"--ble/--no-ble",
help="Enable/disable support for Bluetooth Low Energy.",
is_flag=True,
default=(os.environ.get("TREZOR_BLE") == "1"),
)
@click.option("-v", "--verbose", count=True, help="Show communication messages.")
@click.option(
@@ -210,7 +210,7 @@ def configure_logging(verbose: int) -> None:
def cli_main(
ctx: click.Context,
path: str,
ble: bool,
ble: bool | None,
verbose: int,
is_json: bool,
passphrase_on_host: bool,
@@ -220,6 +220,12 @@ def cli_main(
) -> None:
configure_logging(verbose)
# if BLE was explicitly enabled, raise an error if it's not available
if ble and not BleTransport.ENABLED:
raise click.ClickException("BLE support is unavailable")
BleTransport.ENABLED = ble or (os.environ.get("TREZOR_BLE") == "1")
bytes_session_id: Optional[bytes] = None
if session_id is not None:
try:
@@ -227,9 +233,7 @@ def cli_main(
except ValueError:
raise click.ClickException(f"Not a valid session id: {session_id}")
ctx.obj = TrezorConnection(
path, bytes_session_id, passphrase_on_host, script, ble_enabled=ble
)
ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script)
# Optionally record the screen into a specified directory.
if record:
@@ -303,13 +307,13 @@ def list_devices(
) -> Optional[Iterable["Transport"]]:
"""List connected Trezor devices."""
if no_resolve:
for d in enumerate_devices(ble_enabled=obj.ble_enabled):
for d in enumerate_devices():
click.echo(d.get_path())
return
from . import get_client
for transport in enumerate_devices(ble_enabled=obj.ble_enabled):
for transport in enumerate_devices():
try:
transport.open()
client = get_client(transport)

View File

@@ -96,32 +96,28 @@ class Transport:
CHUNK_SIZE: t.ClassVar[int | None]
def all_transports(ble_enabled: bool | None = None) -> t.Iterable[t.Type["Transport"]]:
def all_transports() -> t.Iterable[type[Transport]]:
from .ble import BleTransport
from .bridge import BridgeTransport
from .hid import HidTransport
from .udp import UdpTransport
from .webusb import WebUsbTransport
transports: t.Tuple[t.Type["Transport"], ...] = (
transports: tuple[type[Transport], ...] = (
BridgeTransport,
HidTransport,
UdpTransport,
WebUsbTransport,
BleTransport,
)
if ble_enabled is None:
ble_enabled = os.environ.get("TREZOR_BLE") == "1"
if ble_enabled:
transports += (BleTransport,)
return set(t for t in transports if t.ENABLED)
def enumerate_devices(
models: t.Iterable[TrezorModel] | None = None,
ble_enabled: bool | None = None,
) -> t.Sequence[Transport]:
devices: t.List[Transport] = []
for transport in all_transports(ble_enabled=ble_enabled):
for transport in all_transports():
name = transport.__name__
try:
found = list(transport.enumerate(models))
@@ -135,14 +131,10 @@ def enumerate_devices(
return devices
def get_transport(
path: str | None = None,
prefix_search: bool = False,
ble_enabled: bool | None = None,
) -> Transport:
def get_transport(path: str | None = None, prefix_search: bool = False) -> Transport:
if path is None:
try:
return next(iter(enumerate_devices(ble_enabled=ble_enabled)))
return next(iter(enumerate_devices()))
except StopIteration:
raise TransportException("No Trezor device found") from None
@@ -157,11 +149,7 @@ def get_transport(
"prefix" if prefix_search else "full path", path
)
)
transports = [
t
for t in all_transports(ble_enabled=ble_enabled)
if match_prefix(path, t.PATH_PREFIX)
]
transports = [t for t in all_transports() if match_prefix(path, t.PATH_PREFIX)]
if transports:
return transports[0].find_by_path(path, prefix_search=prefix_search)

View File

@@ -53,7 +53,7 @@ SHUTDOWN_TIMEOUT_SECONDS = 10
class BleTransport(Transport):
ENABLED = True
ENABLED = BLEAK_IMPORTED
PATH_PREFIX = "ble"
CHUNK_SIZE = 244

View File

@@ -33,6 +33,7 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.device import apply_settings
from trezorlib.device import wipe as wipe_device
from trezorlib.transport import Timeout, enumerate_devices, get_transport
from trezorlib.transport.ble import BleTransport
from trezorlib.transport.thp.protocol_v1 import ProtocolV1Channel, UnexpectedMagicError
# register rewrites before importing from local package
@@ -175,6 +176,9 @@ def _client_from_path(
def _find_client(request: pytest.FixtureRequest, interact: bool) -> Client:
if os.environ.get("TREZOR_BLE") != "1":
BleTransport.ENABLED = False
devices = enumerate_devices()
for device in devices:
return Client(device, auto_interact=not interact, open_transport=True)