Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self) -> None:
self._event = asyncio.Event()
def loop(self) -> asyncio.AbstractEventLoop:
if not hasattr(self, "_loop"):
try:
self._loop = asyncio.get_event_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
return self._loop
def __init__(
self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter,
):
self.stream_reader = stream_reader
self.stream_writer = stream_writer
self.read_lock = asyncio.Lock()
self._inner: typing.Optional[SocketStream] = None
async def open_uds_stream(
self,
path: str,
hostname: typing.Optional[str],
ssl_context: typing.Optional[ssl.SSLContext],
timeout: Timeout,
) -> SocketStream:
server_hostname = hostname if ssl_context else None
try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
asyncio.open_unix_connection(
path, ssl=ssl_context, server_hostname=server_hostname
),
timeout.connect_timeout,
)
except asyncio.TimeoutError:
raise ConnectTimeout()
return SocketStream(stream_reader=stream_reader, stream_writer=stream_writer)
def semaphore(self) -> asyncio.BoundedSemaphore:
if not hasattr(self, "_semaphore"):
self._semaphore = asyncio.BoundedSemaphore(value=self.max_value)
return self._semaphore
async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: Timeout
) -> "SocketStream":
loop = asyncio.get_event_loop()
stream_reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(stream_reader)
transport = self.stream_writer.transport
loop_start_tls = getattr(loop, "start_tls", backport_start_tls)
transport = await asyncio.wait_for(
loop_start_tls(
transport=transport,
protocol=protocol,
sslcontext=ssl_context,
server_hostname=hostname,
),
timeout=timeout.connect_timeout,
)
stream_reader.set_transport(transport)
async def read(self, n: int, timeout: Timeout) -> bytes:
try:
async with self.read_lock:
return await asyncio.wait_for(
self.stream_reader.read(n), timeout.read_timeout
)
except asyncio.TimeoutError:
raise ReadTimeout() from None
*,
server_side: bool = False,
server_hostname: str = None,
ssl_handshake_timeout: float = None,
) -> asyncio.Transport: # pragma: nocover (Since it's not used on all Python versions.)
"""
Python 3.6 asyncio doesn't have a start_tls() method on the loop
so we use this function in place of the loop's start_tls() method.
Adapted from this comment:
https://github.com/urllib3/urllib3/issues/1323#issuecomment-362494839
"""
import asyncio.sslproto
loop = asyncio.get_event_loop()
waiter = loop.create_future()
ssl_protocol = asyncio.sslproto.SSLProtocol(
loop,
protocol,
sslcontext,
waiter,
server_side=False,
server_hostname=server_hostname,
call_connection_made=False,
)
transport.set_protocol(ssl_protocol)
loop.call_soon(ssl_protocol.connection_made, transport)
loop.call_soon(transport.resume_reading) # type: ignore
await waiter
async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: Timeout
) -> "SocketStream":
loop = asyncio.get_event_loop()
stream_reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(stream_reader)
transport = self.stream_writer.transport
loop_start_tls = getattr(loop, "start_tls", backport_start_tls)
transport = await asyncio.wait_for(
loop_start_tls(
transport=transport,
protocol=protocol,
sslcontext=ssl_context,
server_hostname=hostname,
),
timeout=timeout.connect_timeout,
)
stream_reader.set_transport(transport)
stream_writer = asyncio.StreamWriter(
async def open_tcp_stream(
self,
hostname: str,
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: Timeout,
) -> SocketStream:
try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
timeout.connect_timeout,
)
except asyncio.TimeoutError:
raise ConnectTimeout()
return SocketStream(stream_reader=stream_reader, stream_writer=stream_writer)