Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self) -> None:
self.app = empty_framework
self.client = ("127.0.0.1", 5000)
self.config = Config()
self.server = ("remote", 5000)
self.scope = {"method": "GET"}
self.start_time = 0.0
self.state = ASGIWebsocketState.HANDSHAKE
self.sent_events: List[Event] = []
(ASGIWebsocketState.HANDSHAKE, "websocket.send"),
(ASGIWebsocketState.RESPONSE, "websocket.accept"),
(ASGIWebsocketState.RESPONSE, "websocket.send"),
(ASGIWebsocketState.CONNECTED, "websocket.http.response.start"),
(ASGIWebsocketState.CONNECTED, "websocket.http.response.body"),
(ASGIWebsocketState.CLOSED, "websocket.send"),
(ASGIWebsocketState.CLOSED, "websocket.http.response.start"),
(ASGIWebsocketState.CLOSED, "websocket.http.response.body"),
],
)
async def test_asgi_send_invalid_message_given_state(
state: ASGIWebsocketState, message_type: str
) -> None:
server = MockWebsocket()
server.state = state
with pytest.raises(UnexpectedMessage):
await server.asgi_send({"type": message_type})
async def test_asgi_send_invalid_http_message(status: Any, headers: Any, body: Any) -> None:
server = WebsocketMixin()
server.config = Config()
server.start_time = 0.0
server.state = ASGIWebsocketState.HANDSHAKE
server.scope = {"method": "GET"}
with pytest.raises((TypeError, ValueError)):
await server.asgi_send(
{"type": "websocket.http.response.start", "headers": headers, "status": status}
)
await server.asgi_send({"type": "websocket.http.response.body", "body": body})
def __init__(self) -> None:
self.app = empty_framework
self.config = Config()
self.scope = {"headers": []}
self.sent_events: list = []
self.start_time = 0.0
self.state = ASGIWebsocketState.HANDSHAKE
(ASGIWebsocketState.HANDSHAKE, "websocket.send"),
(ASGIWebsocketState.RESPONSE, "websocket.accept"),
(ASGIWebsocketState.RESPONSE, "websocket.send"),
(ASGIWebsocketState.CONNECTED, "websocket.http.response.start"),
(ASGIWebsocketState.CONNECTED, "websocket.http.response.body"),
(ASGIWebsocketState.CLOSED, "websocket.send"),
(ASGIWebsocketState.CLOSED, "websocket.http.response.start"),
(ASGIWebsocketState.CLOSED, "websocket.http.response.body"),
],
)
async def test_websocket_asgi_send_invalid_message_given_state(
state: ASGIWebsocketState, message_type: str
) -> None:
stream = MockH2WebsocketStream()
stream.state = state
with pytest.raises(UnexpectedMessage):
await stream.asgi_send({"type": message_type})
def __init__(
self,
app: ASGIFramework,
config: Config,
stream: trio.abc.Stream,
*,
upgrade_request: Optional[h11.Request] = None,
) -> None:
super().__init__(stream, "wsproto")
self.app = app
self.config = config
self.connection = WSConnection(ConnectionType.SERVER)
self.response: Optional[dict] = None
self.scope: Optional[dict] = None
self.send_lock = trio.Lock()
self.state = ASGIWebsocketState.HANDSHAKE
self.buffer = WebsocketBuffer(self.config.websocket_max_message_size)
self.app_send_channel, self.app_receive_channel = trio.open_memory_channel(10)
if upgrade_request is not None:
self.connection.initiate_upgrade_connection(
upgrade_request.headers, upgrade_request.target
)
app: ASGIFramework,
loop: asyncio.AbstractEventLoop,
config: Config,
transport: asyncio.BaseTransport,
*,
upgrade_request: Optional[h11.Request] = None,
) -> None:
super().__init__(loop, config, transport, "wsproto")
self.stop_keep_alive_timeout()
self.app = app
self.connection = WSConnection(ConnectionType.SERVER)
self.app_queue: asyncio.Queue = asyncio.Queue()
self.response: Optional[dict] = None
self.scope: Optional[dict] = None
self.state = ASGIWebsocketState.HANDSHAKE
self.task: Optional[asyncio.Future] = None
self.buffer = WebsocketBuffer(self.config.websocket_max_message_size)
if upgrade_request is not None:
self.connection.initiate_upgrade_connection(
upgrade_request.headers, upgrade_request.target
)
self.handle_events()
async def asgi_send(self, message: dict) -> None:
"""Called by the ASGI instance to send a message."""
if message["type"] == "websocket.accept" and self.state == ASGIWebsocketState.HANDSHAKE:
headers = build_and_validate_headers(message.get("headers", []))
raise_if_subprotocol_present(headers)
headers.extend(self.response_headers())
await self.asend(
AcceptConnection(
extensions=[PerMessageDeflate()],
extra_headers=headers,
subprotocol=message.get("subprotocol"),
)
)
self.state = ASGIWebsocketState.CONNECTED
self.config.access_logger.access(
self.scope, {"status": 101, "headers": []}, time() - self.start_time
)
elif (
message["type"] == "websocket.http.response.start"
self.config.access_logger.access(self.scope, self.response, time() - self.start_time)
elif message["type"] == "websocket.http.response.body" and self.state in {
ASGIWebsocketState.HANDSHAKE,
ASGIWebsocketState.RESPONSE,
}:
await self._asgi_send_rejection(message)
elif message["type"] == "websocket.send" and self.state == ASGIWebsocketState.CONNECTED:
event: wsproto.events.Event
if message.get("bytes") is not None:
event = wsproto.events.BytesMessage(data=bytes(message["bytes"]))
elif not isinstance(message["text"], str):
raise TypeError(f"{message['text']} should be a str")
else:
event = wsproto.events.TextMessage(data=message["text"])
await self.asend(Data(self.connection.send(event)))
elif message["type"] == "websocket.close" and self.state == ASGIWebsocketState.HANDSHAKE:
await self.send_http_error(403)
self.state = ASGIWebsocketState.HTTPCLOSED
elif message["type"] == "websocket.close":
data = self.connection.send(wsproto.events.CloseConnection(code=int(message["code"])))
await self.asend(Data(data))
self.state = ASGIWebsocketState.CLOSED
else:
raise UnexpectedMessage(self.state, message["type"])
headers.append((b"sec-websocket-extensions", accepts))
await self.asend(Response(headers))
self.connection = wsproto.connection.Connection(
wsproto.connection.ConnectionType.SERVER, supported_extensions
)
self.config.access_logger.access(
self.scope, {"status": 200, "headers": []}, time() - self.start_time
)
elif (
message["type"] == "websocket.http.response.start"
and self.state == ASGIWebsocketState.HANDSHAKE
):
self.response = message
self.config.access_logger.access(self.scope, self.response, time() - self.start_time)
elif message["type"] == "websocket.http.response.body" and self.state in {
ASGIWebsocketState.HANDSHAKE,
ASGIWebsocketState.RESPONSE,
}:
await self._asgi_send_rejection(message)
elif message["type"] == "websocket.send" and self.state == ASGIWebsocketState.CONNECTED:
event: wsproto.events.Event
if message.get("bytes") is not None:
event = wsproto.events.BytesMessage(data=bytes(message["bytes"]))
elif not isinstance(message["text"], str):
raise TypeError(f"{message['text']} should be a str")
else:
event = wsproto.events.TextMessage(data=message["text"])
await self.asend(Data(self.connection.send(event)))
elif message["type"] == "websocket.close" and self.state == ASGIWebsocketState.HANDSHAKE:
await self.send_http_error(403)
self.state = ASGIWebsocketState.HTTPCLOSED
elif message["type"] == "websocket.close":