diff --git a/bin/greenrpc-server b/bin/greenrpc-server index 9840727..b5d9980 100755 --- a/bin/greenrpc-server +++ b/bin/greenrpc-server @@ -6,7 +6,7 @@ from greenrpc.server import TCPServer, WSGIServer if __name__ == "__main__": parser = argparse.ArgumentParser(description="Start a new GreenRPC TCP Server") - parser.add_argument("module", metavar="", type=str, + parser.add_argument("modules", metavar="", type=str, nargs="+", help="Python module to expose for the RPC Server") default_bind = "127.0.0.1:%s" % (DEFAULT_PORT, ) parser.add_argument("--bind", dest="bind", type=str, default=default_bind, @@ -20,9 +20,9 @@ if __name__ == "__main__": address, _, port = args.bind.partition(":") bind = (address, int(port)) if args.http: - server = WSGIServer(args.module, bind=bind, spawn=args.spawn) + server = WSGIServer(args.modules, bind=bind, spawn=args.spawn) else: - server = TCPServer(args.module, bind=bind, spawn=args.spawn) + server = TCPServer(args.modules, bind=bind, spawn=args.spawn) try: server.serve_forever() except KeyboardInterrupt: diff --git a/greenrpc/base.py b/greenrpc/base.py index 1e1fae7..dac9415 100644 --- a/greenrpc/base.py +++ b/greenrpc/base.py @@ -6,17 +6,32 @@ import msgpack class BaseServer(object): SOCKET_BUFFER_SIZE = 1024 + ALLOWED_TYPES = (types.FunctionType, types.MethodType, types.BuiltinFunctionType, types.BuiltinMethodType) def __init__(self, services): - if isinstance(services, (dict, types.ModuleType)): - self.services = services - elif isinstance(services, basestring): - self.services = __import__(services) - else: + self.services = self.load_services(services) + if not self.services: raise TypeError("First argument to BaseServer.__init__ must be a dict or a string") self.packer = msgpack.Packer() + def load_services(self, module): + services = {} + if isinstance(module, dict): + services.update(module) + elif isinstance(module, types.ModuleType): + for name in dir(module): + if not name.startswith("_"): + attr = getattr(module, name) + if isinstance(attr, self.ALLOWED_TYPES): + services[name] = attr + elif isinstance(module, basestring): + services.update(self.load_services(__import__(module))) + elif isinstance(module, (tuple, list)): + for m in module: + services.update(self.load_services(m)) + return services + def unpack_requests(self, sock): unpacker = msgpack.Unpacker() while True: @@ -42,11 +57,11 @@ class BaseServer(object): if not req_method: result["error"] = "No request method was provided" - elif not hasattr(self.services, req_method): + elif not isinstance(self.services.get(req_method), self.ALLOWED_TYPES): result["error"] = "Unknown request method '%s'" % (req_method, ) else: try: - result["results"] = getattr(self.services, req_method)(*req_args) + result["results"] = self.services[req_method](*req_args) except Exception, e: result["error"] = e.message