Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions kafka/net/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def writelines(self, list_of_data):

async def _write_to_sock(self):
try:
while self._write and not self._closed and self._write_buffer:
while self._write_buffer:
await self._net.wait_write(self._sock)
total_bytes, err = self._sock_send()
if err:
Expand All @@ -197,11 +197,18 @@ async def _write_to_sock(self):
self._protocol._sensors.bytes_sent.record(total_bytes)
finally:
self._writing = False
if not self._write:
self._sock.shutdown(socket.SHUT_WR)
if self._closed:
self._close()
elif not self._write:
try:
self._sock.shutdown(socket.SHUT_WR)
except OSError:
pass

def _sock_send(self):
total_bytes = 0
if self._sock is None:
return total_bytes, Errors.KafkaConnectionError('Connection closed during send')
while self._write_buffer:
next_chunk = self._write_buffer.popleft()
# Wrap in memoryview so partial-send slicing is O(1) instead of
Expand Down Expand Up @@ -260,6 +267,10 @@ def _close(self, error=None):
self._net.unregister_event(sock, selectors.EVENT_READ | selectors.EVENT_WRITE)
except (KeyError, ValueError):
pass
try:
sock.shutdown(socket.SHUT_RDWR)
except OSError:
pass
sock.close()
proto = self._protocol
self._protocol = None
Expand Down Expand Up @@ -398,6 +409,8 @@ def _sock_recv(self):
def _sock_send(self):
total_bytes = 0
err = None
if self._sock is None:
return total_bytes, Errors.KafkaConnectionError('Connection closed during send')
while self._write_buffer:
next_chunk = self._write_buffer.popleft()
while next_chunk:
Expand Down
119 changes: 119 additions & 0 deletions test/net/test_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,28 @@ def fake_sock_send():
assert t._sock is None
proto.connection_lost.assert_called_once_with(err)

def test_sock_send_when_sock_closed_returns_clean_error(self, net):
"""If the socket was closed (set to None by _close/abort) while data is
still buffered, _sock_send must short-circuit with a clean
KafkaConnectionError('Connection closed during send') rather than
dereferencing None -> AttributeError.
"""
sock = _make_mock_sock()
t = KafkaTCPTransport(net, sock)
t.write(b'data') # leave bytes buffered
t._sock = None # socket torn down out from under the pending send

total_bytes, err = t._sock_send()

assert total_bytes == 0
assert isinstance(err, Errors.KafkaConnectionError)
# Exactly the clean message -- not a wrapped AttributeError.
assert err.args == ('Connection closed during send',)
assert not isinstance(err.args[0], BaseException)
# The buffered chunk must not be consumed by the short-circuit.
assert list(t._write_buffer) == [b'data']
sock.send.assert_not_called()

def test_sock_recv_error_closes_transport(self, net, socketpair):
"""If _sock_recv returns an error, _read_from_sock must close the
transport and propagate the error via protocol.connection_lost.
Expand Down Expand Up @@ -227,6 +249,103 @@ async def reader():
net.poll(timeout_ms=1000, future=f)
assert received == [b'hello world']

def test_write_eof_empty_buffer_shuts_down_immediately(self, net):
# Fast path: nothing buffered, so write_eof half-closes right away.
sock = _make_mock_sock()
t = KafkaTCPTransport(net, sock)
t.write_eof()
sock.shutdown.assert_called_once_with(socket.SHUT_WR)

def test_write_eof_defers_shutdown_until_buffer_flushed(self, net):
# With data still buffered, write_eof must NOT shut down the write side
# yet -- doing so would discard the unflushed bytes (the latent bug).
sock = _make_mock_sock()
t = KafkaTCPTransport(net, sock)
t.write(b'pending')
t.write_eof()
assert not t._write
assert t._write_buffer # still buffered
sock.shutdown.assert_not_called()

def test_write_eof_flushes_buffered_data_then_shuts_write(self, net, socketpair):
# Regression: data buffered before write_eof() must be fully delivered
# to the peer, and only then is the write side shut down (peer sees EOF).
rsock, wsock = socketpair
t = KafkaTCPTransport(net, wsock)
t.write(b'hello world') # buffered; _write_to_sock scheduled
t.write_eof() # half-close requested -- flush must win
assert t._write_buffer

f = Future()
received = []
async def reader():
while True:
await net.wait_read(rsock)
data = rsock.recv(1024)
received.append(data)
if data == b'': # EOF => peer shut down its write side
break
f.success(True)
net.call_soon(reader)
net.poll(timeout_ms=1000, future=f)

assert b''.join(received) == b'hello world'
assert received[-1] == b'' # shutdown(SHUT_WR) happened after the flush
assert not t._write_buffer

def test_close_empty_buffer_closes_immediately(self, net):
# Fast path: nothing buffered, so close() tears down synchronously.
sock = _make_mock_sock()
proto = MagicMock()
t = KafkaTCPTransport(net, sock)
t.set_protocol(proto)
t.close()
assert t._sock is None
proto.connection_lost.assert_called_once_with(None)

def test_close_defers_teardown_until_buffer_flushed(self, net):
# With data still buffered, close() must defer the actual socket
# teardown to _write_to_sock so the bytes are not dropped.
sock = _make_mock_sock()
proto = MagicMock()
t = KafkaTCPTransport(net, sock)
t.set_protocol(proto)
t.write(b'pending')
t.close()
assert t.is_closing()
assert t._write_buffer # not yet flushed
assert t._sock is not None # not yet closed
proto.connection_lost.assert_not_called()

def test_close_flushes_buffered_data_then_closes(self, net, socketpair):
# Regression: data buffered before close() must reach the peer before
# the transport tears the socket down.
rsock, wsock = socketpair
t = KafkaTCPTransport(net, wsock)
proto = MagicMock()
t.set_protocol(proto)
t.write(b'goodbye world') # buffered; _write_to_sock scheduled
t.close() # closed flag set, teardown deferred
assert t.is_closing()
assert t._write_buffer

f = Future()
received = []
async def reader():
while True:
await net.wait_read(rsock)
data = rsock.recv(1024)
received.append(data)
if data == b'':
break
f.success(True)
net.call_soon(reader)
net.poll(timeout_ms=1000, future=f)

assert b''.join(received) == b'goodbye world'
assert t._sock is None # _write_to_sock flushed then _close()d
proto.connection_lost.assert_called_once_with(None)

def test_writeSequence(self, net):
sock = _make_mock_sock()
t = KafkaTCPTransport(net, sock)
Expand Down
Loading