mirror of
https://github.com/trezor/trezor-firmware.git
synced 2026-03-13 18:58:48 +01:00
chore(tests): update to kwargs usage and new btc.sign_tx API
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import logging
|
||||
import textwrap
|
||||
from collections import namedtuple
|
||||
from copy import deepcopy
|
||||
|
||||
@@ -245,6 +246,96 @@ class DebugUI:
|
||||
return self.passphrase
|
||||
|
||||
|
||||
class MessageFilter:
|
||||
def __init__(self, message_type, **fields):
|
||||
self.message_type = message_type
|
||||
self.fields = {}
|
||||
self.update_fields(**fields)
|
||||
|
||||
def update_fields(self, **fields):
|
||||
for name, value in fields.items():
|
||||
try:
|
||||
self.fields[name] = self.from_message_or_type(value)
|
||||
except TypeError:
|
||||
self.fields[name] = value
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_message_or_type(cls, message_or_type):
|
||||
if isinstance(message_or_type, cls):
|
||||
return message_or_type
|
||||
if isinstance(message_or_type, protobuf.MessageType):
|
||||
return cls.from_message(message_or_type)
|
||||
if isinstance(message_or_type, type) and issubclass(
|
||||
message_or_type, protobuf.MessageType
|
||||
):
|
||||
return cls(message_or_type)
|
||||
raise TypeError("Invalid kind of expected response")
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, message):
|
||||
fields = {}
|
||||
for field in message.keys():
|
||||
value = getattr(message, field)
|
||||
if value in (None, []):
|
||||
continue
|
||||
fields[field] = value
|
||||
return cls(type(message), **fields)
|
||||
|
||||
def match(self, message):
|
||||
if type(message) != self.message_type:
|
||||
return False
|
||||
|
||||
for field, expected_value in self.fields.items():
|
||||
actual_value = getattr(message, field, None)
|
||||
if isinstance(expected_value, MessageFilter):
|
||||
if not expected_value.match(actual_value):
|
||||
return False
|
||||
elif expected_value != actual_value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def format(self, maxwidth=80):
|
||||
fields = []
|
||||
for fname, ftype, _ in self.message_type.get_fields().values():
|
||||
if fname not in self.fields:
|
||||
continue
|
||||
value = self.fields[fname]
|
||||
if isinstance(ftype, protobuf.EnumType) and isinstance(value, int):
|
||||
field_str = ftype.to_str(value)
|
||||
elif isinstance(value, MessageFilter):
|
||||
field_str = value.format(maxwidth - 4)
|
||||
elif isinstance(value, protobuf.MessageType):
|
||||
field_str = protobuf.format_message(value)
|
||||
else:
|
||||
field_str = repr(value)
|
||||
field_str = textwrap.indent(field_str, " ").lstrip()
|
||||
fields.append((fname, field_str))
|
||||
|
||||
pairs = ["{}={}".format(k, v) for k, v in fields]
|
||||
oneline_str = ", ".join(pairs)
|
||||
if len(oneline_str) < maxwidth:
|
||||
return "{}({})".format(self.message_type.__name__, oneline_str)
|
||||
else:
|
||||
item = []
|
||||
item.append("{}(".format(self.message_type.__name__))
|
||||
for pair in pairs:
|
||||
item.append(" {}".format(pair))
|
||||
item.append(")")
|
||||
return "\n".join(item)
|
||||
|
||||
|
||||
class MessageFilterGenerator:
|
||||
def __getattr__(self, key):
|
||||
message_type = getattr(messages, key)
|
||||
return MessageFilter(message_type).update_fields
|
||||
|
||||
|
||||
message_filters = MessageFilterGenerator()
|
||||
|
||||
|
||||
class TrezorClientDebugLink(TrezorClient):
|
||||
# This class implements automatic responses
|
||||
# and other functionality for unit tests
|
||||
@@ -417,13 +508,15 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
raise RuntimeError("Must be called inside 'with' statement")
|
||||
|
||||
# make sure all items are (bool, message) tuples
|
||||
expected_with_validity = [
|
||||
expected_with_validity = (
|
||||
e if isinstance(e, tuple) else (True, e) for e in expected
|
||||
]
|
||||
)
|
||||
|
||||
# only apply those items that are (True, message)
|
||||
self.expected_responses = [
|
||||
expected for valid, expected in expected_with_validity if valid
|
||||
MessageFilter.from_message_or_type(expected)
|
||||
for valid, expected in expected_with_validity
|
||||
if valid
|
||||
]
|
||||
|
||||
self.current_response = 0
|
||||
@@ -469,23 +562,7 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
for i in range(start_at, stop_at):
|
||||
exp = self.expected_responses[i]
|
||||
prefix = " " if i != self.current_response else ">>> "
|
||||
set_fields = {
|
||||
key: value
|
||||
for key, value in exp.__dict__.items()
|
||||
if value is not None and value != []
|
||||
}
|
||||
oneline_str = ", ".join("{}={!r}".format(*i) for i in set_fields.items())
|
||||
if len(oneline_str) < 60:
|
||||
output.append(
|
||||
"{}{}({})".format(prefix, exp.__class__.__name__, oneline_str)
|
||||
)
|
||||
else:
|
||||
item = []
|
||||
item.append("{}{}(".format(prefix, exp.__class__.__name__))
|
||||
for key, value in set_fields.items():
|
||||
item.append("{} {}={!r}".format(prefix, key, value))
|
||||
item.append("{})".format(prefix))
|
||||
output.append("\n".join(item))
|
||||
output.append(textwrap.indent(exp.format(), prefix))
|
||||
if stop_at < len(self.expected_responses):
|
||||
omitted = len(self.expected_responses) - stop_at
|
||||
output.append(" (...{} following responses omitted)".format(omitted))
|
||||
@@ -493,7 +570,7 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
output.append("")
|
||||
if msg is not None:
|
||||
output.append("Actually received:")
|
||||
output.append(protobuf.format_message(msg))
|
||||
output.append(textwrap.indent(protobuf.format_message(msg), " "))
|
||||
else:
|
||||
output.append("This message was never received.")
|
||||
raise AssertionError("\n".join(output))
|
||||
@@ -511,15 +588,9 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
|
||||
expected = self.expected_responses[self.current_response]
|
||||
|
||||
if msg.__class__ != expected.__class__:
|
||||
if not expected.match(msg):
|
||||
self._raise_unexpected_response(msg)
|
||||
|
||||
for field, value in expected.__dict__.items():
|
||||
if value is None or value == []:
|
||||
continue
|
||||
if getattr(msg, field) != value:
|
||||
self._raise_unexpected_response(msg)
|
||||
|
||||
self.current_response += 1
|
||||
|
||||
def mnemonic_callback(self, _):
|
||||
|
||||
Reference in New Issue
Block a user