Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_app_middleware_argument():
def homepage(request):
return PlainTextResponse("Homepage")
app = Starlette(
routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)]
)
client = TestClient(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"
def test_middleware_repr():
middleware = Middleware(CustomMiddleware)
assert repr(middleware) == "Middleware(CustomMiddleware)"
async def test_duplicate_caching() -> None:
cache = Cache("locmem://default")
special_cache = Cache("locmem://special")
class DuplicateCache(HTTPEndpoint):
pass
app = Starlette(
routes=[
Route(
"/duplicate_cache", CacheMiddleware(DuplicateCache, cache=special_cache)
)
],
middleware=[Middleware(CacheMiddleware, cache=cache)],
)
client = httpx.AsyncClient(app=app, base_url="http://testserver")
async with cache, special_cache, client:
with pytest.raises(DuplicateCaching):
await client.get("/duplicate_cache")
for key, value in self.exception_handlers.items():
if key in (500, Exception):
error_handler = value
else:
exception_handlers[key] = value
middlewares = []
if config.ALLOWED_HOSTS != ["*"]:
middlewares.append(
Middleware(TrustedHostMiddleware, allowed_hosts=config.ALLOWED_HOSTS)
)
middlewares.append(
Middleware(ServerErrorMiddleware, handler=error_handler, debug=config.DEBUG)
)
if (
config.CORS_ALLOW_ORIGIN_REGEX is not None
or len(config.CORS_ALLOW_ORIGINS) > 0
):
middlewares.append(
Middleware(
CORSMiddleware,
allow_origins=config.CORS_ALLOW_ORIGINS,
allow_methods=config.CORS_ALLOW_METHODS,
allow_headers=config.CORS_ALLOW_HEADERS,
allow_credentials=config.CORS_ALLOW_CREDENTIALS,
allow_origin_regex=config.CORS_ALLOW_ORIGIN_REGEX,
expose_headers=config.CORS_EXPOSE_HEADERS,
max_age=config.CORS_MAX_AGE,
def build_app(self) -> ASGIApp:
config = self.config
error_handler = None
exception_handlers = {}
for key, value in self.exception_handlers.items():
if key in (500, Exception):
error_handler = value
else:
exception_handlers[key] = value
middlewares = []
if config.ALLOWED_HOSTS != ["*"]:
middlewares.append(
Middleware(TrustedHostMiddleware, allowed_hosts=config.ALLOWED_HOSTS)
)
middlewares.append(
Middleware(ServerErrorMiddleware, handler=error_handler, debug=config.DEBUG)
)
if (
config.CORS_ALLOW_ORIGIN_REGEX is not None
or len(config.CORS_ALLOW_ORIGINS) > 0
):
middlewares.append(
Middleware(
CORSMiddleware,
allow_origins=config.CORS_ALLOW_ORIGINS,
allow_methods=config.CORS_ALLOW_METHODS,
allow_headers=config.CORS_ALLOW_HEADERS,
def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
error_handler = None
exception_handlers = {}
for key, value in self.exception_handlers.items():
if key in (500, Exception):
error_handler = value
else:
exception_handlers[key] = value
middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug,)]
+ self.user_middleware
+ [
Middleware(
ExceptionMiddleware, handlers=exception_handlers, debug=debug,
)
]
)
app = self.router
for cls, options in reversed(middleware):
app = cls(app=app, **options)
return app
def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
error_handler = None
exception_handlers = {}
for key, value in self.exception_handlers.items():
if key in (500, Exception):
error_handler = value
else:
exception_handlers[key] = value
middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug,)]
+ self.user_middleware
+ [
Middleware(
ExceptionMiddleware, handlers=exception_handlers, debug=debug,
)
]
)
app = self.router
for cls, options in reversed(middleware):
app = cls(app=app, **options)
return app
middlewares.append(
Middleware(
CORSMiddleware,
allow_origins=config.CORS_ALLOW_ORIGINS,
allow_methods=config.CORS_ALLOW_METHODS,
allow_headers=config.CORS_ALLOW_HEADERS,
allow_credentials=config.CORS_ALLOW_CREDENTIALS,
allow_origin_regex=config.CORS_ALLOW_ORIGIN_REGEX,
expose_headers=config.CORS_EXPOSE_HEADERS,
max_age=config.CORS_MAX_AGE,
)
)
middlewares += self.user_middlewares
middlewares.append(Middleware(ExceptionMiddleware, handlers=exception_handlers))
app = self.app
for cls, options in reversed(middlewares):
app = cls(app=app, **options)
return app