mirror of
https://github.com/trezor/trezor-firmware.git
synced 2026-02-20 00:33:30 +01:00
refactor(python): better handling of TREZOR_BLE flag
[no changelog]
This commit is contained in:
@@ -139,13 +139,11 @@ class TrezorConnection:
|
||||
session_id: bytes | None,
|
||||
passphrase_on_host: bool,
|
||||
script: bool,
|
||||
ble_enabled: bool,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.session_id = session_id
|
||||
self.passphrase_on_host = passphrase_on_host
|
||||
self.script = script
|
||||
self.ble_enabled = ble_enabled
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
@@ -219,9 +217,7 @@ class TrezorConnection:
|
||||
_TRANSPORT = None
|
||||
try:
|
||||
# look for transport without prefix search
|
||||
_TRANSPORT = transport.get_transport(
|
||||
self.path, prefix_search=False, ble_enabled=self.ble_enabled
|
||||
)
|
||||
_TRANSPORT = transport.get_transport(self.path, prefix_search=False)
|
||||
except Exception:
|
||||
# most likely not found. try again below.
|
||||
pass
|
||||
@@ -229,9 +225,7 @@ class TrezorConnection:
|
||||
# look for transport with prefix search
|
||||
# if this fails, we want the exception to bubble up to the caller
|
||||
if not _TRANSPORT:
|
||||
_TRANSPORT = transport.get_transport(
|
||||
self.path, prefix_search=True, ble_enabled=self.ble_enabled
|
||||
)
|
||||
_TRANSPORT = transport.get_transport(self.path, prefix_search=True)
|
||||
|
||||
_TRANSPORT.open()
|
||||
atexit.register(_TRANSPORT.close)
|
||||
|
||||
@@ -28,6 +28,7 @@ import click
|
||||
from .. import log, messages, protobuf
|
||||
from ..transport import DeviceIsBusy, enumerate_devices
|
||||
from ..transport.session import Session
|
||||
from ..transport.ble import BleTransport
|
||||
from ..transport.udp import UdpTransport
|
||||
from . import (
|
||||
AliasedGroup,
|
||||
@@ -175,7 +176,6 @@ def configure_logging(verbose: int) -> None:
|
||||
"--ble/--no-ble",
|
||||
help="Enable/disable support for Bluetooth Low Energy.",
|
||||
is_flag=True,
|
||||
default=(os.environ.get("TREZOR_BLE") == "1"),
|
||||
)
|
||||
@click.option("-v", "--verbose", count=True, help="Show communication messages.")
|
||||
@click.option(
|
||||
@@ -210,7 +210,7 @@ def configure_logging(verbose: int) -> None:
|
||||
def cli_main(
|
||||
ctx: click.Context,
|
||||
path: str,
|
||||
ble: bool,
|
||||
ble: bool | None,
|
||||
verbose: int,
|
||||
is_json: bool,
|
||||
passphrase_on_host: bool,
|
||||
@@ -220,6 +220,12 @@ def cli_main(
|
||||
) -> None:
|
||||
configure_logging(verbose)
|
||||
|
||||
# if BLE was explicitly enabled, raise an error if it's not available
|
||||
if ble and not BleTransport.ENABLED:
|
||||
raise click.ClickException("BLE support is unavailable")
|
||||
|
||||
BleTransport.ENABLED = ble or (os.environ.get("TREZOR_BLE") == "1")
|
||||
|
||||
bytes_session_id: Optional[bytes] = None
|
||||
if session_id is not None:
|
||||
try:
|
||||
@@ -227,9 +233,7 @@ def cli_main(
|
||||
except ValueError:
|
||||
raise click.ClickException(f"Not a valid session id: {session_id}")
|
||||
|
||||
ctx.obj = TrezorConnection(
|
||||
path, bytes_session_id, passphrase_on_host, script, ble_enabled=ble
|
||||
)
|
||||
ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script)
|
||||
|
||||
# Optionally record the screen into a specified directory.
|
||||
if record:
|
||||
@@ -303,13 +307,13 @@ def list_devices(
|
||||
) -> Optional[Iterable["Transport"]]:
|
||||
"""List connected Trezor devices."""
|
||||
if no_resolve:
|
||||
for d in enumerate_devices(ble_enabled=obj.ble_enabled):
|
||||
for d in enumerate_devices():
|
||||
click.echo(d.get_path())
|
||||
return
|
||||
|
||||
from . import get_client
|
||||
|
||||
for transport in enumerate_devices(ble_enabled=obj.ble_enabled):
|
||||
for transport in enumerate_devices():
|
||||
try:
|
||||
transport.open()
|
||||
client = get_client(transport)
|
||||
|
||||
@@ -96,32 +96,28 @@ class Transport:
|
||||
CHUNK_SIZE: t.ClassVar[int | None]
|
||||
|
||||
|
||||
def all_transports(ble_enabled: bool | None = None) -> t.Iterable[t.Type["Transport"]]:
|
||||
def all_transports() -> t.Iterable[type[Transport]]:
|
||||
from .ble import BleTransport
|
||||
from .bridge import BridgeTransport
|
||||
from .hid import HidTransport
|
||||
from .udp import UdpTransport
|
||||
from .webusb import WebUsbTransport
|
||||
|
||||
transports: t.Tuple[t.Type["Transport"], ...] = (
|
||||
transports: tuple[type[Transport], ...] = (
|
||||
BridgeTransport,
|
||||
HidTransport,
|
||||
UdpTransport,
|
||||
WebUsbTransport,
|
||||
BleTransport,
|
||||
)
|
||||
if ble_enabled is None:
|
||||
ble_enabled = os.environ.get("TREZOR_BLE") == "1"
|
||||
if ble_enabled:
|
||||
transports += (BleTransport,)
|
||||
return set(t for t in transports if t.ENABLED)
|
||||
|
||||
|
||||
def enumerate_devices(
|
||||
models: t.Iterable[TrezorModel] | None = None,
|
||||
ble_enabled: bool | None = None,
|
||||
) -> t.Sequence[Transport]:
|
||||
devices: t.List[Transport] = []
|
||||
for transport in all_transports(ble_enabled=ble_enabled):
|
||||
for transport in all_transports():
|
||||
name = transport.__name__
|
||||
try:
|
||||
found = list(transport.enumerate(models))
|
||||
@@ -135,14 +131,10 @@ def enumerate_devices(
|
||||
return devices
|
||||
|
||||
|
||||
def get_transport(
|
||||
path: str | None = None,
|
||||
prefix_search: bool = False,
|
||||
ble_enabled: bool | None = None,
|
||||
) -> Transport:
|
||||
def get_transport(path: str | None = None, prefix_search: bool = False) -> Transport:
|
||||
if path is None:
|
||||
try:
|
||||
return next(iter(enumerate_devices(ble_enabled=ble_enabled)))
|
||||
return next(iter(enumerate_devices()))
|
||||
except StopIteration:
|
||||
raise TransportException("No Trezor device found") from None
|
||||
|
||||
@@ -157,11 +149,7 @@ def get_transport(
|
||||
"prefix" if prefix_search else "full path", path
|
||||
)
|
||||
)
|
||||
transports = [
|
||||
t
|
||||
for t in all_transports(ble_enabled=ble_enabled)
|
||||
if match_prefix(path, t.PATH_PREFIX)
|
||||
]
|
||||
transports = [t for t in all_transports() if match_prefix(path, t.PATH_PREFIX)]
|
||||
if transports:
|
||||
return transports[0].find_by_path(path, prefix_search=prefix_search)
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ SHUTDOWN_TIMEOUT_SECONDS = 10
|
||||
|
||||
|
||||
class BleTransport(Transport):
|
||||
ENABLED = True
|
||||
ENABLED = BLEAK_IMPORTED
|
||||
PATH_PREFIX = "ble"
|
||||
CHUNK_SIZE = 244
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.device import apply_settings
|
||||
from trezorlib.device import wipe as wipe_device
|
||||
from trezorlib.transport import Timeout, enumerate_devices, get_transport
|
||||
from trezorlib.transport.ble import BleTransport
|
||||
from trezorlib.transport.thp.protocol_v1 import ProtocolV1Channel, UnexpectedMagicError
|
||||
|
||||
# register rewrites before importing from local package
|
||||
@@ -175,6 +176,9 @@ def _client_from_path(
|
||||
|
||||
|
||||
def _find_client(request: pytest.FixtureRequest, interact: bool) -> Client:
|
||||
if os.environ.get("TREZOR_BLE") != "1":
|
||||
BleTransport.ENABLED = False
|
||||
|
||||
devices = enumerate_devices()
|
||||
for device in devices:
|
||||
return Client(device, auto_interact=not interact, open_transport=True)
|
||||
|
||||
Reference in New Issue
Block a user