make awre use field types

This commit is contained in:
jopohl
2017-01-11 10:31:14 +01:00
parent bce77afa93
commit 94bbacdeef
9 changed files with 118 additions and 41 deletions

View File

@@ -1,6 +1,7 @@
import numpy as np
import time
from urh.signalprocessing.FieldType import FieldType
from urh.util.Logger import logger
from urh.awre.components.Address import Address
@@ -15,7 +16,7 @@ from urh.cythonext import util
class FormatFinder(object):
MIN_MESSAGES_PER_CLUSTER = 2 # If there is only one message per cluster it is not very significant
def __init__(self, protocol, participants=None):
def __init__(self, protocol, participants=None, field_types=None):
"""
:type protocol: urh.signalprocessing.ProtocolAnalyzer.ProtocolAnalyzer
@@ -32,12 +33,15 @@ class FormatFinder(object):
mt = self.protocol.message_types
self.preamble_component = Preamble(priority=0, messagetypes=mt)
self.length_component = Length(length_cluster=self.len_cluster, priority=1,
field_types = FieldType.load_from_xml() if field_types is None else field_types
self.preamble_component = Preamble(fieldtypes=field_types, priority=0, messagetypes=mt)
self.length_component = Length(fieldtypes=field_types, length_cluster=self.len_cluster, priority=1,
predecessors=[self.preamble_component], messagetypes=mt)
self.address_component = Address(xor_matrix=self.xor_matrix, priority=2, predecessors=[self.preamble_component],
messagetypes=mt)
self.sequence_number_component = SequenceNumber(priority=3, predecessors=[self.preamble_component])
self.address_component = Address(fieldtypes=field_types, xor_matrix=self.xor_matrix, priority=2,
predecessors=[self.preamble_component], messagetypes=mt)
self.sequence_number_component = SequenceNumber(fieldtypes=field_types, priority=3,
predecessors=[self.preamble_component])
self.type_component = Type(priority=4, predecessors=[self.preamble_component])
self.flags_component = Flags(priority=5, predecessors=[self.preamble_component])

View File

@@ -11,10 +11,16 @@ from urh.signalprocessing.MessageType import MessageType
class Address(Component):
MIN_ADDRESS_LENGTH = 8 # Address should be at least one byte
def __init__(self, xor_matrix, priority=2, predecessors=None, enabled=True, backend=None, messagetypes=None):
def __init__(self, fieldtypes, xor_matrix, priority=2, predecessors=None, enabled=True, backend=None, messagetypes=None):
super().__init__(priority, predecessors, enabled, backend, messagetypes)
self.xor_matrix = xor_matrix
self.dst_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.DST_ADDRESS), None)
self.src_field_type = next((ft for ft in fieldtypes if ft.Function == ft.Function.SRC_ADDRESS), None)
self.dst_field_name = self.dst_field_type.caption if self.dst_field_type else "DST address"
self.src_field_name = self.src_field_type.caption if self.src_field_type else "SRC address"
def _py_find_field(self, messages, verbose=False):
"""
@@ -163,17 +169,22 @@ class Address(Component):
msg = messages[msg_index]
if msg.message_type.name == "ack":
name = "DST address"
field_type = self.dst_field_type
name = self.dst_field_name
elif msg.participant:
if rng.hex_value == msg.participant.address_hex:
name = "SRC address"
name = self.src_field_name
field_type = self.src_field_type
else:
name = "DST address"
name = self.dst_field_name
field_type = self.dst_field_type
else:
name = "Address"
field_type = None
if not any(lbl.name == name and lbl.auto_created for lbl in msg.message_type):
msg.message_type.add_protocol_label(rng.start, rng.end - 1, name=name, auto_created=True)
msg.message_type.add_protocol_label(rng.start, rng.end - 1, name=name,
auto_created=True, type=field_type)
@staticmethod

View File

@@ -14,9 +14,12 @@ class Length(Component):
"""
def __init__(self, length_cluster, priority=2, predecessors=None, enabled=True, backend=None, messagetypes=None):
def __init__(self, fieldtypes, length_cluster, priority=2, predecessors=None, enabled=True, backend=None, messagetypes=None):
super().__init__(priority, predecessors, enabled, backend, messagetypes)
self.length_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.LENGTH), None)
self.length_field_name = self.length_field_type.caption if self.length_field_type else "Length"
self.length_cluster = length_cluster
"""
An example length cluster is
@@ -118,7 +121,8 @@ class Length(Component):
try:
start, end = max(scores, key=scores.__getitem__)
if not any(lbl.name == "Length" and lbl.auto_created for lbl in message_type):
message_type.add_protocol_label(start=start, end=end - 1, name="Length", auto_created=True)
message_type.add_protocol_label(start=start, end=end - 1, name=self.length_field_name,
auto_created=True, type=self.length_field_type)
except ValueError:
continue

View File

@@ -1,5 +1,6 @@
from collections import defaultdict
from urh.awre.components.Component import Component
from urh.signalprocessing.FieldType import FieldType
from urh.signalprocessing.Message import Message
@@ -8,9 +9,24 @@ class Preamble(Component):
Assign Preamble and SoF.
"""
def __init__(self, priority=0, predecessors=None, enabled=True, backend=None, messagetypes=None):
def __init__(self, fieldtypes, priority=0, predecessors=None, enabled=True, backend=None, messagetypes=None):
"""
:type fieldtypes: list of FieldType
:param priority:
:param predecessors:
:param enabled:
:param backend:
:param messagetypes:
"""
super().__init__(priority, predecessors, enabled, backend, messagetypes)
self.preamble_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.PREAMBLE), None)
self.sync_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.SYNC), None)
self.preamble_name = self.preamble_field_type.caption if self.preamble_field_type else "Preamble"
self.sync_name = self.sync_field_type.caption if self.sync_field_type else "Synchronization"
def _py_find_field(self, messages):
"""
@@ -28,7 +44,8 @@ class Preamble(Component):
preamble_ends = defaultdict(int)
for message_type, ranges in preamble_ranges.items():
start, end = max(ranges, key=ranges.count)
message_type.add_protocol_label(start=start, end=end, name="Preamble", auto_created=True)
message_type.add_protocol_label(start=start, end=end, name=self.preamble_name,
auto_created=True, type=self.preamble_field_type)
preamble_ends[message_type] = end + 1
@@ -39,7 +56,8 @@ class Preamble(Component):
sync_range = self.__find_sync_range(messages, preamble_ends[message_type], search_end)
if sync_range:
message_type.add_protocol_label(start=sync_range[0], end=sync_range[1]-1, name="Synchronization", auto_created=True)
message_type.add_protocol_label(start=sync_range[0], end=sync_range[1]-1, name=self.sync_name,
auto_created=True, type=self.sync_field_type)
def __find_preamble_range(self, message: Message):

View File

@@ -1,8 +1,21 @@
from urh.awre.components.Component import Component
class SequenceNumber(Component):
def __init__(self, priority=2, predecessors=None, enabled=True, backend=None):
def __init__(self, fieldtypes, priority=2, predecessors=None, enabled=True, backend=None):
"""
:type fieldtypes: list of FieldType
:param priority:
:param predecessors:
:param enabled:
:param backend:
:param messagetypes:
"""
super().__init__(priority, predecessors, enabled, backend)
self.seqnr_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.SEQUENCE_NUMBER), None)
self.seqnr_field_name = self.seqnr_field_type.caption if self.seqnr_field_type else "Sequence Number"
def _py_find_field(self, messages):
raise NotImplementedError("Todo")

View File

@@ -45,6 +45,12 @@ class FieldType(object):
def __eq__(self, other):
return isinstance(other, FieldType) and self.id_match(other.id)
def __hash__(self):
return hash(self.id)
def __repr__(self):
return "FieldType: {0} - {1} ({2})".format(self.function.name, self.caption, self.display_format_index)
@property
def id(self):
return self.__id
@@ -52,12 +58,6 @@ class FieldType(object):
def id_match(self, id):
return self.__id == id
def __hash__(self):
return hash(self.id)
def __repr__(self):
return "FieldType: {0} - {1} ({2})".format(self.function.name, self.caption, self.display_format_index)
@staticmethod
def default_field_types():
"""

View File

@@ -78,7 +78,8 @@ class MessageType(list):
super().append(lbl)
self.sort()
def add_protocol_label(self, start: int, end: int, name=None, color_ind=None, auto_created=False) -> ProtocolLabel:
def add_protocol_label(self, start: int, end: int, name=None, color_ind=None,
auto_created=False, type:FieldType=None) -> ProtocolLabel:
name = "" if not name else name
used_colors = [p.color_index for p in self]
@@ -90,7 +91,8 @@ class MessageType(list):
else:
color_ind = random.randint(0, len(constants.LABEL_COLORS) - 1)
proto_label = ProtocolLabel(name=name, start=start, end=end, color_index=color_ind, auto_created=auto_created)
proto_label = ProtocolLabel(name=name, start=start, end=end, color_index=color_ind,
auto_created=auto_created, type=type)
if proto_label not in self:
self.append(proto_label)

View File

@@ -18,7 +18,8 @@ class ProtocolLabel(object):
DISPLAY_FORMATS = ["Bit", "Hex", "ASCII", "Decimal"]
SEARCH_TYPES = ["Number", "Bits", "Hex", "ASCII"]
def __init__(self, name: str, start: int, end: int, color_index: int, fuzz_created=False, auto_created=False):
def __init__(self, name: str, start: int, end: int, color_index: int, fuzz_created=False,
auto_created=False, type:FieldType=None):
self.__name = name
self.start = start
self.end = end + 1
@@ -32,9 +33,9 @@ class ProtocolLabel(object):
self.fuzz_created = fuzz_created
self.__type = None # type: FieldType
self.__type = type # type: FieldType
self.display_format_index = 0
self.display_format_index = 0 if type is None else type.display_format_index
self.auto_created = auto_created
@@ -86,7 +87,7 @@ class ProtocolLabel(object):
return False
def __eq__(self, other):
return self.start == other.start and self.end == other.end and self.name == other.name
return self.start == other.start and self.end == other.end and self.name == other.name and self.type == other.type
def __hash__(self):
return hash("{}/{}/{}".format(self.start, self.end, self.name))

View File

@@ -11,6 +11,7 @@ from urh.awre.components.Length import Length
from urh.awre.components.Preamble import Preamble
from urh.awre.components.SequenceNumber import SequenceNumber
from urh.awre.components.Type import Type
from urh.signalprocessing.FieldType import FieldType
from urh.signalprocessing.Participant import Participant
from urh.signalprocessing.ProtocoLabel import ProtocolLabel
from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer
@@ -22,6 +23,15 @@ from urh.signalprocessing.Message import Message
from urh.cythonext import util
class TestAWRE(unittest.TestCase):
def setUp(self):
self.field_types = FieldType.default_field_types()
self.preamble_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.PREAMBLE)
self.sync_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.SYNC)
self.length_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.LENGTH)
self.sequence_number_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.SEQUENCE_NUMBER)
self.dst_address_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.DST_ADDRESS)
self.src_address_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.SRC_ADDRESS)
self.protocol = ProtocolAnalyzer(None)
with open("./data/awre_consistent_addresses.txt") as f:
for line in f:
@@ -46,15 +56,21 @@ class TestAWRE(unittest.TestCase):
for i, message in enumerate(self.zero_crc_protocol.messages):
message.participant = alice if i in alice_indices else bob
@staticmethod
def __field_type_with_function(field_types, function) -> FieldType:
return next(ft for ft in field_types if ft.function == function)
def test_build_component_order(self):
expected_default = [Preamble(), Length(None), Address(None), SequenceNumber(), Type(), Flags()]
expected_default = [Preamble(fieldtypes=[]), Length(fieldtypes=[], length_cluster=None),
Address(fieldtypes=[], xor_matrix=None), SequenceNumber(fieldtypes=[]), Type(), Flags()]
format_finder = FormatFinder(self.protocol)
for expected, actual in zip(expected_default, format_finder.build_component_order()):
assert type(expected) == type(actual)
expected_swapped = [Preamble(), Address(None), Length(None), SequenceNumber(), Type(), Flags()]
expected_swapped = [Preamble(fieldtypes=[]), Address(fieldtypes=[],xor_matrix=None),
Length(fieldtypes=[], length_cluster=None), SequenceNumber(fieldtypes=[]), Type(), Flags()]
format_finder.length_component.priority = 2
format_finder.address_component.priority = 1
@@ -77,14 +93,20 @@ class TestAWRE(unittest.TestCase):
dst_address_start, dst_address_end = 88, 111
src_address_start, src_address_end = 112, 135
preamble_label = ProtocolLabel(name="Preamble", start=preamble_start, end=preamble_end, color_index=0)
sync_label = ProtocolLabel(name="Synchronization", start=sync_start, end=sync_end, color_index=1)
length_label = ProtocolLabel(name="Length", start=length_start, end=length_end, color_index=2)
ack_address_label = ProtocolLabel(name="DST address", start=ack_address_start, end=ack_address_end, color_index=3)
dst_address_label = ProtocolLabel(name="DST address", start=dst_address_start, end=dst_address_end, color_index=4)
src_address_label = ProtocolLabel(name="SRC address", start=src_address_start, end=src_address_end, color_index=5)
preamble_label = ProtocolLabel(name=self.preamble_field_type.caption, type=self.preamble_field_type,
start=preamble_start, end=preamble_end, color_index=0)
sync_label = ProtocolLabel(name=self.sync_field_type.caption, type=self.sync_field_type,
start=sync_start, end=sync_end, color_index=1)
length_label = ProtocolLabel(name=self.length_field_type.caption, type=self.length_field_type,
start=length_start, end=length_end, color_index=2)
ack_address_label = ProtocolLabel(name=self.dst_address_field_type.caption, type=self.dst_address_field_type,
start=ack_address_start, end=ack_address_end, color_index=3)
dst_address_label = ProtocolLabel(name=self.dst_address_field_type.caption, type=self.dst_address_field_type,
start=dst_address_start, end=dst_address_end, color_index=4)
src_address_label = ProtocolLabel(name=self.src_address_field_type.caption, type=self.src_address_field_type,
start=src_address_start, end=src_address_end, color_index=5)
ff = FormatFinder(self.protocol, self.participants)
ff = FormatFinder(protocol=self.protocol, participants=self.participants, field_types=self.field_types)
ff.perform_iteration()
self.assertIn(preamble_label, self.protocol.default_message_type)
@@ -124,8 +146,10 @@ class TestAWRE(unittest.TestCase):
sof_start = 11
sof_end = 14
preamble_label = ProtocolLabel(name="Preamble", start=preamble_start, end=preamble_end, color_index=0)
sync_label = ProtocolLabel(name="Synchronization", start=sof_start, end=sof_end, color_index=1)
preamble_label = ProtocolLabel(name=self.preamble_field_type.caption, type=self.preamble_field_type,
start=preamble_start, end=preamble_end, color_index=0)
sync_label = ProtocolLabel(name=self.sync_field_type.caption, type=self.sync_field_type,
start=sof_start, end=sof_end, color_index=1)
ff = FormatFinder(enocean_protocol, self.participants)
@@ -135,7 +159,7 @@ class TestAWRE(unittest.TestCase):
self.assertIn(preamble_label, enocean_protocol.default_message_type)
self.assertIn(sync_label, enocean_protocol.default_message_type)
self.assertTrue(not any(lbl.name == "Length" for lbl in enocean_protocol.default_message_type))
self.assertTrue(not any(lbl.name == self.length_field_type.caption for lbl in enocean_protocol.default_message_type))
self.assertTrue(not any("address" in lbl.name.lower() for lbl in enocean_protocol.default_message_type))
def test_address_candidate_finding(self):