diff --git a/sysaudit/__init__.py b/sysaudit/__init__.py index b043f91..aad360d 100644 --- a/sysaudit/__init__.py +++ b/sysaudit/__init__.py @@ -1,38 +1,83 @@ __all__ = ["audit", "addaudithook"] +import os import sys # Python 3.8+ # DEV: We could check `sys.version_info >= (3, 8)`, but if auditing ever gets # back ported we want to take advantage of that +std_audit = None +std_addaudithook = None if hasattr(sys, "audit") and hasattr(sys, "addaudithook"): - audit = sys.audit - addaudithook = sys.addaudithook -else: + std_audit = sys.audit + std_addaudithook = sys.addaudithook + +# Try to import Cython version +csysaudit_audit = None +csysaudit_addaudithook = None +try: + from . import _csysaudit + + csysaudit_audit = _csysaudit.audit + csysaudit_addaudithook = _csysaudit.addaudithook +except ImportError: + pass + + +# Pure-python implementation +_hooks = list() + + +def py_audit(event, *args): + global _hooks + + for hook in _hooks: + hook(event, args) + + +def py_addaudithook(callback): + global _hooks + + # https://docs.python.org/3.8/library/sys.html#sys.addaudithook + # Raise an auditing event `sys.addaudithook` with no arguments. + # If any existing hooks raise an exception derived from RuntimeError, + # the new hook will not be added and the exception suppressed. + # As a result, callers cannot assume that their hook has been added + # unless they control all existing hooks. try: - from ._csysaudit import audit, addaudithook - except ImportError: - _hooks = list() - - def audit(event, *args): - global _hooks - # Grab a copy of hooks so we don't need to lock here - for hook in _hooks[:]: - hook(event, args) - - def addaudithook(callback): - global _hooks - - # https://docs.python.org/3.8/library/sys.html#sys.addaudithook - # Raise an auditing event `sys.addaudithook` with no arguments. - # If any existing hooks raise an exception derived from RuntimeError, - # the new hook will not be added and the exception suppressed. - # As a result, callers cannot assume that their hook has been added - # unless they control all existing hooks. - try: - audit("sys.addaudithook") - except RuntimeError: - return - - if callback not in _hooks: - _hooks.append(callback) + audit("sys.addaudithook") + except RuntimeError: + return + + _hooks.append(callback) + + +# Choose the best implementation +# DEV: We still import/create all of them +# so we can easily access each implementation +# for testing +SYSAUDIT_IMPL = os.getenv("SYSAUDIT_IMPL") +if SYSAUDIT_IMPL: + if SYSAUDIT_IMPL == "stdlib": + audit = std_audit + addaudithook = std_addaudithook + elif SYSAUDIT_IMPL == "csysaudit": + audit = csysaudit_audit + addaudithook = csysaudit_addaudithook + elif SYSAUDIT_IMPL == "pysysaudit": + audit = py_audit + addaudithook = py_addaudithook + else: + raise ValueError( + "SYSAUDIT_IMPL must be one of ('stdlib', 'csysaudit', 'pysysaudit')" + ) +else: + if std_audit and std_addaudithook: + audit = std_audit + addaudithook = std_addaudithook + elif csysaudit_audit and csysaudit_addaudithook: + audit = csysaudit_audit + addaudithook = csysaudit_addaudithook + else: + audit = py_audit + addaudithook = py_addaudithook diff --git a/sysaudit/__init__.pyi b/sysaudit/__init__.pyi index cd66550..a423560 100644 --- a/sysaudit/__init__.pyi +++ b/sysaudit/__init__.pyi @@ -4,3 +4,15 @@ def audit(event: str, *args: typing.Any) -> None: ... def addaudithook( hook: typing.Callable[[str, typing.Tuple[typing.Any, ...]], None] ) -> None: ... + +_audit_fn = typing.Callable[[str, typing.Any], None] +_addaudithook_fn = typing.Callable[ + [typing.Callable[[str, typing.Tuple[typing.Any, ...]], None]], None +] + +std_audit = typing.Optional[_audit_fn] +std_addaudithook = typing.Optional[_addaudithook_fn] +csysaudit_audit = typing.Optional[_audit_fn] +csysaudit_addaudithook = typing.Optional[_addaudithook_fn] +py_audit = typing.Optional[_audit_fn] +py_addaudithook = typing.Optional[_addaudithook_fn] diff --git a/tests/audit-tests.py b/tests/audit-tests.py index 85dc781..dc74e1b 100644 --- a/tests/audit-tests.py +++ b/tests/audit-tests.py @@ -7,6 +7,7 @@ module with arguments identifying each test. import contextlib import typing +import os import sys import sysaudit diff --git a/tests/test_audit.py b/tests/test_audit.py index 9942e78..e65b6f4 100644 --- a/tests/test_audit.py +++ b/tests/test_audit.py @@ -8,7 +8,6 @@ import subprocess import sys import unittest - AUDIT_TESTS_PY = os.path.abspath( os.path.join(os.path.dirname(__file__), "audit-tests.py") ) @@ -17,13 +16,20 @@ skip_old_py = unittest.skipIf( sys.version_info < (3, 8), "Skipping tests testing built-in events" ) +if sys.version_info < (3, 8): + IMPLEMENTATIONS = ("csysaudit", "pysysaudit") +else: + IMPLEMENTATIONS = ("stdlib", "csysaudit", "pysysaudit") + class AuditTest(unittest.TestCase): - def do_test(self, *args): + def _do_test(self, *args, **kwargs): popen_kwargs = dict( stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) + if "impl" in kwargs: + popen_kwargs["env"] = dict(SYSAUDIT_IMPL=kwargs["impl"]) if sys.version_info >= (3, 6): popen_kwargs["encoding"] = "utf-8" @@ -36,13 +42,19 @@ class AuditTest(unittest.TestCase): if p.returncode: self.fail("".join(p.stderr)) - def run_python(self, *args): + def do_test(self, *args): + for impl in IMPLEMENTATIONS: + return self._do_test(*args, impl=impl) + + def _run_python(self, *args, **kwargs): events = [] popen_kwargs = dict( stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) + if "impl" in kwargs: + popen_kwargs["env"] = dict(SYSAUDIT_IMPL=kwargs["impl"]) if sys.version_info >= (3, 6): popen_kwargs["encoding"] = "utf-8" @@ -57,6 +69,10 @@ class AuditTest(unittest.TestCase): "".join(p.stderr), ) + def run_python(self, *args): + for impl in IMPLEMENTATIONS: + return self._run_python(*args, impl=impl) + def test_basic(self): self.do_test("test_basic") diff --git a/tests/test_import.py b/tests/test_import.py index 7842bf5..c602172 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -4,9 +4,17 @@ import sysaudit def test_module(): # type: () -> None + assert sysaudit.audit is not None + assert sysaudit.addaudithook is not None + if sys.version_info >= (3, 8, 0): + assert sysaudit.std_audit == sys.audit # type: ignore + assert sysaudit.std_addaudithook == sys.addaudithook # type: ignore assert sysaudit.audit == sys.audit # type: ignore [attr-defined] assert sysaudit.addaudithook == sys.addaudithook # type: ignore [attr-defined] else: - assert sysaudit.audit == sysaudit._csysaudit.audit - assert sysaudit.addaudithook == sysaudit._csysaudit.addaudithook + assert sysaudit.audit == sysaudit.csysaudit_audit + assert sysaudit.addaudithook == sysaudit.csysaudit_addaudithook + + assert sysaudit.py_audit is not None + assert sysaudit.py_addaudithook is not None