debug flag and print function for label assigner

This commit is contained in:
jopohl
2016-07-04 09:46:07 +02:00
parent 89647ec22b
commit b6830d0da6
3 changed files with 33 additions and 28 deletions

View File

@@ -15,19 +15,19 @@ class LabelAssigner(object):
self.__blocks = blocks
self.sync_end = None
self.preamble_end = None
self.constant_intervals = defaultdict(set)
self.constant_intervals_per_block = defaultdict(list)
self.common_intervals = defaultdict(set)
self.common_intervals_per_block = defaultdict(list)
@property
def is_initialized(self):
return len(self.constant_intervals) > 0 if len(self.__blocks) > 0 else True
return len(self.common_intervals) > 0 if len(self.__blocks) > 0 else True
def __search_constant_intervals(self):
if self.preamble_end is None:
self.find_preamble()
self.constant_intervals.clear()
self.constant_intervals_per_block.clear()
self.common_intervals.clear()
self.common_intervals_per_block.clear()
for i in range(0, len(self.__blocks)):
for j in range(i + 1, len(self.__blocks)):
@@ -43,29 +43,32 @@ class LabelAssigner(object):
else:
if constant_length > constants.SHORTEST_CONSTANT_IN_BITS:
interval = Interval(self.preamble_end+range_start, self.preamble_end+k-1)
self.constant_intervals[(i,j)].add(interval)
self.constant_intervals_per_block[i].append(interval)
self.constant_intervals_per_block[j].append(interval)
self.common_intervals[(i, j)].add(interval)
self.common_intervals_per_block[i].append(interval)
self.common_intervals_per_block[j].append(interval)
constant_length = 0
range_start = k + 1
if constant_length > constants.SHORTEST_CONSTANT_IN_BITS:
interval = Interval(self.preamble_end+range_start, self.preamble_end+ end)
self.constant_intervals[(i,j)].add(interval)
self.constant_intervals_per_block[i].append(interval)
self.constant_intervals_per_block[j].append(interval)
self.common_intervals[(i, j)].add(interval)
self.common_intervals_per_block[i].append(interval)
self.common_intervals_per_block[j].append(interval)
# for block_index in sorted(self.constant_intervals_per_block):
# interval_info = ""
# for interval in sorted(set(self.constant_intervals_per_block[block_index])):
# interval_info += str(interval) + " (" + str(self.constant_intervals_per_block[block_index].count(interval)) + ") "
#
# print(block_index, interval_info)
#
# for block_index in sorted(self.constant_intervals):
# print(block_index, sorted(r for r in self.constant_intervals[block_index] if r.start != self.preamble_end), end=" ")
# print(" ".join([self.__get_hex_value_for_block(self.__blocks[block_index[0]], interval) for interval in sorted(r for r in self.constant_intervals[block_index] if r.start!=self.preamble_end)]))
def print_common_intervals(self):
print("Raw common intervals\n=================")
for block_index in sorted(self.common_intervals):
print(block_index, sorted(r for r in self.common_intervals[block_index] if r.start != self.preamble_end), end=" ")
print(" ".join([self.__get_hex_value_for_block(self.__blocks[block_index[0]], interval) for interval in sorted(r for r in self.common_intervals[block_index] if r.start != self.preamble_end)]))
print("Merged common intervals\n=================")
for block_index in sorted(self.common_intervals_per_block):
interval_info = ""
for interval in sorted(set(self.common_intervals_per_block[block_index])):
interval_info += str(interval) + " (" + str(self.common_intervals_per_block[block_index].count(interval)) + ") "
print(block_index, interval_info)
def __get_hex_value_for_block(self, block, interval):
start, end = block.convert_range(interval.start + 1, interval.end, from_view=0, to_view=1, decoded=True)
@@ -94,7 +97,7 @@ class LabelAssigner(object):
self.__search_constant_intervals()
possible_sync_pos = defaultdict(int)
for const_range in (cr for const_interval in self.constant_intervals.values() for cr in const_interval):
for const_range in (cr for const_interval in self.common_intervals.values() for cr in const_interval):
const_range = Interval(4 * ((const_range.start + 1) // 4) - 1, 4 * ((const_range.end + 1) // 4) - 1) # align to nibbles
possible_sync_pos[const_range] += int(const_range.start == self.preamble_end)
@@ -117,12 +120,12 @@ class LabelAssigner(object):
common_constant_intervals = set()
if (0,1) not in self.constant_intervals:
if (0,1) not in self.common_intervals:
return []
for candidate in self.constant_intervals[(0,1)]:
for candidate in self.common_intervals[(0, 1)]:
for j in range(1, len(self.__blocks)):
overlapping_intervals = {candidate.find_common_interval(interval) for interval in self.constant_intervals[(0, j)]}
overlapping_intervals = {candidate.find_common_interval(interval) for interval in self.common_intervals[(0, j)]}
overlapping_intervals.discard(None)
if len(overlapping_intervals) == 0:

View File

@@ -861,6 +861,8 @@ class ProtocolAnalyzer(object):
if not decoder_found and fallback:
block.decoder = fallback
def auto_assign_labels(self):
def auto_assign_labels(self, debug=False):
label_assigner = LabelAssigner(self.blocks)
label_assigner.auto_assign_to_labelset(self.default_labelset)
label_assigner.auto_assign_to_labelset(self.default_labelset)
if debug:
label_assigner.print_common_intervals()

View File

@@ -140,7 +140,7 @@ class TestAutoAssignments(unittest.TestCase):
length_end = 71
t = time.time()
self.protocol.auto_assign_labels()
self.protocol.auto_assign_labels(debug=True)
print("Time for auto assigning labels: ", str(time.time()-t)) # 0.020628690719604492
preamble_label = ProtocolLabel(name="Preamble", start=preamble_start, end=preamble_end, val_type_index=0, color_index=0)