Backport module for sys.audit and sys.addaudithook mechanism
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

337 lines
9.3 KiB

"""This script contains the actual auditing tests.
It should not be imported directly, but should be run by the test_audit
module with arguments identifying each test.
"""
import contextlib
import typing
import os
import sys
import sysaudit
class TestHook:
"""Used in standard hook tests to collect any logged events.
Should be used in a with block to ensure that it has no impact
after the test completes.
"""
def __init__(self, raise_on_events=None, exc_type=RuntimeError):
self.raise_on_events = raise_on_events or ()
self.exc_type = exc_type
self.seen = []
self.closed = False
def __enter__(self, *a): # type: (typing.Any) -> TestHook
sysaudit.addaudithook(self)
return self
def __exit__(self, *a):
self.close()
def close(self):
self.closed = True
@property
def seen_events(self):
return [i[0] for i in self.seen]
def __call__(
self, event, args
): # type: (str, typing.Tuple[typing.Any, ...]) -> None
if self.closed:
return
self.seen.append((event, args))
if event in self.raise_on_events:
raise self.exc_type("saw event " + event)
# Simple helpers, since we are not in unittest here
def assertEqual(x, y):
if x != y:
raise AssertionError("{!r} should equal {!r}".format(x, y))
def assertIn(el, series):
if el not in series:
raise AssertionError("{!r} should be in {!r}".format(el, series))
def assertNotIn(el, series):
if el in series:
raise AssertionError("{!r} should not be in {!r}".format(el, series))
def assertSequenceEqual(x, y):
if len(x) != len(y):
raise AssertionError("{!r} should equal {!r}".format(x, y))
if any(ix != iy for ix, iy in zip(x, y)):
raise AssertionError("{!r} should equal {!r}".format(x, y))
@contextlib.contextmanager
def assertRaises(ex_type):
try:
yield
assert False, "expected {}".format(ex_type)
except BaseException as ex:
if isinstance(ex, AssertionError):
raise
assert type(ex) is ex_type, "{} should be {}".format(ex, ex_type)
def test_basic():
with TestHook() as hook:
sysaudit.audit("test_event", 1, 2, 3)
assertEqual(hook.seen[0][0], "test_event")
assertEqual(hook.seen[0][1], (1, 2, 3))
def test_block_add_hook():
# Raising an exception should prevent a new hook from being added,
# but will not propagate out.
with TestHook(raise_on_events="sys.addaudithook") as hook1:
with TestHook() as hook2:
sysaudit.audit("test_event")
assertIn("test_event", hook1.seen_events)
assertNotIn("test_event", hook2.seen_events)
def test_block_add_hook_baseexception():
# Raising BaseException will propagate out when adding a hook
with assertRaises(BaseException):
with TestHook(
raise_on_events="sys.addaudithook", exc_type=BaseException
) as hook1:
# Adding this next hook should raise BaseException
with TestHook() as hook2:
pass
def test_pickle():
import pickle
class PicklePrint:
def __reduce_ex__(self, p):
return str, ("Pwned!",)
payload_1 = pickle.dumps(PicklePrint())
payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3))
# Before we add the hook, ensure our malicious pickle loads
assertEqual("Pwned!", pickle.loads(payload_1))
with TestHook(raise_on_events="pickle.find_class") as hook:
with assertRaises(RuntimeError):
# With the hook enabled, loading globals is not allowed
pickle.loads(payload_1)
# pickles with no globals are okay
pickle.loads(payload_2)
def test_monkeypatch():
class A:
pass
class B:
pass
class C(A):
pass
a = A()
with TestHook() as hook:
# Catch name changes
C.__name__ = "X"
# Catch type changes
C.__bases__ = (B,) # noqa
# Ensure bypassing __setattr__ is still caught
type.__dict__["__bases__"].__set__(C, (B,))
# Catch attribute replacement
C.__init__ = B.__init__
# Catch attribute addition
C.new_attr = 123 # noqa
# Catch class changes
a.__class__ = B # noqa
actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"]
assertSequenceEqual(
[(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual
)
def test_open():
# SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open()
try:
import ssl
load_dh_params = ssl.create_default_context().load_dh_params
except ImportError:
load_dh_params = None
# Try a range of "open" functions.
# All of them should fail
with TestHook(raise_on_events={"open"}) as hook:
for args in [
(open, sys.argv[2], "r"),
(open, sys.executable, "rb"),
(open, 3, "wb"),
(open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1),
(load_dh_params, sys.argv[2]),
]:
fn, args = args[0], args[1:]
if not fn:
continue
with assertRaises(RuntimeError):
fn(*args)
actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]]
actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]]
assertSequenceEqual(
[
i
for i in [
(sys.argv[2], "r"),
(sys.executable, "r"),
(3, "w"),
(sys.argv[2], "w"),
(sys.argv[2], "rb") if load_dh_params else None,
]
if i is not None
],
actual_mode,
)
assertSequenceEqual([], actual_flag)
def test_cantrace():
traced = []
def trace(frame, event, *args):
if frame.f_code == TestHook.__call__.__code__:
traced.append(event)
old = sys.settrace(trace)
try:
with TestHook() as hook:
# No traced call
eval("1")
# No traced call
hook.__cantrace__ = False # noqa
eval("2")
# One traced call
hook.__cantrace__ = True # noqa
eval("3")
# Two traced calls (writing to private member, eval)
hook.__cantrace__ = 1 # noqa
eval("4")
# One traced call (writing to private member)
hook.__cantrace__ = 0 # noqa
finally:
sys.settrace(old)
assertSequenceEqual(["call"] * 4, traced)
def test_mmap():
import mmap
with TestHook() as hook:
mmap.mmap(-1, 8)
assertEqual(hook.seen[0][1][:2], (-1, 8))
def test_excepthook(): # type: () -> None
def excepthook(exc_type, exc_value, exc_tb):
if exc_type is not RuntimeError:
sys.__excepthook__(exc_type, exc_value, exc_tb)
def hook(event, args): # type: (str, typing.Tuple[typing.Any, ...]) -> None
if event == "sys.excepthook":
if not isinstance(args[2], args[1]):
raise TypeError(
"Expected isinstance({!r}, {!r})".format(args[2], args[1])
)
if args[0] != excepthook:
raise ValueError("Expected {} == {}".format(args[0], excepthook))
print(event, repr(args[2]))
sysaudit.addaudithook(hook)
sys.excepthook = excepthook
raise RuntimeError("fatal-error")
def test_unraisablehook(): # type: () -> None
from _testcapi import write_unraisable_exc # type: ignore
def unraisablehook(hookargs): # noqa: F841
pass
def hook(event, args): # type: (str, typing.Tuple[typing.Any, ...]) -> None
if event == "sys.unraisablehook":
if args[0] != unraisablehook:
raise ValueError("Expected {} == {}".format(args[0], unraisablehook))
print(event, repr(args[1].exc_value), args[1].err_msg)
sysaudit.addaudithook(hook)
sys.unraisablehook = unraisablehook # type: ignore [attr-defined]
write_unraisable_exc(RuntimeError("nonfatal-error"), "for audit hook test", None)
def test_winreg(): # type: () -> None
from winreg import OpenKey, EnumKey, CloseKey, HKEY_LOCAL_MACHINE # type: ignore
def hook(event, args): # type: (str, typing.Tuple[typing.Any, ...]) -> None
if not event.startswith("winreg."):
return
print(event, args)
sysaudit.addaudithook(hook)
k = OpenKey(HKEY_LOCAL_MACHINE, "Software")
EnumKey(k, 0)
try:
EnumKey(k, 10000)
except OSError:
pass
else:
raise RuntimeError("Expected EnumKey(HKLM, 10000) to fail")
kv = k.Detach()
CloseKey(kv)
def test_socket(): # type: () -> None
import socket
def hook(event, args): # type: (str, typing.Tuple[typing.Any, ...]) -> None
if event.startswith("socket."):
print(event, args)
sysaudit.addaudithook(hook)
socket.gethostname()
# Don't care if this fails, we just want the audit message
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
# Don't care if this fails, we just want the audit message
sock.bind(("127.0.0.1", 8080))
except Exception:
pass
finally:
sock.close()
if __name__ == "__main__":
test = sys.argv[1]
globals()[test]()