"""``request_id_middleware``:请求级 ID 入口生成/继承 + 响应头回写。 覆盖: - 没传 X-Request-Id 时 middleware 生成新 ID 并写回响应头 - 传了 X-Request-Id 时被尊重并回写 - 路由处理器内可以从 contextvars 读到当前 request_id - 异常路径下 contextvars 也能被正确 reset(不会泄漏到下一请求) """ from __future__ import annotations import pytest from fastapi import FastAPI, Request from httpx import AsyncClient, ASGITransport from kilostar.utils import request_context as rc from kilostar.utils.request_context import ( bind_request_id, new_request_id, reset_request_id, ) def _build_app() -> FastAPI: app = FastAPI() @app.middleware("http") async def request_id_middleware(request: Request, call_next): incoming = request.headers.get("X-Request-Id", "").strip() request_id = incoming or new_request_id() token = bind_request_id(request_id) try: response = await call_next(request) finally: reset_request_id(token) response.headers["X-Request-Id"] = request_id return response @app.get("/whoami") async def whoami(): return {"request_id": rc.get_request_id()} return app @pytest.mark.asyncio async def test_generates_request_id_when_header_absent(): transport = ASGITransport(app=_build_app(), raise_app_exceptions=False) async with AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.get("/whoami") assert resp.status_code == 200 rid = resp.headers.get("X-Request-Id") assert rid and rid.startswith("req-") assert resp.json()["request_id"] == rid @pytest.mark.asyncio async def test_inherits_request_id_from_header(): transport = ASGITransport(app=_build_app(), raise_app_exceptions=False) async with AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.get( "/whoami", headers={"X-Request-Id": "client-supplied-123"} ) assert resp.headers.get("X-Request-Id") == "client-supplied-123" assert resp.json()["request_id"] == "client-supplied-123" @pytest.mark.asyncio async def test_request_id_reset_after_request(): """两次请求的 request_id 不应互相串味(contextvars 在 finally 里被 reset)。""" transport = ASGITransport(app=_build_app(), raise_app_exceptions=False) async with AsyncClient(transport=transport, base_url="http://test") as client: r1 = await client.get("/whoami") r2 = await client.get("/whoami") assert r1.headers["X-Request-Id"] != r2.headers["X-Request-Id"] def test_logger_picks_up_contextvars(): """logger 切面:contextvars 中的值会被 patcher 注入到 record.extra。""" from kilostar.utils.logger import get_logger captured: list[dict] = [] def _format(record): # loguru 的 format 函数收到 record dict 本身,可以直接读 extra captured.append(dict(record["extra"])) return "{message}\n" from loguru import logger as _global handler_id = _global.add(lambda _msg: None, format=_format, level="DEBUG") try: with rc.trace_id_scope("trace-from-ctx"), rc.request_id_scope("req-from-ctx"): log = get_logger("test_actor") log.info("hello") finally: _global.remove(handler_id) assert captured, "应至少捕获一条日志" # 找到我们的那条(避免被并发中其他 logger 干扰) matched = [c for c in captured if c.get("actor_name") == "test_actor"] assert matched, f"未捕获到来自 test_actor 的日志,全部 captured={captured}" last = matched[-1] assert last.get("trace_id") == "trace-from-ctx" assert last.get("request_id") == "req-from-ctx"