diff --git a/tests/gnrc_rpl_srh/tests/01-run.py b/tests/gnrc_rpl_srh/tests/01-run.py index 087f5914bc..af318822be 100755 --- a/tests/gnrc_rpl_srh/tests/01-run.py +++ b/tests/gnrc_rpl_srh/tests/01-run.py @@ -11,13 +11,12 @@ import random import re import sys import subprocess -import threading from scapy.all import Ether, IPv6, UDP, \ IPv6ExtHdrHopByHop, IPv6ExtHdrDestOpt, \ IPv6ExtHdrFragment, IPv6ExtHdrRouting, \ ICMPv6ParamProblem, ICMPv6TimeExceeded, \ - sendp, srp1, sniff + sendp, srp1, AsyncSniffer from testrunner import run, check_unittests @@ -31,44 +30,33 @@ EXT_HDR_NH = { } -class Sniffer(threading.Thread): +class Sniffer(object): def __init__(self, iface, *args, **kwargs): - super().__init__(*args, **kwargs) - self.stop_filter = None - self.stopped = False self.iface = iface - self.ps = [] - self.enter_loop = threading.Event() - self.sniff_results = threading.Event() - - def run(self): - while True: - self.enter_loop.wait() - self.enter_loop.clear() - if self.stopped: - return - if self.stop_filter: - self.ps = sniff(stop_filter=self.stop_filter, - iface=self.iface, timeout=5) - self.stop_filter = None - self.sniff_results.set() + self.sniffer = None + self.stop_filter = None def start_sniff(self, stop_filter): + assert self.sniffer is None self.stop_filter = stop_filter - self.enter_loop.set() + self.sniffer = AsyncSniffer( + iface=self.iface, + stop_filter=stop_filter, + ) + self.sniffer.start() - def wait_for_sniff_results(self): - res = [] - if self.sniff_results.wait(5): - self.sniff_results.clear() - res = self.ps - self.ps = [] - return res - - def stop(self): - self.stopped = True - self.enter_loop.set() - self.join() + def wait_for_sniff_results(self, timeout=5): + assert self.sniffer is not None + self.sniffer.join(timeout=timeout) + sniffer = self.sniffer + self.sniffer = None + if sniffer.results is None: + return [] + return [p for p in sniffer.results + # filter out packets only belonging to stop_filter if + # it existed + if sniffer.kwargs.get("stop_filter") is None or + sniffer.kwargs["stop_filter"](p)] sniffer = None @@ -347,7 +335,6 @@ def testfunc(child): child.expect(r"(?Pfe80::[A-Fa-f:0-9]+)\s") lladdr_dst = child.match.group("lladdr").lower() sniffer = Sniffer(tap) - sniffer.start() def run(func): if child.logfile == sys.stdout: @@ -372,7 +359,6 @@ def testfunc(child): run(test_seq_left_0) run(test_time_exc) print("SUCCESS") - sniffer.stop() if __name__ == "__main__":