diff --git a/kafka/net/transport.py b/kafka/net/transport.py index 4d923cc39..1ac3309e2 100644 --- a/kafka/net/transport.py +++ b/kafka/net/transport.py @@ -186,7 +186,7 @@ def writelines(self, list_of_data): async def _write_to_sock(self): try: - while self._write and not self._closed and self._write_buffer: + while self._write_buffer: await self._net.wait_write(self._sock) total_bytes, err = self._sock_send() if err: @@ -197,11 +197,18 @@ async def _write_to_sock(self): self._protocol._sensors.bytes_sent.record(total_bytes) finally: self._writing = False - if not self._write: - self._sock.shutdown(socket.SHUT_WR) + if self._closed: + self._close() + elif not self._write: + try: + self._sock.shutdown(socket.SHUT_WR) + except OSError: + pass def _sock_send(self): total_bytes = 0 + if self._sock is None: + return total_bytes, Errors.KafkaConnectionError('Connection closed during send') while self._write_buffer: next_chunk = self._write_buffer.popleft() # Wrap in memoryview so partial-send slicing is O(1) instead of @@ -260,6 +267,10 @@ def _close(self, error=None): self._net.unregister_event(sock, selectors.EVENT_READ | selectors.EVENT_WRITE) except (KeyError, ValueError): pass + try: + sock.shutdown(socket.SHUT_RDWR) + except OSError: + pass sock.close() proto = self._protocol self._protocol = None @@ -398,6 +409,8 @@ def _sock_recv(self): def _sock_send(self): total_bytes = 0 err = None + if self._sock is None: + return total_bytes, Errors.KafkaConnectionError('Connection closed during send') while self._write_buffer: next_chunk = self._write_buffer.popleft() while next_chunk: diff --git a/test/net/test_transport.py b/test/net/test_transport.py index 4aa27ccd7..ad72c82a6 100644 --- a/test/net/test_transport.py +++ b/test/net/test_transport.py @@ -173,6 +173,28 @@ def fake_sock_send(): assert t._sock is None proto.connection_lost.assert_called_once_with(err) + def test_sock_send_when_sock_closed_returns_clean_error(self, net): + """If the socket was closed (set to None by _close/abort) while data is + still buffered, _sock_send must short-circuit with a clean + KafkaConnectionError('Connection closed during send') rather than + dereferencing None -> AttributeError. + """ + sock = _make_mock_sock() + t = KafkaTCPTransport(net, sock) + t.write(b'data') # leave bytes buffered + t._sock = None # socket torn down out from under the pending send + + total_bytes, err = t._sock_send() + + assert total_bytes == 0 + assert isinstance(err, Errors.KafkaConnectionError) + # Exactly the clean message -- not a wrapped AttributeError. + assert err.args == ('Connection closed during send',) + assert not isinstance(err.args[0], BaseException) + # The buffered chunk must not be consumed by the short-circuit. + assert list(t._write_buffer) == [b'data'] + sock.send.assert_not_called() + def test_sock_recv_error_closes_transport(self, net, socketpair): """If _sock_recv returns an error, _read_from_sock must close the transport and propagate the error via protocol.connection_lost. @@ -227,6 +249,103 @@ async def reader(): net.poll(timeout_ms=1000, future=f) assert received == [b'hello world'] + def test_write_eof_empty_buffer_shuts_down_immediately(self, net): + # Fast path: nothing buffered, so write_eof half-closes right away. + sock = _make_mock_sock() + t = KafkaTCPTransport(net, sock) + t.write_eof() + sock.shutdown.assert_called_once_with(socket.SHUT_WR) + + def test_write_eof_defers_shutdown_until_buffer_flushed(self, net): + # With data still buffered, write_eof must NOT shut down the write side + # yet -- doing so would discard the unflushed bytes (the latent bug). + sock = _make_mock_sock() + t = KafkaTCPTransport(net, sock) + t.write(b'pending') + t.write_eof() + assert not t._write + assert t._write_buffer # still buffered + sock.shutdown.assert_not_called() + + def test_write_eof_flushes_buffered_data_then_shuts_write(self, net, socketpair): + # Regression: data buffered before write_eof() must be fully delivered + # to the peer, and only then is the write side shut down (peer sees EOF). + rsock, wsock = socketpair + t = KafkaTCPTransport(net, wsock) + t.write(b'hello world') # buffered; _write_to_sock scheduled + t.write_eof() # half-close requested -- flush must win + assert t._write_buffer + + f = Future() + received = [] + async def reader(): + while True: + await net.wait_read(rsock) + data = rsock.recv(1024) + received.append(data) + if data == b'': # EOF => peer shut down its write side + break + f.success(True) + net.call_soon(reader) + net.poll(timeout_ms=1000, future=f) + + assert b''.join(received) == b'hello world' + assert received[-1] == b'' # shutdown(SHUT_WR) happened after the flush + assert not t._write_buffer + + def test_close_empty_buffer_closes_immediately(self, net): + # Fast path: nothing buffered, so close() tears down synchronously. + sock = _make_mock_sock() + proto = MagicMock() + t = KafkaTCPTransport(net, sock) + t.set_protocol(proto) + t.close() + assert t._sock is None + proto.connection_lost.assert_called_once_with(None) + + def test_close_defers_teardown_until_buffer_flushed(self, net): + # With data still buffered, close() must defer the actual socket + # teardown to _write_to_sock so the bytes are not dropped. + sock = _make_mock_sock() + proto = MagicMock() + t = KafkaTCPTransport(net, sock) + t.set_protocol(proto) + t.write(b'pending') + t.close() + assert t.is_closing() + assert t._write_buffer # not yet flushed + assert t._sock is not None # not yet closed + proto.connection_lost.assert_not_called() + + def test_close_flushes_buffered_data_then_closes(self, net, socketpair): + # Regression: data buffered before close() must reach the peer before + # the transport tears the socket down. + rsock, wsock = socketpair + t = KafkaTCPTransport(net, wsock) + proto = MagicMock() + t.set_protocol(proto) + t.write(b'goodbye world') # buffered; _write_to_sock scheduled + t.close() # closed flag set, teardown deferred + assert t.is_closing() + assert t._write_buffer + + f = Future() + received = [] + async def reader(): + while True: + await net.wait_read(rsock) + data = rsock.recv(1024) + received.append(data) + if data == b'': + break + f.success(True) + net.call_soon(reader) + net.poll(timeout_ms=1000, future=f) + + assert b''.join(received) == b'goodbye world' + assert t._sock is None # _write_to_sock flushed then _close()d + proto.connection_lost.assert_called_once_with(None) + def test_writeSequence(self, net): sock = _make_mock_sock() t = KafkaTCPTransport(net, sock)