diff --git a/fluent/handler.py b/fluent/handler.py index 78eb1d8..960d074 100644 --- a/fluent/handler.py +++ b/fluent/handler.py @@ -115,12 +115,14 @@ def __init__(self, host='localhost', port=24224, timeout=3.0, + packager='msgpack', verbose=False, buffer_overflow_handler=None): self.tag = tag self.sender = sender.FluentSender(tag, host=host, port=port, + packager=packager, timeout=timeout, verbose=verbose, buffer_overflow_handler=buffer_overflow_handler) logging.Handler.__init__(self) @@ -132,7 +134,7 @@ def emit(self, record): def close(self): self.acquire() try: - self.sender._close() + self.sender.close() logging.Handler.close(self) finally: self.release() diff --git a/fluent/sender.py b/fluent/sender.py index 68fa9ac..597c509 100644 --- a/fluent/sender.py +++ b/fluent/sender.py @@ -1,13 +1,15 @@ # -*- coding: utf-8 -*- from __future__ import print_function -import socket import threading import time import traceback +import json import msgpack +from fluent.transport import Transport, TransportError + _global_sender = None @@ -27,14 +29,17 @@ def setup(tag, **kwargs): def get_global_sender(): return _global_sender + def close(): get_global_sender().close() + class FluentSender(object): def __init__(self, tag, host='localhost', port=24224, + packager="msgpack", bufmax=1 * 1024 * 1024, timeout=3.0, verbose=False, @@ -48,17 +53,18 @@ def __init__(self, self.timeout = timeout self.verbose = verbose self.buffer_overflow_handler = buffer_overflow_handler + self.packager = self.get_packager(packager) - self.socket = None self.pendings = None self.lock = threading.Lock() self._last_error_threadlocal = threading.local() + self.transport = Transport(self.host, self.port, self.timeout) try: - self._reconnect() - except socket.error: + self.transport.connect() + except TransportError: # will be retried in emit() - self._close() + self.transport.close() def emit(self, label, data): cur_time = int(time.time()) @@ -80,16 +86,15 @@ def close(self): try: if self.pendings: try: - self._send_data(self.pendings) + self.transport.send(self.pendings) except Exception: self._call_buffer_overflow_handler(self.pendings) - self._close() + self.transport.close() self.pendings = None finally: self.lock.release() - def _make_packet(self, label, timestamp, data): if label: tag = '.'.join((self.tag, label)) @@ -98,7 +103,16 @@ def _make_packet(self, label, timestamp, data): packet = (tag, timestamp, data) if self.verbose: print(packet) - return msgpack.packb(packet) + return self.packager(packet) + + def get_packager(self, name): + if name == 'json': + return json.dumps + + if name == 'msgpack': + return msgpack.packb + + raise RuntimeError("Unknown packager: {}", name) def _send(self, bytes_): self.lock.acquire() @@ -114,18 +128,17 @@ def _send_internal(self, bytes_): bytes_ = self.pendings try: - self._send_data(bytes_) + self.transport.send(bytes_) # send finished self.pendings = None return True - except socket.error as e: - #except Exception as e: + except TransportError as e: self.last_error = e - # close socket - self._close() + # close transport + self.transport.close() # clear buffer if it exceeds max bufer size if self.pendings and (len(self.pendings) > self.bufmax): @@ -136,24 +149,6 @@ def _send_internal(self, bytes_): return False - def _send_data(self, bytes_): - # reconnect if possible - self._reconnect() - # send message - self.socket.sendall(bytes_) - - def _reconnect(self): - if not self.socket: - if self.host.startswith('unix://'): - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.settimeout(self.timeout) - sock.connect(self.host[len('unix://'):]) - else: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(self.timeout) - sock.connect((self.host, self.port)) - self.socket = sock - def _call_buffer_overflow_handler(self, pending_events): try: if self.buffer_overflow_handler: @@ -168,13 +163,8 @@ def last_error(self): @last_error.setter def last_error(self, err): - self._last_error_threadlocal.exception = err + self._last_error_threadlocal.exception = err - def clear_last_error(self, _thread_id = None): + def clear_last_error(self, _thread_id=None): if hasattr(self._last_error_threadlocal, 'exception'): delattr(self._last_error_threadlocal, 'exception') - - def _close(self): - if self.socket: - self.socket.close() - self.socket = None diff --git a/fluent/transport.py b/fluent/transport.py new file mode 100644 index 0000000..630a7b8 --- /dev/null +++ b/fluent/transport.py @@ -0,0 +1,69 @@ +# encoding=utf-8 + +import socket + +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + + +class Transport(object): + def __init__(self, host, port, timeout): + self.host = host + self.port = port + self.timeout = timeout + + self._conn = None + + def close(self): + if self._conn: + self._conn.close() + self._conn = None + + def connect(self): + if self._conn: + return + + family, socket_type, addr = get_connection_params(self.host, self.port) + self._conn = socket.socket(family, socket_type) + self._conn.connect(addr) + self._conn.settimeout(self.timeout) + + def send(self, data): + self.connect() + self._conn.sendall(data.encode('utf-8')) + + +def get_connection_params(url, port=0): + parsed = urlparse(url) + + port = parsed.port or port or 0 + + scheme = parsed.scheme.lower() + if scheme == 'unix': + family = socket.AF_UNIX + socket_type = socket.SOCK_STREAM + addr = parsed.hostname + + elif scheme == 'udp': + family = socket.AF_INET + socket_type = socket.SOCK_DGRAM + addr = (parsed.hostname, port) + + elif scheme in ('tcp', ''): + family = socket.AF_INET + socket_type = socket.SOCK_STREAM + addr = (parsed.hostname or parsed.path, port) + + else: + raise TransportError( + "Unknown connection protocol: url={}, port={}".format( + url, port, + ) + ) + + return family, socket_type, addr + + +TransportError = socket.error diff --git a/tests/mockserver.py b/tests/mockserver.py index b385d0e..4d77e2e 100644 --- a/tests/mockserver.py +++ b/tests/mockserver.py @@ -1,54 +1,108 @@ # -*- coding: utf-8 -*- +import socket + try: from cStringIO import StringIO as BytesIO except ImportError: from io import BytesIO -import socket -import threading -import time - from msgpack import Unpacker +from fluent.transport import get_connection_params + -class MockRecvServer(threading.Thread): - """ - Single threaded server accepts one connection and recv until EOF. - """ - def __init__(self, host='localhost', port=0): - if host.startswith('unix://'): - self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self._sock.bind(host[len('unix://'):]) +def create_server(host, port=0): + family, socket_type, addr = get_connection_params(host, port) + + if socket_type == UDPServer.SOCKET_TYPE: + conn_type = UDPServer + else: + if family == socket.AF_UNIX: + conn_type = UnixSocketServer else: - self._sock = socket.socket() - self._sock.bind((host, port)) - self.port = self._sock.getsockname()[1] - self._sock.listen(1) - self._buf = BytesIO() + conn_type = TCPServer + + conn = conn_type(family, addr) + conn.listen() + return conn + - threading.Thread.__init__(self) - self.start() +class Server(object): + SOCKET_TYPE = "" - def run(self): - sock = self._sock - con, _ = sock.accept() + def __init__(self, family, addr): + self._family = family + self._addr = addr + + self._sock = socket.socket(self._family, self.SOCKET_TYPE) + self._sock.bind(self._addr) + self._sock.settimeout(0.5) + + def listen(self): + # Okay move along, move along people, there's nothing to see here! + pass + + def recv(self, qty_messages=1): + data = BytesIO() while True: - data = con.recv(4096) - if not data: + chunk = self.recv_raw() + data.seek(0, 2) + data.write(chunk) + + data.seek(0) + messages = list(Unpacker(data, encoding='utf-8')) + if len(messages) >= qty_messages: break - self._buf.write(data) - con.close() - sock.close() - self._sock = None - - def wait(self): - while self._sock: - time.sleep(0.1) - - def get_recieved(self): - self.wait() - self._buf.seek(0) - # TODO: have to process string encoding properly. currently we assume - # that all encoding is utf-8. - return list(Unpacker(self._buf, encoding='utf-8')) + + return list(messages) + + def recv_raw(self, limit=1024): + raise NotImplementedError + + def close(self): + self._sock.close() + + def addr(self): + raise NotImplementedError + + +class TCPServer(Server): + SOCKET_TYPE = socket.SOCK_STREAM + + def __init__(self, *args, **kwargs): + super(TCPServer, self).__init__(*args, **kwargs) + + self.accepted_connection = None + + def listen(self): + self._sock.listen(1) + + def recv_raw(self, limit=1024): + if not self.accepted_connection: + self.accepted_connection, _ = self._sock.accept() + + return self.accepted_connection.recv(limit) + + def close(self): + super(TCPServer, self).close() + if self.accepted_connection: + self.accepted_connection.close() + + def addr(self): + return "tcp://{}:{}".format(*self._sock.getsockname()) + + +class UnixSocketServer(TCPServer): + def addr(self): + return "unix://{}".format(self._sock.getsockname()) + + +class UDPServer(Server): + SOCKET_TYPE = socket.SOCK_DGRAM + + def recv_raw(self, limit=1024): + return self._sock.recv(limit) + + def addr(self): + return "udp://{}:{}".format(*self._sock.getsockname()) diff --git a/tests/test_event.py b/tests/test_event.py index 494b0f2..bd24480 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -3,19 +3,26 @@ import unittest from fluent import event, sender -from tests import mockserver -class TestException(BaseException): pass +from tests.mockserver import create_server + + +class TestException(BaseException): + pass + + +class BaseTestEvent(object): + ADDR = "" -class TestEvent(unittest.TestCase): def setUp(self): - self._server = mockserver.MockRecvServer('localhost') - sender.setup('app', port=self._server.port) + self._server = create_server(self.ADDR) + sender.setup('app', host=self._server.addr()) def tearDown(self): from fluent.sender import _set_global_sender sender.close() _set_global_sender(None) + self._server.close() def test_logging(self): # XXX: This tests succeeds even if the fluentd connection failed @@ -45,7 +52,6 @@ def test_no_last_error_on_successful_event(self): sender.close() @unittest.skip("This test failed with 'TypeError: catching classes that do not inherit from BaseException is not allowed' so skipped") - #@patch('fluent.sender.socket') def test_connect_exception_during_event_send(self, mock_socket): # Make the socket.socket().connect() call raise a custom exception mock_connect = mock_socket.socket.return_value.connect @@ -54,7 +60,7 @@ def test_connect_exception_during_event_send(self, mock_socket): # Force the socket to reconnect while trying to emit the event global_sender = sender.get_global_sender() - global_sender._close() + global_sender.transport.close() event.Event('unfollow', { 'from': 'userE', @@ -64,3 +70,11 @@ def test_connect_exception_during_event_send(self, mock_socket): ex = global_sender.last_error self.assertEqual(ex.args, EXCEPTION_MSG) global_sender.clear_last_error() + + +class TestEvent_TCP(BaseTestEvent, unittest.TestCase): + ADDR = 'tcp://localhost' + + +class TestEvent_UDP(BaseTestEvent, unittest.TestCase): + ADDR = 'udp://localhost' diff --git a/tests/test_handler.py b/tests/test_handler.py index 9180231..de44bb5 100644 --- a/tests/test_handler.py +++ b/tests/test_handler.py @@ -1,24 +1,31 @@ -# -*- coding: utf-8 -*- +# -*- coding: utf-8 -*- import logging import unittest import fluent.handler -from tests import mockserver +from tests.mockserver import create_server -class TestHandler(unittest.TestCase): +class BaseTestHandler(object): + ADDR = "" + def setUp(self): - super(TestHandler, self).setUp() - self._server = mockserver.MockRecvServer('localhost') - self._port = self._server.port + super(BaseTestHandler, self).setUp() + self._server = create_server(self.ADDR) + + def tearDown(self): + self._server.close() - def get_data(self): - return self._server.get_recieved() + def create_handler(self, tag): + return fluent.handler.FluentHandler(tag, host=self._server.addr()) + + def get_messages(self, qty=1): + return self._server.recv(qty) def test_simple(self): - handler = fluent.handler.FluentHandler('app.follow', port=self._port) + handler = self.create_handler('app.follow') logging.basicConfig(level=logging.INFO) log = logging.getLogger('fluent.test') @@ -30,7 +37,7 @@ def test_simple(self): }) handler.close() - data = self.get_data() + data = self.get_messages(1) eq = self.assertEqual eq(1, len(data)) eq(3, len(data[0])) @@ -41,7 +48,7 @@ def test_simple(self): self.assertTrue(isinstance(data[0][1], int)) def test_custom_fmt(self): - handler = fluent.handler.FluentHandler('app.follow', port=self._port) + handler = self.create_handler('app.follow') logging.basicConfig(level=logging.INFO) log = logging.getLogger('fluent.test') @@ -56,14 +63,14 @@ def test_custom_fmt(self): log.info({'sample': 'value'}) handler.close() - data = self.get_data() + data = self.get_messages() self.assertTrue('name' in data[0][2]) self.assertEqual('fluent.test', data[0][2]['name']) self.assertTrue('lineno' in data[0][2]) self.assertTrue('emitted_at' in data[0][2]) def test_custom_field_raise_exception(self): - handler = fluent.handler.FluentHandler('app.follow', port=self._port) + handler = self.create_handler('app.follow') logging.basicConfig(level=logging.INFO) log = logging.getLogger('fluent.test') @@ -80,7 +87,7 @@ def test_custom_field_raise_exception(self): handler.close() def test_custom_field_fill_missing_fmt_key_is_true(self): - handler = fluent.handler.FluentHandler('app.follow', port=self._port) + handler = self.create_handler('app.follow') logging.basicConfig(level=logging.INFO) log = logging.getLogger('fluent.test') @@ -97,7 +104,7 @@ def test_custom_field_fill_missing_fmt_key_is_true(self): log.removeHandler(handler) handler.close() - data = self.get_data() + data = self.get_messages() self.assertTrue('name' in data[0][2]) self.assertEqual('fluent.test', data[0][2]['name']) self.assertTrue('custom_field' in data[0][2]) @@ -105,7 +112,7 @@ def test_custom_field_fill_missing_fmt_key_is_true(self): self.assertIsNone(data[0][2]['custom_field']) def test_json_encoded_message(self): - handler = fluent.handler.FluentHandler('app.follow', port=self._port) + handler = self.create_handler('app.follow') logging.basicConfig(level=logging.INFO) log = logging.getLogger('fluent.test') @@ -114,12 +121,12 @@ def test_json_encoded_message(self): log.info('{"key": "hello world!", "param": "value"}') handler.close() - data = self.get_data() + data = self.get_messages() self.assertTrue('key' in data[0][2]) self.assertEqual('hello world!', data[0][2]['key']) def test_unstructured_message(self): - handler = fluent.handler.FluentHandler('app.follow', port=self._port) + handler = self.create_handler('app.follow') logging.basicConfig(level=logging.INFO) log = logging.getLogger('fluent.test') @@ -128,12 +135,12 @@ def test_unstructured_message(self): log.info('hello %s', 'world') handler.close() - data = self.get_data() + data = self.get_messages() self.assertTrue('message' in data[0][2]) self.assertEqual('hello world', data[0][2]['message']) def test_unstructured_formatted_message(self): - handler = fluent.handler.FluentHandler('app.follow', port=self._port) + handler = self.create_handler('app.follow') logging.basicConfig(level=logging.INFO) log = logging.getLogger('fluent.test') @@ -142,12 +149,12 @@ def test_unstructured_formatted_message(self): log.info('hello world, %s', 'you!') handler.close() - data = self.get_data() + data = self.get_messages() self.assertTrue('message' in data[0][2]) self.assertEqual('hello world, you!', data[0][2]['message']) def test_number_string_simple_message(self): - handler = fluent.handler.FluentHandler('app.follow', port=self._port) + handler = self.create_handler('app.follow') logging.basicConfig(level=logging.INFO) log = logging.getLogger('fluent.test') @@ -156,11 +163,11 @@ def test_number_string_simple_message(self): log.info("1") handler.close() - data = self.get_data() + data = self.get_messages() self.assertTrue('message' in data[0][2]) def test_non_string_simple_message(self): - handler = fluent.handler.FluentHandler('app.follow', port=self._port) + handler = self.create_handler('app.follow') logging.basicConfig(level=logging.INFO) log = logging.getLogger('fluent.test') @@ -169,11 +176,11 @@ def test_non_string_simple_message(self): log.info(42) handler.close() - data = self.get_data() + data = self.get_messages() self.assertTrue('message' in data[0][2]) def test_non_string_dict_message(self): - handler = fluent.handler.FluentHandler('app.follow', port=self._port) + handler = self.create_handler('app.follow') logging.basicConfig(level=logging.INFO) log = logging.getLogger('fluent.test') @@ -182,6 +189,14 @@ def test_non_string_dict_message(self): log.info({42: 'root'}) handler.close() - data = self.get_data() + data = self.get_messages() # For some reason, non-string keys are ignored self.assertFalse(42 in data[0][2]) + + +class TestHandler_TCP(BaseTestHandler, unittest.TestCase): + ADDR = 'tcp://localhost' + + +class TestHandler_UDP(BaseTestHandler, unittest.TestCase): + ADDR = 'udp://localhost' diff --git a/tests/test_sender.py b/tests/test_sender.py index c4c9520..d7c6b52 100644 --- a/tests/test_sender.py +++ b/tests/test_sender.py @@ -5,7 +5,7 @@ import socket import fluent.sender -from tests import mockserver +from tests.mockserver import create_server class TestSetup(unittest.TestCase): @@ -38,24 +38,27 @@ def test_tolerant(self): self.assertEqual(actual.timeout, 1.0) -class TestSender(unittest.TestCase): +class BaseTestSender(object): + ADDR = "" + def setUp(self): - super(TestSender, self).setUp() - self._server = mockserver.MockRecvServer('localhost') - self._sender = fluent.sender.FluentSender(tag='test', - port=self._server.port) + super(BaseTestSender, self).setUp() + self._server = create_server(self.ADDR) + self._sender = fluent.sender.FluentSender( + tag='test', host=self._server.addr(), + ) def tearDown(self): self._sender.close() + self._server.close() - def get_data(self): - return self._server.get_recieved() + def get_messages(self, qty=1): + return self._server.recv(qty) def test_simple(self): - sender = self._sender - sender.emit('foo', {'bar': 'baz'}) - sender._close() - data = self.get_data() + self._sender.emit('foo', {'bar': 'baz'}) + + data = self.get_messages(1) eq = self.assertEqual eq(1, len(data)) eq(3, len(data[0])) @@ -67,7 +70,7 @@ def test_simple(self): def test_no_last_error_on_successful_emit(self): sender = self._sender sender.emit('foo', {'bar': 'baz'}) - sender._close() + sender.transport.close() self.assertEqual(sender.last_error, None) @@ -85,7 +88,6 @@ def test_clear_last_error(self): self.assertEqual(self._sender.last_error, None) @unittest.skip("This test failed with 'TypeError: catching classes that do not inherit from BaseException is not allowed' so skipped") - #@patch('fluent.sender.socket') def test_connect_exception_during_sender_init(self, mock_socket): # Make the socket.socket().connect() call raise a custom exception mock_connect = mock_socket.socket.return_value.connect @@ -93,3 +95,11 @@ def test_connect_exception_during_sender_init(self, mock_socket): mock_connect.side_effect = socket.error(EXCEPTION_MSG) self.assertEqual(self._sender.last_error.args[0], EXCEPTION_MSG) + + +class TestSender_TCP(BaseTestSender, unittest.TestCase): + ADDR = 'tcp://localhost' + + +class TestSender_UDP(BaseTestSender, unittest.TestCase): + ADDR = 'udp://localhost'