"""``kilostar.utils.request_context``:双层 ID 的 contextvars 与 logger 集成。 覆盖: - ``request_id`` / ``trace_id`` 默认空、bind 后可读、reset 后还原 - ``request_id_scope`` / ``trace_id_scope`` 上下文管理器 - ``snapshot`` / ``apply_snapshot`` 跨边界透传 - logger 切面:``contextvars`` 中的值会自动写入 ``record["extra"]`` """ from __future__ import annotations import asyncio import pytest from kilostar.utils import request_context as rc def test_default_values_are_empty(): assert rc.get_request_id() == "" assert rc.get_trace_id() == "" def test_bind_and_reset_request_id(): token = rc.bind_request_id("req-abc") try: assert rc.get_request_id() == "req-abc" finally: rc.reset_request_id(token) assert rc.get_request_id() == "" def test_bind_and_reset_trace_id(): token = rc.bind_trace_id("trace-xyz") try: assert rc.get_trace_id() == "trace-xyz" finally: rc.reset_trace_id(token) assert rc.get_trace_id() == "" def test_request_id_scope(): with rc.request_id_scope("req-1") as rid: assert rid == "req-1" assert rc.get_request_id() == "req-1" assert rc.get_request_id() == "" def test_trace_id_scope_nested(): with rc.trace_id_scope("outer"): assert rc.get_trace_id() == "outer" with rc.trace_id_scope("inner"): assert rc.get_trace_id() == "inner" assert rc.get_trace_id() == "outer" assert rc.get_trace_id() == "" def test_snapshot_returns_current_ids(): with rc.request_id_scope("r1"), rc.trace_id_scope("t1"): snap = rc.snapshot() assert snap == {"request_id": "r1", "trace_id": "t1"} def test_apply_snapshot_restores_after_exit(): snap = {"request_id": "r2", "trace_id": "t2"} with rc.apply_snapshot(snap): assert rc.get_request_id() == "r2" assert rc.get_trace_id() == "t2" assert rc.get_request_id() == "" assert rc.get_trace_id() == "" def test_apply_snapshot_handles_none(): """传 None 应是 no-op,不报错。""" with rc.apply_snapshot(None): assert rc.get_request_id() == "" @pytest.mark.asyncio async def test_contextvars_isolated_between_concurrent_tasks(): """两个并发的 asyncio task 各自的 trace_id 不应互相串味。""" results: dict[str, str] = {} async def worker(name: str, trace_id: str) -> None: with rc.trace_id_scope(trace_id): await asyncio.sleep(0) results[name] = rc.get_trace_id() await asyncio.gather( worker("a", "trace-a"), worker("b", "trace-b"), ) assert results == {"a": "trace-a", "b": "trace-b"} def test_new_request_id_format(): rid = rc.new_request_id() assert rid.startswith("req-") assert len(rid) > len("req-") def test_new_request_id_custom_prefix(): rid = rc.new_request_id("ws") assert rid.startswith("ws-")