Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions tests/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,95 @@ class _TestSSL(tb.SSLTestCase):
PAYLOAD_SIZE = 1024 * 100
TIMEOUT = 60

def test_start_tls_buffer_transfer(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest()

HELLO_MSG = b'1' * self.PAYLOAD_SIZE
BUFFERED_MSG = b'buffered data before TLS'

server_context = self._create_server_ssl_context(
self.ONLYCERT, self.ONLYKEY)
client_context = self._create_client_ssl_context()

async def handle_client(reader, writer):
# Send data before TLS upgrade
writer.write(BUFFERED_MSG)
await writer.drain()
await asyncio.sleep(0.2)

# Read pre-TLS data
data = await reader.readexactly(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))

# Upgrade to TLS (server side)
try:
# We need the wait_for because the broken version hangs here
await asyncio.wait_for(writer.start_tls(server_context),
timeout=2
)
self.assertIsNotNone(writer.get_extra_info('sslcontext'))
except asyncio.TimeoutError:
self.assertIsNotNone(writer.get_extra_info('sslcontext'))

# Send/receive over TLS
writer.write(b'OK')
await writer.drain()

data = await reader.readexactly(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))

writer.close()
await self.wait_closed(writer)

async def client(addr):
# Use open_connection for StreamReader/StreamWriter
reader, writer = await asyncio.open_connection(*addr)

# Read buffered data before TLS
buffered = await reader.readexactly(len(BUFFERED_MSG))
self.assertEqual(buffered, BUFFERED_MSG,
"Client didn't receive buffered data before TLS upgrade")

# Write before TLS upgrade
writer.write(HELLO_MSG)
await writer.drain()

# Upgrade to TLS
try:
# We need the wait_for because the broken version hangs here
await asyncio.wait_for(writer.start_tls(client_context),
timeout=2
)
self.assertIsNotNone(writer.get_extra_info('sslcontext'))
except asyncio.TimeoutError:
self.assertIsNotNone(writer.get_extra_info('sslcontext'))

# Verify communication over TLS
tls_data = await reader.readexactly(2)
self.assertEqual(tls_data, b'OK',
"Client didn't receive TLS response correctly")

# Continue over TLS
writer.write(HELLO_MSG)
await writer.drain()

writer.close()
await self.wait_closed(writer)

async def run_test():
srv = await asyncio.start_server(
handle_client, '127.0.0.1', 0, family=socket.AF_INET)

addr = srv.sockets[0].getsockname()

await asyncio.wait_for(client(addr), timeout=10)

srv.close()
await srv.wait_closed()

self.loop.run_until_complete(run_test())

def test_create_server_ssl_1(self):
CNT = 0 # number of clients that were successful
TOTAL_CNT = 25 # total number of clients that test will create
Expand Down
11 changes: 11 additions & 0 deletions uvloop/loop.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1616,6 +1616,17 @@ cdef class Loop:
ssl_shutdown_timeout=ssl_shutdown_timeout,
call_connection_made=False)

# Transfer buffered data from the old protocol to the new one.
stream_buff = None
if hasattr(protocol, '_stream_reader'):
stream_reader = protocol._stream_reader
if stream_reader is not None:
stream_buff = getattr(stream_reader, '_buffer', None)

if stream_buff is not None:
ssl_protocol._incoming.write(stream_buff)
stream_buff.clear()

# Pause early so that "ssl_protocol.data_received()" doesn't
# have a chance to get called before "ssl_protocol.connection_made()".
transport.pause_reading()
Expand Down
Loading