Python 后端(FastAPI + FastMCP + ...)的初始版本号设定为 0.1.0,这是 uv 在 pypriject.toml 里给我自动设置的,我觉得有道理。
209 lines
6.6 KiB
Python
209 lines
6.6 KiB
Python
import asyncio
|
|
import time
|
|
import traceback
|
|
from contextlib import asynccontextmanager, suppress
|
|
from pathlib import Path
|
|
from typing import AsyncGenerator, Callable
|
|
|
|
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastmcp.utilities.lifespan import combine_lifespans
|
|
from watchfiles import awatch
|
|
|
|
from njupt_api.baselib import (
|
|
LogRecord,
|
|
config,
|
|
log_buffer,
|
|
log_record_serialize,
|
|
logger,
|
|
)
|
|
from njupt_api.zhengfang.zhengfang import ZhengFang
|
|
from router import __version__
|
|
from router.admin_router import admin_router
|
|
from router.api_router import api_router
|
|
from router.enhance.lib import ReturnDto
|
|
from router.enhance.model import create_db_and_tables
|
|
from router.mcp_router import mcp_app
|
|
from router.webui_router import webui_router
|
|
|
|
DATA_DIR = Path.cwd() / "data"
|
|
|
|
|
|
async def toml_watcher() -> None:
|
|
"""配置文件监听器"""
|
|
await config.load_json()
|
|
async for change in awatch(DATA_DIR / "config.json"):
|
|
logger.info(f"配置文件更新,重新加载 | {change=}")
|
|
await config.load_json()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def life_span(_: FastAPI) -> AsyncGenerator[None, None]:
|
|
logger.info("初始化 SQLite 数据库中...")
|
|
create_db_and_tables()
|
|
logger.info("启动配置文件监听任务...")
|
|
watcher_task = asyncio.create_task(toml_watcher(), name="toml_watcher")
|
|
logger.success("🌟 NJUPT API Suan 已经启动。")
|
|
try:
|
|
yield
|
|
finally:
|
|
logger.info("🌙 NJUPT API Suan 正在关闭。")
|
|
watcher_task.cancel()
|
|
logger.info("配置文件监听任务已结束。")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def jwxt(username: str, password: str) -> AsyncGenerator[ZhengFang, None]:
|
|
zf = ZhengFang()
|
|
await zf.start()
|
|
await zf.login(username, password)
|
|
yield zf
|
|
await zf.close()
|
|
return
|
|
|
|
|
|
app = FastAPI(lifespan=combine_lifespans(life_span, mcp_app.lifespan))
|
|
|
|
|
|
@app.middleware("http")
|
|
async def log_requests(request: Request, call_next: Callable) -> None:
|
|
# 忽略对路径 /mcp 好 /mcp/ 的日志记录
|
|
if request.url.path == "/mcp" or request.url.path == "/mcp/":
|
|
return await call_next(request)
|
|
|
|
# 如有需要,忽略对路径 /assets/ 的日志记录
|
|
if not config.get("log", "log_assets_request", False) and request.url.path.startswith("/assets/"):
|
|
return await call_next(request)
|
|
|
|
# 请求开始时间
|
|
start_time = time.time()
|
|
|
|
# 获取请求信息
|
|
client_host = request.client.host if request.client else "unknown"
|
|
method = request.method
|
|
path = request.url.path
|
|
|
|
# 记录请求开始
|
|
logger.debug(f"访问 [{client_host}] {method} {path}")
|
|
if config.get("log", "log_api_request_details", False):
|
|
logger.debug(f" - {request.headers=}")
|
|
logger.debug(f" - {request.path_params=}")
|
|
logger.debug(f" - {request.query_params=}")
|
|
logger.debug(f" - request.body={await request.body()}")
|
|
|
|
# 处理请求
|
|
try:
|
|
response = await call_next(request)
|
|
|
|
# 计算耗时
|
|
process_time = (time.time() - start_time) * 1000
|
|
|
|
# 记录响应
|
|
if response.status_code < 400:
|
|
logger.info(
|
|
f"成功 [{client_host}] {method} {path} | Status: {response.status_code} | Time: {process_time:.2f}ms",
|
|
)
|
|
if config.get("log", "log_api_request_details", False):
|
|
logger.debug(f" - {response}")
|
|
else:
|
|
logger.warning(
|
|
f"警告 [{client_host}] {method} {path} | Status: {response.status_code} | Time: {process_time:.2f}ms",
|
|
)
|
|
|
|
# 可以添加响应头(可选)
|
|
response.headers["X-Process-Time"] = str(process_time)
|
|
return response
|
|
|
|
except Exception as exc:
|
|
process_time = (time.time() - start_time) * 1000
|
|
logger.error(
|
|
f"错误 [{client_host}] {method} {path} | Error: {exc} | Time: {process_time:.2f}ms",
|
|
)
|
|
raise
|
|
|
|
|
|
@app.websocket("/ws/logs")
|
|
async def ws_logs(websocket: WebSocket) -> None:
|
|
"""向 WebUI 传递 Suan API 日志""" # noqa: DOC501
|
|
await websocket.accept()
|
|
logger.debug("日志 websocket 建立连接,日志将被推送到 WebUI。")
|
|
last_sent_id = 0
|
|
|
|
try:
|
|
while True:
|
|
new_logs = [log for log in log_buffer if log.id >= last_sent_id]
|
|
if new_logs:
|
|
try:
|
|
await asyncio.wait_for(
|
|
websocket.send_json(
|
|
{
|
|
"logs": [log_record_serialize(log) for log in new_logs], # 新日志
|
|
"latest_id": LogRecord.log_counter,
|
|
},
|
|
),
|
|
timeout=2.0,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
logger.debug("日志 websocket 发送超时,已自动退出。")
|
|
break
|
|
last_sent_id = LogRecord.log_counter
|
|
|
|
try:
|
|
await asyncio.sleep(0.5)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except WebSocketDisconnect:
|
|
logger.debug("日志 websocket 断开连接。")
|
|
except asyncio.CancelledError:
|
|
logger.debug("日志 websocket 任务取消。")
|
|
except Exception as exc:
|
|
logger.error(f"向 WebUI 传递日志时遇到异常 : {exc}")
|
|
finally:
|
|
with suppress(Exception):
|
|
await websocket.close()
|
|
|
|
|
|
@app.get("/version")
|
|
async def get_version() -> ReturnDto:
|
|
ver = {"version": __version__}
|
|
return ReturnDto(success=True, result=ver)
|
|
|
|
|
|
app.include_router(api_router)
|
|
app.include_router(admin_router)
|
|
app.include_router(webui_router)
|
|
app.mount("/mcp", mcp_app)
|
|
app.mount(
|
|
"/assets",
|
|
StaticFiles(directory=Path.cwd() / "webui" / "dist" / "assets"),
|
|
name="assets",
|
|
)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
expose_headers=["mcp-session-id"], # 关键:必须显式暴露此 Header
|
|
)
|
|
|
|
|
|
@app.exception_handler(Exception)
|
|
def general_exception_handler(_: Request, exc: Exception) -> JSONResponse:
|
|
# 记录错误堆栈
|
|
logger.error(
|
|
f"未捕获的异常!这可能表示某些路由、端点已经不可用!\n{exc!s}\n{traceback.format_exc()}",
|
|
)
|
|
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={
|
|
"code": 500,
|
|
"message": "服务器内部错误",
|
|
"data": None, # 生产环境不要暴露详细错误信息
|
|
},
|
|
)
|