diff --git a/sysaudit/__init__.py b/sysaudit/__init__.py index b043f91..989c6bf 100644 --- a/sysaudit/__init__.py +++ b/sysaudit/__init__.py @@ -1,5 +1,6 @@ -__all__ = ["audit", "addaudithook"] +__all__ = ["audit", "addaudithook", "subscribe", "Span"] +import collections import sys # Python 3.8+ @@ -36,3 +37,80 @@ else: if callback not in _hooks: _hooks.append(callback) + + +_subscriptions = collections.defaultdict(list) +_subscription_hook_active = False + + +def _subscription_hook(event, args): + if event in _subscriptions: + # Grab a copy of hooks so we don't need to lock here + for hook in _subscriptions[event][:]: + hook(args) + + +def subscribe(event, hook): + global _subscriptions + global _subscription_hook_active + + if not _subscription_hook_active: + addaudithook(_subscription_hook) + _subscription_hook_active = True + + if hook not in _subscriptions[event]: + _subscriptions[event].append(hook) + + +class Span: + __slots__ = ("name", "started", "ended") + + class Message: + __slots__ = ("type", "span", "data") + + def __init__(self, type, span, data=None): + self.type = type + self.span = span + self.data = data + + def __str__(self): + return "{0}(type={1!r}, span={2}, data={3!r})".format( + self.__class__.__name__, self.type, self.span, self.data + ) + + def __init__(self, name): + self.name = name + self.started = False + self.ended = False + + @property + def id(self): + return id(self) + + def message(self, type, data=None): + audit(self.name, self.Message(type, self, data)) + + def start(self, data=None): + if not self.started: + self.message("start", data) + self.started = True + + def end(self, data=None): + if not self.ended: + self.message("end", data) + self.ended = True + + def annotate(self, data): + self.message("annotate", data) + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end(data=dict(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb)) + + def __str__(self): + return "{0}(name={1!r}, id={2!r})".format( + self.__class__.__name__, self.name, self.id + )