mirror of
https://github.com/jopohl/urh.git
synced 2026-03-20 23:17:01 +01:00
make awre use field types
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
from urh.signalprocessing.FieldType import FieldType
|
||||
from urh.util.Logger import logger
|
||||
|
||||
from urh.awre.components.Address import Address
|
||||
@@ -15,7 +16,7 @@ from urh.cythonext import util
|
||||
class FormatFinder(object):
|
||||
MIN_MESSAGES_PER_CLUSTER = 2 # If there is only one message per cluster it is not very significant
|
||||
|
||||
def __init__(self, protocol, participants=None):
|
||||
def __init__(self, protocol, participants=None, field_types=None):
|
||||
"""
|
||||
|
||||
:type protocol: urh.signalprocessing.ProtocolAnalyzer.ProtocolAnalyzer
|
||||
@@ -32,12 +33,15 @@ class FormatFinder(object):
|
||||
|
||||
mt = self.protocol.message_types
|
||||
|
||||
self.preamble_component = Preamble(priority=0, messagetypes=mt)
|
||||
self.length_component = Length(length_cluster=self.len_cluster, priority=1,
|
||||
field_types = FieldType.load_from_xml() if field_types is None else field_types
|
||||
|
||||
self.preamble_component = Preamble(fieldtypes=field_types, priority=0, messagetypes=mt)
|
||||
self.length_component = Length(fieldtypes=field_types, length_cluster=self.len_cluster, priority=1,
|
||||
predecessors=[self.preamble_component], messagetypes=mt)
|
||||
self.address_component = Address(xor_matrix=self.xor_matrix, priority=2, predecessors=[self.preamble_component],
|
||||
messagetypes=mt)
|
||||
self.sequence_number_component = SequenceNumber(priority=3, predecessors=[self.preamble_component])
|
||||
self.address_component = Address(fieldtypes=field_types, xor_matrix=self.xor_matrix, priority=2,
|
||||
predecessors=[self.preamble_component], messagetypes=mt)
|
||||
self.sequence_number_component = SequenceNumber(fieldtypes=field_types, priority=3,
|
||||
predecessors=[self.preamble_component])
|
||||
self.type_component = Type(priority=4, predecessors=[self.preamble_component])
|
||||
self.flags_component = Flags(priority=5, predecessors=[self.preamble_component])
|
||||
|
||||
|
||||
@@ -11,10 +11,16 @@ from urh.signalprocessing.MessageType import MessageType
|
||||
class Address(Component):
|
||||
MIN_ADDRESS_LENGTH = 8 # Address should be at least one byte
|
||||
|
||||
def __init__(self, xor_matrix, priority=2, predecessors=None, enabled=True, backend=None, messagetypes=None):
|
||||
def __init__(self, fieldtypes, xor_matrix, priority=2, predecessors=None, enabled=True, backend=None, messagetypes=None):
|
||||
super().__init__(priority, predecessors, enabled, backend, messagetypes)
|
||||
self.xor_matrix = xor_matrix
|
||||
|
||||
self.dst_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.DST_ADDRESS), None)
|
||||
self.src_field_type = next((ft for ft in fieldtypes if ft.Function == ft.Function.SRC_ADDRESS), None)
|
||||
|
||||
self.dst_field_name = self.dst_field_type.caption if self.dst_field_type else "DST address"
|
||||
self.src_field_name = self.src_field_type.caption if self.src_field_type else "SRC address"
|
||||
|
||||
def _py_find_field(self, messages, verbose=False):
|
||||
"""
|
||||
|
||||
@@ -163,17 +169,22 @@ class Address(Component):
|
||||
msg = messages[msg_index]
|
||||
|
||||
if msg.message_type.name == "ack":
|
||||
name = "DST address"
|
||||
field_type = self.dst_field_type
|
||||
name = self.dst_field_name
|
||||
elif msg.participant:
|
||||
if rng.hex_value == msg.participant.address_hex:
|
||||
name = "SRC address"
|
||||
name = self.src_field_name
|
||||
field_type = self.src_field_type
|
||||
else:
|
||||
name = "DST address"
|
||||
name = self.dst_field_name
|
||||
field_type = self.dst_field_type
|
||||
else:
|
||||
name = "Address"
|
||||
field_type = None
|
||||
|
||||
if not any(lbl.name == name and lbl.auto_created for lbl in msg.message_type):
|
||||
msg.message_type.add_protocol_label(rng.start, rng.end - 1, name=name, auto_created=True)
|
||||
msg.message_type.add_protocol_label(rng.start, rng.end - 1, name=name,
|
||||
auto_created=True, type=field_type)
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -14,9 +14,12 @@ class Length(Component):
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, length_cluster, priority=2, predecessors=None, enabled=True, backend=None, messagetypes=None):
|
||||
def __init__(self, fieldtypes, length_cluster, priority=2, predecessors=None, enabled=True, backend=None, messagetypes=None):
|
||||
super().__init__(priority, predecessors, enabled, backend, messagetypes)
|
||||
|
||||
self.length_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.LENGTH), None)
|
||||
self.length_field_name = self.length_field_type.caption if self.length_field_type else "Length"
|
||||
|
||||
self.length_cluster = length_cluster
|
||||
"""
|
||||
An example length cluster is
|
||||
@@ -118,7 +121,8 @@ class Length(Component):
|
||||
try:
|
||||
start, end = max(scores, key=scores.__getitem__)
|
||||
if not any(lbl.name == "Length" and lbl.auto_created for lbl in message_type):
|
||||
message_type.add_protocol_label(start=start, end=end - 1, name="Length", auto_created=True)
|
||||
message_type.add_protocol_label(start=start, end=end - 1, name=self.length_field_name,
|
||||
auto_created=True, type=self.length_field_type)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from urh.awre.components.Component import Component
|
||||
from urh.signalprocessing.FieldType import FieldType
|
||||
from urh.signalprocessing.Message import Message
|
||||
|
||||
|
||||
@@ -8,9 +9,24 @@ class Preamble(Component):
|
||||
Assign Preamble and SoF.
|
||||
|
||||
"""
|
||||
def __init__(self, priority=0, predecessors=None, enabled=True, backend=None, messagetypes=None):
|
||||
def __init__(self, fieldtypes, priority=0, predecessors=None, enabled=True, backend=None, messagetypes=None):
|
||||
"""
|
||||
|
||||
:type fieldtypes: list of FieldType
|
||||
:param priority:
|
||||
:param predecessors:
|
||||
:param enabled:
|
||||
:param backend:
|
||||
:param messagetypes:
|
||||
"""
|
||||
super().__init__(priority, predecessors, enabled, backend, messagetypes)
|
||||
|
||||
self.preamble_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.PREAMBLE), None)
|
||||
self.sync_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.SYNC), None)
|
||||
|
||||
self.preamble_name = self.preamble_field_type.caption if self.preamble_field_type else "Preamble"
|
||||
self.sync_name = self.sync_field_type.caption if self.sync_field_type else "Synchronization"
|
||||
|
||||
def _py_find_field(self, messages):
|
||||
"""
|
||||
|
||||
@@ -28,7 +44,8 @@ class Preamble(Component):
|
||||
preamble_ends = defaultdict(int)
|
||||
for message_type, ranges in preamble_ranges.items():
|
||||
start, end = max(ranges, key=ranges.count)
|
||||
message_type.add_protocol_label(start=start, end=end, name="Preamble", auto_created=True)
|
||||
message_type.add_protocol_label(start=start, end=end, name=self.preamble_name,
|
||||
auto_created=True, type=self.preamble_field_type)
|
||||
|
||||
preamble_ends[message_type] = end + 1
|
||||
|
||||
@@ -39,7 +56,8 @@ class Preamble(Component):
|
||||
sync_range = self.__find_sync_range(messages, preamble_ends[message_type], search_end)
|
||||
|
||||
if sync_range:
|
||||
message_type.add_protocol_label(start=sync_range[0], end=sync_range[1]-1, name="Synchronization", auto_created=True)
|
||||
message_type.add_protocol_label(start=sync_range[0], end=sync_range[1]-1, name=self.sync_name,
|
||||
auto_created=True, type=self.sync_field_type)
|
||||
|
||||
|
||||
def __find_preamble_range(self, message: Message):
|
||||
|
||||
@@ -1,8 +1,21 @@
|
||||
from urh.awre.components.Component import Component
|
||||
|
||||
class SequenceNumber(Component):
|
||||
def __init__(self, priority=2, predecessors=None, enabled=True, backend=None):
|
||||
def __init__(self, fieldtypes, priority=2, predecessors=None, enabled=True, backend=None):
|
||||
"""
|
||||
|
||||
:type fieldtypes: list of FieldType
|
||||
:param priority:
|
||||
:param predecessors:
|
||||
:param enabled:
|
||||
:param backend:
|
||||
:param messagetypes:
|
||||
"""
|
||||
super().__init__(priority, predecessors, enabled, backend)
|
||||
|
||||
self.seqnr_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.SEQUENCE_NUMBER), None)
|
||||
self.seqnr_field_name = self.seqnr_field_type.caption if self.seqnr_field_type else "Sequence Number"
|
||||
|
||||
|
||||
def _py_find_field(self, messages):
|
||||
raise NotImplementedError("Todo")
|
||||
@@ -45,6 +45,12 @@ class FieldType(object):
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, FieldType) and self.id_match(other.id)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.id)
|
||||
|
||||
def __repr__(self):
|
||||
return "FieldType: {0} - {1} ({2})".format(self.function.name, self.caption, self.display_format_index)
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self.__id
|
||||
@@ -52,12 +58,6 @@ class FieldType(object):
|
||||
def id_match(self, id):
|
||||
return self.__id == id
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.id)
|
||||
|
||||
def __repr__(self):
|
||||
return "FieldType: {0} - {1} ({2})".format(self.function.name, self.caption, self.display_format_index)
|
||||
|
||||
@staticmethod
|
||||
def default_field_types():
|
||||
"""
|
||||
|
||||
@@ -78,7 +78,8 @@ class MessageType(list):
|
||||
super().append(lbl)
|
||||
self.sort()
|
||||
|
||||
def add_protocol_label(self, start: int, end: int, name=None, color_ind=None, auto_created=False) -> ProtocolLabel:
|
||||
def add_protocol_label(self, start: int, end: int, name=None, color_ind=None,
|
||||
auto_created=False, type:FieldType=None) -> ProtocolLabel:
|
||||
|
||||
name = "" if not name else name
|
||||
used_colors = [p.color_index for p in self]
|
||||
@@ -90,7 +91,8 @@ class MessageType(list):
|
||||
else:
|
||||
color_ind = random.randint(0, len(constants.LABEL_COLORS) - 1)
|
||||
|
||||
proto_label = ProtocolLabel(name=name, start=start, end=end, color_index=color_ind, auto_created=auto_created)
|
||||
proto_label = ProtocolLabel(name=name, start=start, end=end, color_index=color_ind,
|
||||
auto_created=auto_created, type=type)
|
||||
|
||||
if proto_label not in self:
|
||||
self.append(proto_label)
|
||||
|
||||
@@ -18,7 +18,8 @@ class ProtocolLabel(object):
|
||||
DISPLAY_FORMATS = ["Bit", "Hex", "ASCII", "Decimal"]
|
||||
SEARCH_TYPES = ["Number", "Bits", "Hex", "ASCII"]
|
||||
|
||||
def __init__(self, name: str, start: int, end: int, color_index: int, fuzz_created=False, auto_created=False):
|
||||
def __init__(self, name: str, start: int, end: int, color_index: int, fuzz_created=False,
|
||||
auto_created=False, type:FieldType=None):
|
||||
self.__name = name
|
||||
self.start = start
|
||||
self.end = end + 1
|
||||
@@ -32,9 +33,9 @@ class ProtocolLabel(object):
|
||||
|
||||
self.fuzz_created = fuzz_created
|
||||
|
||||
self.__type = None # type: FieldType
|
||||
self.__type = type # type: FieldType
|
||||
|
||||
self.display_format_index = 0
|
||||
self.display_format_index = 0 if type is None else type.display_format_index
|
||||
|
||||
self.auto_created = auto_created
|
||||
|
||||
@@ -86,7 +87,7 @@ class ProtocolLabel(object):
|
||||
return False
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.start == other.start and self.end == other.end and self.name == other.name
|
||||
return self.start == other.start and self.end == other.end and self.name == other.name and self.type == other.type
|
||||
|
||||
def __hash__(self):
|
||||
return hash("{}/{}/{}".format(self.start, self.end, self.name))
|
||||
|
||||
@@ -11,6 +11,7 @@ from urh.awre.components.Length import Length
|
||||
from urh.awre.components.Preamble import Preamble
|
||||
from urh.awre.components.SequenceNumber import SequenceNumber
|
||||
from urh.awre.components.Type import Type
|
||||
from urh.signalprocessing.FieldType import FieldType
|
||||
from urh.signalprocessing.Participant import Participant
|
||||
from urh.signalprocessing.ProtocoLabel import ProtocolLabel
|
||||
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
|
||||
@@ -22,6 +23,15 @@ from urh.signalprocessing.Message import Message
|
||||
from urh.cythonext import util
|
||||
class TestAWRE(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.field_types = FieldType.default_field_types()
|
||||
|
||||
self.preamble_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.PREAMBLE)
|
||||
self.sync_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.SYNC)
|
||||
self.length_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.LENGTH)
|
||||
self.sequence_number_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.SEQUENCE_NUMBER)
|
||||
self.dst_address_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.DST_ADDRESS)
|
||||
self.src_address_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.SRC_ADDRESS)
|
||||
|
||||
self.protocol = ProtocolAnalyzer(None)
|
||||
with open("./data/awre_consistent_addresses.txt") as f:
|
||||
for line in f:
|
||||
@@ -46,15 +56,21 @@ class TestAWRE(unittest.TestCase):
|
||||
for i, message in enumerate(self.zero_crc_protocol.messages):
|
||||
message.participant = alice if i in alice_indices else bob
|
||||
|
||||
@staticmethod
|
||||
def __field_type_with_function(field_types, function) -> FieldType:
|
||||
return next(ft for ft in field_types if ft.function == function)
|
||||
|
||||
def test_build_component_order(self):
|
||||
expected_default = [Preamble(), Length(None), Address(None), SequenceNumber(), Type(), Flags()]
|
||||
expected_default = [Preamble(fieldtypes=[]), Length(fieldtypes=[], length_cluster=None),
|
||||
Address(fieldtypes=[], xor_matrix=None), SequenceNumber(fieldtypes=[]), Type(), Flags()]
|
||||
|
||||
format_finder = FormatFinder(self.protocol)
|
||||
|
||||
for expected, actual in zip(expected_default, format_finder.build_component_order()):
|
||||
assert type(expected) == type(actual)
|
||||
|
||||
expected_swapped = [Preamble(), Address(None), Length(None), SequenceNumber(), Type(), Flags()]
|
||||
expected_swapped = [Preamble(fieldtypes=[]), Address(fieldtypes=[],xor_matrix=None),
|
||||
Length(fieldtypes=[], length_cluster=None), SequenceNumber(fieldtypes=[]), Type(), Flags()]
|
||||
format_finder.length_component.priority = 2
|
||||
format_finder.address_component.priority = 1
|
||||
|
||||
@@ -77,14 +93,20 @@ class TestAWRE(unittest.TestCase):
|
||||
dst_address_start, dst_address_end = 88, 111
|
||||
src_address_start, src_address_end = 112, 135
|
||||
|
||||
preamble_label = ProtocolLabel(name="Preamble", start=preamble_start, end=preamble_end, color_index=0)
|
||||
sync_label = ProtocolLabel(name="Synchronization", start=sync_start, end=sync_end, color_index=1)
|
||||
length_label = ProtocolLabel(name="Length", start=length_start, end=length_end, color_index=2)
|
||||
ack_address_label = ProtocolLabel(name="DST address", start=ack_address_start, end=ack_address_end, color_index=3)
|
||||
dst_address_label = ProtocolLabel(name="DST address", start=dst_address_start, end=dst_address_end, color_index=4)
|
||||
src_address_label = ProtocolLabel(name="SRC address", start=src_address_start, end=src_address_end, color_index=5)
|
||||
preamble_label = ProtocolLabel(name=self.preamble_field_type.caption, type=self.preamble_field_type,
|
||||
start=preamble_start, end=preamble_end, color_index=0)
|
||||
sync_label = ProtocolLabel(name=self.sync_field_type.caption, type=self.sync_field_type,
|
||||
start=sync_start, end=sync_end, color_index=1)
|
||||
length_label = ProtocolLabel(name=self.length_field_type.caption, type=self.length_field_type,
|
||||
start=length_start, end=length_end, color_index=2)
|
||||
ack_address_label = ProtocolLabel(name=self.dst_address_field_type.caption, type=self.dst_address_field_type,
|
||||
start=ack_address_start, end=ack_address_end, color_index=3)
|
||||
dst_address_label = ProtocolLabel(name=self.dst_address_field_type.caption, type=self.dst_address_field_type,
|
||||
start=dst_address_start, end=dst_address_end, color_index=4)
|
||||
src_address_label = ProtocolLabel(name=self.src_address_field_type.caption, type=self.src_address_field_type,
|
||||
start=src_address_start, end=src_address_end, color_index=5)
|
||||
|
||||
ff = FormatFinder(self.protocol, self.participants)
|
||||
ff = FormatFinder(protocol=self.protocol, participants=self.participants, field_types=self.field_types)
|
||||
ff.perform_iteration()
|
||||
|
||||
self.assertIn(preamble_label, self.protocol.default_message_type)
|
||||
@@ -124,8 +146,10 @@ class TestAWRE(unittest.TestCase):
|
||||
sof_start = 11
|
||||
sof_end = 14
|
||||
|
||||
preamble_label = ProtocolLabel(name="Preamble", start=preamble_start, end=preamble_end, color_index=0)
|
||||
sync_label = ProtocolLabel(name="Synchronization", start=sof_start, end=sof_end, color_index=1)
|
||||
preamble_label = ProtocolLabel(name=self.preamble_field_type.caption, type=self.preamble_field_type,
|
||||
start=preamble_start, end=preamble_end, color_index=0)
|
||||
sync_label = ProtocolLabel(name=self.sync_field_type.caption, type=self.sync_field_type,
|
||||
start=sof_start, end=sof_end, color_index=1)
|
||||
|
||||
|
||||
ff = FormatFinder(enocean_protocol, self.participants)
|
||||
@@ -135,7 +159,7 @@ class TestAWRE(unittest.TestCase):
|
||||
|
||||
self.assertIn(preamble_label, enocean_protocol.default_message_type)
|
||||
self.assertIn(sync_label, enocean_protocol.default_message_type)
|
||||
self.assertTrue(not any(lbl.name == "Length" for lbl in enocean_protocol.default_message_type))
|
||||
self.assertTrue(not any(lbl.name == self.length_field_type.caption for lbl in enocean_protocol.default_message_type))
|
||||
self.assertTrue(not any("address" in lbl.name.lower() for lbl in enocean_protocol.default_message_type))
|
||||
|
||||
def test_address_candidate_finding(self):
|
||||
|
||||
Reference in New Issue
Block a user