chore(tests): update to kwargs usage and new btc.sign_tx API

This commit is contained in:
matejcik
2020-09-15 13:06:41 +02:00
committed by matejcik
parent b41021a5fb
commit 08d896f2f9
42 changed files with 361 additions and 263 deletions

View File

@@ -15,6 +15,7 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import textwrap
from collections import namedtuple
from copy import deepcopy
@@ -245,6 +246,96 @@ class DebugUI:
return self.passphrase
class MessageFilter:
def __init__(self, message_type, **fields):
self.message_type = message_type
self.fields = {}
self.update_fields(**fields)
def update_fields(self, **fields):
for name, value in fields.items():
try:
self.fields[name] = self.from_message_or_type(value)
except TypeError:
self.fields[name] = value
return self
@classmethod
def from_message_or_type(cls, message_or_type):
if isinstance(message_or_type, cls):
return message_or_type
if isinstance(message_or_type, protobuf.MessageType):
return cls.from_message(message_or_type)
if isinstance(message_or_type, type) and issubclass(
message_or_type, protobuf.MessageType
):
return cls(message_or_type)
raise TypeError("Invalid kind of expected response")
@classmethod
def from_message(cls, message):
fields = {}
for field in message.keys():
value = getattr(message, field)
if value in (None, []):
continue
fields[field] = value
return cls(type(message), **fields)
def match(self, message):
if type(message) != self.message_type:
return False
for field, expected_value in self.fields.items():
actual_value = getattr(message, field, None)
if isinstance(expected_value, MessageFilter):
if not expected_value.match(actual_value):
return False
elif expected_value != actual_value:
return False
return True
def format(self, maxwidth=80):
fields = []
for fname, ftype, _ in self.message_type.get_fields().values():
if fname not in self.fields:
continue
value = self.fields[fname]
if isinstance(ftype, protobuf.EnumType) and isinstance(value, int):
field_str = ftype.to_str(value)
elif isinstance(value, MessageFilter):
field_str = value.format(maxwidth - 4)
elif isinstance(value, protobuf.MessageType):
field_str = protobuf.format_message(value)
else:
field_str = repr(value)
field_str = textwrap.indent(field_str, " ").lstrip()
fields.append((fname, field_str))
pairs = ["{}={}".format(k, v) for k, v in fields]
oneline_str = ", ".join(pairs)
if len(oneline_str) < maxwidth:
return "{}({})".format(self.message_type.__name__, oneline_str)
else:
item = []
item.append("{}(".format(self.message_type.__name__))
for pair in pairs:
item.append(" {}".format(pair))
item.append(")")
return "\n".join(item)
class MessageFilterGenerator:
def __getattr__(self, key):
message_type = getattr(messages, key)
return MessageFilter(message_type).update_fields
message_filters = MessageFilterGenerator()
class TrezorClientDebugLink(TrezorClient):
# This class implements automatic responses
# and other functionality for unit tests
@@ -417,13 +508,15 @@ class TrezorClientDebugLink(TrezorClient):
raise RuntimeError("Must be called inside 'with' statement")
# make sure all items are (bool, message) tuples
expected_with_validity = [
expected_with_validity = (
e if isinstance(e, tuple) else (True, e) for e in expected
]
)
# only apply those items that are (True, message)
self.expected_responses = [
expected for valid, expected in expected_with_validity if valid
MessageFilter.from_message_or_type(expected)
for valid, expected in expected_with_validity
if valid
]
self.current_response = 0
@@ -469,23 +562,7 @@ class TrezorClientDebugLink(TrezorClient):
for i in range(start_at, stop_at):
exp = self.expected_responses[i]
prefix = " " if i != self.current_response else ">>> "
set_fields = {
key: value
for key, value in exp.__dict__.items()
if value is not None and value != []
}
oneline_str = ", ".join("{}={!r}".format(*i) for i in set_fields.items())
if len(oneline_str) < 60:
output.append(
"{}{}({})".format(prefix, exp.__class__.__name__, oneline_str)
)
else:
item = []
item.append("{}{}(".format(prefix, exp.__class__.__name__))
for key, value in set_fields.items():
item.append("{} {}={!r}".format(prefix, key, value))
item.append("{})".format(prefix))
output.append("\n".join(item))
output.append(textwrap.indent(exp.format(), prefix))
if stop_at < len(self.expected_responses):
omitted = len(self.expected_responses) - stop_at
output.append(" (...{} following responses omitted)".format(omitted))
@@ -493,7 +570,7 @@ class TrezorClientDebugLink(TrezorClient):
output.append("")
if msg is not None:
output.append("Actually received:")
output.append(protobuf.format_message(msg))
output.append(textwrap.indent(protobuf.format_message(msg), " "))
else:
output.append("This message was never received.")
raise AssertionError("\n".join(output))
@@ -511,15 +588,9 @@ class TrezorClientDebugLink(TrezorClient):
expected = self.expected_responses[self.current_response]
if msg.__class__ != expected.__class__:
if not expected.match(msg):
self._raise_unexpected_response(msg)
for field, value in expected.__dict__.items():
if value is None or value == []:
continue
if getattr(msg, field) != value:
self._raise_unexpected_response(msg)
self.current_response += 1
def mnemonic_callback(self, _):