From 94bbacdeef73c5a84871b48cb5beeab49ed07fab Mon Sep 17 00:00:00 2001 From: jopohl Date: Wed, 11 Jan 2017 10:31:14 +0100 Subject: [PATCH] make awre use field types --- src/urh/awre/FormatFinder.py | 16 +++++--- src/urh/awre/components/Address.py | 21 +++++++--- src/urh/awre/components/Length.py | 8 +++- src/urh/awre/components/Preamble.py | 24 ++++++++++-- src/urh/awre/components/SequenceNumber.py | 15 ++++++- src/urh/signalprocessing/FieldType.py | 12 +++--- src/urh/signalprocessing/MessageType.py | 6 ++- src/urh/signalprocessing/ProtocoLabel.py | 9 +++-- tests/test_awre.py | 48 +++++++++++++++++------ 9 files changed, 118 insertions(+), 41 deletions(-) diff --git a/src/urh/awre/FormatFinder.py b/src/urh/awre/FormatFinder.py index 456ca119..e3541709 100644 --- a/src/urh/awre/FormatFinder.py +++ b/src/urh/awre/FormatFinder.py @@ -1,6 +1,7 @@ import numpy as np import time +from urh.signalprocessing.FieldType import FieldType from urh.util.Logger import logger from urh.awre.components.Address import Address @@ -15,7 +16,7 @@ from urh.cythonext import util class FormatFinder(object): MIN_MESSAGES_PER_CLUSTER = 2 # If there is only one message per cluster it is not very significant - def __init__(self, protocol, participants=None): + def __init__(self, protocol, participants=None, field_types=None): """ :type protocol: urh.signalprocessing.ProtocolAnalyzer.ProtocolAnalyzer @@ -32,12 +33,15 @@ class FormatFinder(object): mt = self.protocol.message_types - self.preamble_component = Preamble(priority=0, messagetypes=mt) - self.length_component = Length(length_cluster=self.len_cluster, priority=1, + field_types = FieldType.load_from_xml() if field_types is None else field_types + + self.preamble_component = Preamble(fieldtypes=field_types, priority=0, messagetypes=mt) + self.length_component = Length(fieldtypes=field_types, length_cluster=self.len_cluster, priority=1, predecessors=[self.preamble_component], messagetypes=mt) - self.address_component = Address(xor_matrix=self.xor_matrix, priority=2, predecessors=[self.preamble_component], - messagetypes=mt) - self.sequence_number_component = SequenceNumber(priority=3, predecessors=[self.preamble_component]) + self.address_component = Address(fieldtypes=field_types, xor_matrix=self.xor_matrix, priority=2, + predecessors=[self.preamble_component], messagetypes=mt) + self.sequence_number_component = SequenceNumber(fieldtypes=field_types, priority=3, + predecessors=[self.preamble_component]) self.type_component = Type(priority=4, predecessors=[self.preamble_component]) self.flags_component = Flags(priority=5, predecessors=[self.preamble_component]) diff --git a/src/urh/awre/components/Address.py b/src/urh/awre/components/Address.py index 26416673..ba0aa00c 100644 --- a/src/urh/awre/components/Address.py +++ b/src/urh/awre/components/Address.py @@ -11,10 +11,16 @@ from urh.signalprocessing.MessageType import MessageType class Address(Component): MIN_ADDRESS_LENGTH = 8 # Address should be at least one byte - def __init__(self, xor_matrix, priority=2, predecessors=None, enabled=True, backend=None, messagetypes=None): + def __init__(self, fieldtypes, xor_matrix, priority=2, predecessors=None, enabled=True, backend=None, messagetypes=None): super().__init__(priority, predecessors, enabled, backend, messagetypes) self.xor_matrix = xor_matrix + self.dst_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.DST_ADDRESS), None) + self.src_field_type = next((ft for ft in fieldtypes if ft.Function == ft.Function.SRC_ADDRESS), None) + + self.dst_field_name = self.dst_field_type.caption if self.dst_field_type else "DST address" + self.src_field_name = self.src_field_type.caption if self.src_field_type else "SRC address" + def _py_find_field(self, messages, verbose=False): """ @@ -163,17 +169,22 @@ class Address(Component): msg = messages[msg_index] if msg.message_type.name == "ack": - name = "DST address" + field_type = self.dst_field_type + name = self.dst_field_name elif msg.participant: if rng.hex_value == msg.participant.address_hex: - name = "SRC address" + name = self.src_field_name + field_type = self.src_field_type else: - name = "DST address" + name = self.dst_field_name + field_type = self.dst_field_type else: name = "Address" + field_type = None if not any(lbl.name == name and lbl.auto_created for lbl in msg.message_type): - msg.message_type.add_protocol_label(rng.start, rng.end - 1, name=name, auto_created=True) + msg.message_type.add_protocol_label(rng.start, rng.end - 1, name=name, + auto_created=True, type=field_type) @staticmethod diff --git a/src/urh/awre/components/Length.py b/src/urh/awre/components/Length.py index 26881843..b8894527 100644 --- a/src/urh/awre/components/Length.py +++ b/src/urh/awre/components/Length.py @@ -14,9 +14,12 @@ class Length(Component): """ - def __init__(self, length_cluster, priority=2, predecessors=None, enabled=True, backend=None, messagetypes=None): + def __init__(self, fieldtypes, length_cluster, priority=2, predecessors=None, enabled=True, backend=None, messagetypes=None): super().__init__(priority, predecessors, enabled, backend, messagetypes) + self.length_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.LENGTH), None) + self.length_field_name = self.length_field_type.caption if self.length_field_type else "Length" + self.length_cluster = length_cluster """ An example length cluster is @@ -118,7 +121,8 @@ class Length(Component): try: start, end = max(scores, key=scores.__getitem__) if not any(lbl.name == "Length" and lbl.auto_created for lbl in message_type): - message_type.add_protocol_label(start=start, end=end - 1, name="Length", auto_created=True) + message_type.add_protocol_label(start=start, end=end - 1, name=self.length_field_name, + auto_created=True, type=self.length_field_type) except ValueError: continue diff --git a/src/urh/awre/components/Preamble.py b/src/urh/awre/components/Preamble.py index 254d2d90..0448f501 100644 --- a/src/urh/awre/components/Preamble.py +++ b/src/urh/awre/components/Preamble.py @@ -1,5 +1,6 @@ from collections import defaultdict from urh.awre.components.Component import Component +from urh.signalprocessing.FieldType import FieldType from urh.signalprocessing.Message import Message @@ -8,9 +9,24 @@ class Preamble(Component): Assign Preamble and SoF. """ - def __init__(self, priority=0, predecessors=None, enabled=True, backend=None, messagetypes=None): + def __init__(self, fieldtypes, priority=0, predecessors=None, enabled=True, backend=None, messagetypes=None): + """ + + :type fieldtypes: list of FieldType + :param priority: + :param predecessors: + :param enabled: + :param backend: + :param messagetypes: + """ super().__init__(priority, predecessors, enabled, backend, messagetypes) + self.preamble_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.PREAMBLE), None) + self.sync_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.SYNC), None) + + self.preamble_name = self.preamble_field_type.caption if self.preamble_field_type else "Preamble" + self.sync_name = self.sync_field_type.caption if self.sync_field_type else "Synchronization" + def _py_find_field(self, messages): """ @@ -28,7 +44,8 @@ class Preamble(Component): preamble_ends = defaultdict(int) for message_type, ranges in preamble_ranges.items(): start, end = max(ranges, key=ranges.count) - message_type.add_protocol_label(start=start, end=end, name="Preamble", auto_created=True) + message_type.add_protocol_label(start=start, end=end, name=self.preamble_name, + auto_created=True, type=self.preamble_field_type) preamble_ends[message_type] = end + 1 @@ -39,7 +56,8 @@ class Preamble(Component): sync_range = self.__find_sync_range(messages, preamble_ends[message_type], search_end) if sync_range: - message_type.add_protocol_label(start=sync_range[0], end=sync_range[1]-1, name="Synchronization", auto_created=True) + message_type.add_protocol_label(start=sync_range[0], end=sync_range[1]-1, name=self.sync_name, + auto_created=True, type=self.sync_field_type) def __find_preamble_range(self, message: Message): diff --git a/src/urh/awre/components/SequenceNumber.py b/src/urh/awre/components/SequenceNumber.py index 11340d73..30d06ad1 100644 --- a/src/urh/awre/components/SequenceNumber.py +++ b/src/urh/awre/components/SequenceNumber.py @@ -1,8 +1,21 @@ from urh.awre.components.Component import Component class SequenceNumber(Component): - def __init__(self, priority=2, predecessors=None, enabled=True, backend=None): + def __init__(self, fieldtypes, priority=2, predecessors=None, enabled=True, backend=None): + """ + + :type fieldtypes: list of FieldType + :param priority: + :param predecessors: + :param enabled: + :param backend: + :param messagetypes: + """ super().__init__(priority, predecessors, enabled, backend) + self.seqnr_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.SEQUENCE_NUMBER), None) + self.seqnr_field_name = self.seqnr_field_type.caption if self.seqnr_field_type else "Sequence Number" + + def _py_find_field(self, messages): raise NotImplementedError("Todo") \ No newline at end of file diff --git a/src/urh/signalprocessing/FieldType.py b/src/urh/signalprocessing/FieldType.py index 07265c93..56c15a56 100644 --- a/src/urh/signalprocessing/FieldType.py +++ b/src/urh/signalprocessing/FieldType.py @@ -45,6 +45,12 @@ class FieldType(object): def __eq__(self, other): return isinstance(other, FieldType) and self.id_match(other.id) + def __hash__(self): + return hash(self.id) + + def __repr__(self): + return "FieldType: {0} - {1} ({2})".format(self.function.name, self.caption, self.display_format_index) + @property def id(self): return self.__id @@ -52,12 +58,6 @@ class FieldType(object): def id_match(self, id): return self.__id == id - def __hash__(self): - return hash(self.id) - - def __repr__(self): - return "FieldType: {0} - {1} ({2})".format(self.function.name, self.caption, self.display_format_index) - @staticmethod def default_field_types(): """ diff --git a/src/urh/signalprocessing/MessageType.py b/src/urh/signalprocessing/MessageType.py index f24cefbc..40f08b34 100644 --- a/src/urh/signalprocessing/MessageType.py +++ b/src/urh/signalprocessing/MessageType.py @@ -78,7 +78,8 @@ class MessageType(list): super().append(lbl) self.sort() - def add_protocol_label(self, start: int, end: int, name=None, color_ind=None, auto_created=False) -> ProtocolLabel: + def add_protocol_label(self, start: int, end: int, name=None, color_ind=None, + auto_created=False, type:FieldType=None) -> ProtocolLabel: name = "" if not name else name used_colors = [p.color_index for p in self] @@ -90,7 +91,8 @@ class MessageType(list): else: color_ind = random.randint(0, len(constants.LABEL_COLORS) - 1) - proto_label = ProtocolLabel(name=name, start=start, end=end, color_index=color_ind, auto_created=auto_created) + proto_label = ProtocolLabel(name=name, start=start, end=end, color_index=color_ind, + auto_created=auto_created, type=type) if proto_label not in self: self.append(proto_label) diff --git a/src/urh/signalprocessing/ProtocoLabel.py b/src/urh/signalprocessing/ProtocoLabel.py index cfaa382a..b9833a9b 100644 --- a/src/urh/signalprocessing/ProtocoLabel.py +++ b/src/urh/signalprocessing/ProtocoLabel.py @@ -18,7 +18,8 @@ class ProtocolLabel(object): DISPLAY_FORMATS = ["Bit", "Hex", "ASCII", "Decimal"] SEARCH_TYPES = ["Number", "Bits", "Hex", "ASCII"] - def __init__(self, name: str, start: int, end: int, color_index: int, fuzz_created=False, auto_created=False): + def __init__(self, name: str, start: int, end: int, color_index: int, fuzz_created=False, + auto_created=False, type:FieldType=None): self.__name = name self.start = start self.end = end + 1 @@ -32,9 +33,9 @@ class ProtocolLabel(object): self.fuzz_created = fuzz_created - self.__type = None # type: FieldType + self.__type = type # type: FieldType - self.display_format_index = 0 + self.display_format_index = 0 if type is None else type.display_format_index self.auto_created = auto_created @@ -86,7 +87,7 @@ class ProtocolLabel(object): return False def __eq__(self, other): - return self.start == other.start and self.end == other.end and self.name == other.name + return self.start == other.start and self.end == other.end and self.name == other.name and self.type == other.type def __hash__(self): return hash("{}/{}/{}".format(self.start, self.end, self.name)) diff --git a/tests/test_awre.py b/tests/test_awre.py index 41c61f03..7f3433cf 100644 --- a/tests/test_awre.py +++ b/tests/test_awre.py @@ -11,6 +11,7 @@ from urh.awre.components.Length import Length from urh.awre.components.Preamble import Preamble from urh.awre.components.SequenceNumber import SequenceNumber from urh.awre.components.Type import Type +from urh.signalprocessing.FieldType import FieldType from urh.signalprocessing.Participant import Participant from urh.signalprocessing.ProtocoLabel import ProtocolLabel from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer @@ -22,6 +23,15 @@ from urh.signalprocessing.Message import Message from urh.cythonext import util class TestAWRE(unittest.TestCase): def setUp(self): + self.field_types = FieldType.default_field_types() + + self.preamble_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.PREAMBLE) + self.sync_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.SYNC) + self.length_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.LENGTH) + self.sequence_number_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.SEQUENCE_NUMBER) + self.dst_address_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.DST_ADDRESS) + self.src_address_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.SRC_ADDRESS) + self.protocol = ProtocolAnalyzer(None) with open("./data/awre_consistent_addresses.txt") as f: for line in f: @@ -46,15 +56,21 @@ class TestAWRE(unittest.TestCase): for i, message in enumerate(self.zero_crc_protocol.messages): message.participant = alice if i in alice_indices else bob + @staticmethod + def __field_type_with_function(field_types, function) -> FieldType: + return next(ft for ft in field_types if ft.function == function) + def test_build_component_order(self): - expected_default = [Preamble(), Length(None), Address(None), SequenceNumber(), Type(), Flags()] + expected_default = [Preamble(fieldtypes=[]), Length(fieldtypes=[], length_cluster=None), + Address(fieldtypes=[], xor_matrix=None), SequenceNumber(fieldtypes=[]), Type(), Flags()] format_finder = FormatFinder(self.protocol) for expected, actual in zip(expected_default, format_finder.build_component_order()): assert type(expected) == type(actual) - expected_swapped = [Preamble(), Address(None), Length(None), SequenceNumber(), Type(), Flags()] + expected_swapped = [Preamble(fieldtypes=[]), Address(fieldtypes=[],xor_matrix=None), + Length(fieldtypes=[], length_cluster=None), SequenceNumber(fieldtypes=[]), Type(), Flags()] format_finder.length_component.priority = 2 format_finder.address_component.priority = 1 @@ -77,14 +93,20 @@ class TestAWRE(unittest.TestCase): dst_address_start, dst_address_end = 88, 111 src_address_start, src_address_end = 112, 135 - preamble_label = ProtocolLabel(name="Preamble", start=preamble_start, end=preamble_end, color_index=0) - sync_label = ProtocolLabel(name="Synchronization", start=sync_start, end=sync_end, color_index=1) - length_label = ProtocolLabel(name="Length", start=length_start, end=length_end, color_index=2) - ack_address_label = ProtocolLabel(name="DST address", start=ack_address_start, end=ack_address_end, color_index=3) - dst_address_label = ProtocolLabel(name="DST address", start=dst_address_start, end=dst_address_end, color_index=4) - src_address_label = ProtocolLabel(name="SRC address", start=src_address_start, end=src_address_end, color_index=5) + preamble_label = ProtocolLabel(name=self.preamble_field_type.caption, type=self.preamble_field_type, + start=preamble_start, end=preamble_end, color_index=0) + sync_label = ProtocolLabel(name=self.sync_field_type.caption, type=self.sync_field_type, + start=sync_start, end=sync_end, color_index=1) + length_label = ProtocolLabel(name=self.length_field_type.caption, type=self.length_field_type, + start=length_start, end=length_end, color_index=2) + ack_address_label = ProtocolLabel(name=self.dst_address_field_type.caption, type=self.dst_address_field_type, + start=ack_address_start, end=ack_address_end, color_index=3) + dst_address_label = ProtocolLabel(name=self.dst_address_field_type.caption, type=self.dst_address_field_type, + start=dst_address_start, end=dst_address_end, color_index=4) + src_address_label = ProtocolLabel(name=self.src_address_field_type.caption, type=self.src_address_field_type, + start=src_address_start, end=src_address_end, color_index=5) - ff = FormatFinder(self.protocol, self.participants) + ff = FormatFinder(protocol=self.protocol, participants=self.participants, field_types=self.field_types) ff.perform_iteration() self.assertIn(preamble_label, self.protocol.default_message_type) @@ -124,8 +146,10 @@ class TestAWRE(unittest.TestCase): sof_start = 11 sof_end = 14 - preamble_label = ProtocolLabel(name="Preamble", start=preamble_start, end=preamble_end, color_index=0) - sync_label = ProtocolLabel(name="Synchronization", start=sof_start, end=sof_end, color_index=1) + preamble_label = ProtocolLabel(name=self.preamble_field_type.caption, type=self.preamble_field_type, + start=preamble_start, end=preamble_end, color_index=0) + sync_label = ProtocolLabel(name=self.sync_field_type.caption, type=self.sync_field_type, + start=sof_start, end=sof_end, color_index=1) ff = FormatFinder(enocean_protocol, self.participants) @@ -135,7 +159,7 @@ class TestAWRE(unittest.TestCase): self.assertIn(preamble_label, enocean_protocol.default_message_type) self.assertIn(sync_label, enocean_protocol.default_message_type) - self.assertTrue(not any(lbl.name == "Length" for lbl in enocean_protocol.default_message_type)) + self.assertTrue(not any(lbl.name == self.length_field_type.caption for lbl in enocean_protocol.default_message_type)) self.assertTrue(not any("address" in lbl.name.lower() for lbl in enocean_protocol.default_message_type)) def test_address_candidate_finding(self):