diff --git a/main.py b/main.py new file mode 100644 index 0000000..e95ccfa --- /dev/null +++ b/main.py @@ -0,0 +1,86 @@ +from pathlib import Path +from secrets import token_urlsafe + +from njupt_api.baselib import config, logger +from router import __version__ + +DATA_DIR = Path.cwd() / "data" +TEMP_DIR = Path.cwd() / "temp" + + +if __name__ == "__main__": + try: + with open(file=Path.cwd() / "njupt_api" / "art.txt", mode="r", encoding="utf-8") as f: + print(f.read().format(__version__)) # noqa:T201 + except FileNotFoundError: + pass + + import uvicorn + + logger.success("Ciallo~(∠·ω< )⌒★") + + # 创建需要的工作目录 + if DATA_DIR.exists() and DATA_DIR.is_dir(): + logger.debug(f"工作目录 {DATA_DIR=!s} 已存在。") + else: + DATA_DIR.mkdir(parents=True, exist_ok=True) + logger.debug(f"工作目录 {DATA_DIR=!s} 已创建。") + + if TEMP_DIR.exists() and TEMP_DIR.is_dir(): + logger.debug(f"工作目录 {TEMP_DIR=!s} 已存在。") + # 清空 temp 工作目录中的已有文件 + c = 0 + for item in TEMP_DIR.iterdir(): + item.unlink() + c += 1 + logger.debug(f"清理了 temp 工作目录中的 {c} 个已有文件。") + else: + TEMP_DIR.mkdir(parents=True, exist_ok=True) + logger.debug(f"工作目录 {TEMP_DIR=!s} 已创建。") + + # 如果没有 toml 配置文件就创建 + try: + config.sync_load_json() + except FileNotFoundError: + config.init_config() + config.sync_create_json() + + lines = [ + "", + "🌐 WebUI 管理面板将运行在 /webui 端点下,登录需要令牌。", + "", + "============================================================", + "🔐 管理后端令牌", + "============================================================", + "", + ] + try: + with open(file=DATA_DIR / "token.txt", mode="r", encoding="utf-8") as f: + token = f.readline().strip() + lines.insert(-2, f"🔐 使用已经存在的管理后端令牌 | {token}") + except FileNotFoundError: + token = token_urlsafe(32) + with open(file=DATA_DIR / "token.txt", mode="w", encoding="utf-8") as f: + f.write(token) + lines.insert(-2, f"🔐 新的管理后端令牌已生成 | {token}") + lines.insert(-2, "🔐 你需要此令牌以登录管理员面板并可视化地配置 NJUPT Suan API。") + logger.info("\n".join(lines)) + + logger.info("🥭 准备 uvicorn run ...") + host = config.get("system", "host", "0.0.0.0") + port = config.get("system", "port", 8000) + reload = config.get("system", "reload", True) + logger.debug(f"启动参数 - {host=} | {port=} | {reload=}") + logger.debug("这些参数无法自动热重载。如果你修改了他们,请 Ctrl + C 关闭并重新启动 Suan API。") + uvicorn.run( + "server:app", + host=host, + port=port, + reload=reload, + reload_dirs=["njupt_api", "router"], + access_log=False, + log_level="critical", + timeout_graceful_shutdown=2, + ) + logger.debug("退出时未清理 temp 工作目录,将在 Suan API 下次启动时清理。") + logger.info("🥭 uvicorn run 已结束。") diff --git a/njupt_api/__init__.py b/njupt_api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/njupt_api/art.txt b/njupt_api/art.txt new file mode 100644 index 0000000..21a8612 --- /dev/null +++ b/njupt_api/art.txt @@ -0,0 +1,10 @@ + + ________ ___ ___ ___ ________ _________ ________ ___ ___ ________ ________ ________ ________ ___ +|\ ___ \ |\ \|\ \|\ \|\ __ \|\___ ___\ |\ ____\|\ \|\ \|\ __ \|\ ___ \ |\ __ \|\ __ \|\ \ +\ \ \\ \ \ \ \ \ \ \\\ \ \ \|\ \|___ \ \_| \ \ \___|\ \ \\\ \ \ \|\ \ \ \\ \ \ \ \ \|\ \ \ \|\ \ \ \ + \ \ \\ \ \ __ \ \ \ \ \\\ \ \ ____\ \ \ \ \ \_____ \ \ \\\ \ \ __ \ \ \\ \ \ \ \ __ \ \ ____\ \ \ + \ \ \\ \ \|\ \\_\ \ \ \\\ \ \ \___| \ \ \ \|____|\ \ \ \\\ \ \ \ \ \ \ \\ \ \ \ \ \ \ \ \ \___|\ \ \ + \ \__\\ \__\ \________\ \_______\ \__\ \ \__\ ____\_\ \ \_______\ \__\ \__\ \__\\ \__\ \ \__\ \__\ \__\ \ \__\ + \|__| \|__|\|________|\|_______|\|__| \|__| |\_________\|_______|\|__|\|__|\|__| \|__| \|__|\|__|\|__| \|__| + \|_________| + => NJUPT Suan API (v.{}) | Made with 💗 by MangoFanFanw | Powered by FastAPI, FastMCP, Vue and more ~ diff --git a/njupt_api/baselib/__init__.py b/njupt_api/baselib/__init__.py new file mode 100644 index 0000000..e9d2aec --- /dev/null +++ b/njupt_api/baselib/__init__.py @@ -0,0 +1,14 @@ +from .config import config +from .logger import LogRecord, log_buffer, log_record_serialize, logger +from .mcploggingmiddleware import LoggingMiddleware +from .playcontextmanager import PlayContextManager + +__all__ = [ + config, + LogRecord, + log_buffer, + log_record_serialize, + logger, + LoggingMiddleware, + PlayContextManager, +] diff --git a/njupt_api/baselib/config.py b/njupt_api/baselib/config.py new file mode 100644 index 0000000..5b5daa2 --- /dev/null +++ b/njupt_api/baselib/config.py @@ -0,0 +1,115 @@ +from json import dumps, loads +from pathlib import Path +from typing import TypeVar + +import aiofiles + +from .logger import logger + +CONFIG_PATH = Path.cwd() / "data" / "config.json" + + +T = TypeVar("T") + + +class Config: + def __init__(self) -> None: + self._doc = {} + + async def load_json(self) -> None: + """ + 从 Toml 配置文件中读取配置。 + """ + logger.debug("异步读取配置文件。") + async with aiofiles.open(file=CONFIG_PATH, mode="r") as f: + self._doc = loads(await f.read()) + + def sync_load_json(self) -> None: + """ + 同步读取配置文件,仅限于 main.py 中启动时。 + + Raises: + FileNotFoundError: 配置文件不存在。 + """ + logger.debug("同步读取配置文件。") + try: + with open(file=CONFIG_PATH, mode="r") as f: + self._doc = loads(f.read()) + except FileNotFoundError: + logger.warning("FileNotFoundError - 配置文件不存在。") + raise + + def sync_create_json(self) -> None: + """ + 同步创建配置文件。 + """ + logger.debug("同步创建配置文件。") + with open(file=CONFIG_PATH, mode="w") as f: + f.write(dumps(self._doc)) + + def init_config(self) -> None: + """ + 重新初始化 Toml 配置文件。这会重置所有配置。 + """ + logger.warning("初始化配置文件,这会重置所有配置。") + self._doc.clear() + doc_system = {} + doc_schedule = {} + doc_log = {} + + doc_system["host"] = "0.0.0.0" + doc_system["port"] = 8000 + doc_system["reload"] = True + + doc_schedule["jwxt_login_method"] = "sso" + doc_schedule["semester_start_date"] = "2026-03-02" + doc_schedule["schedule_title_template"] = "芒果酸的第 {title} 周课程表" + doc_schedule["schedule_subtitle_template"] = "我也要上吗?" + + doc_log["log_api_request_details"] = False + doc_log["log_mcp_request_details"] = False + doc_log["log_assets_request"] = False + + self._doc["system"] = doc_system + self._doc["schedule"] = doc_schedule + self._doc["log"] = doc_log + + async def save_json(self) -> None: + """ + 异步保存 Toml 配置文件。 + """ + logger.debug("异步保存配置文件。") + async with aiofiles.open(file=CONFIG_PATH, mode="w") as f: + await f.write(dumps(self._doc, indent=4)) + + def get(self, group: str, option: str, default: T) -> T: + """ + 获取配置项的值。 + + Args: + group: Table + option: Key + default: 默认值 + + Returns: + Any,与 default 参数类型相同。 + """ + try: + return self._doc.get(group).get(option) + except AttributeError: + return default + + def to_dict(self) -> dict: + return self._doc + + def from_dict(self, data: dict) -> None: + self._doc.clear() + for key, value in data.items(): + if isinstance(value, dict): + t_table = {} + for k, v in value.items(): + t_table[k] = v + self._doc[key] = t_table + + +config = Config() diff --git a/njupt_api/baselib/logger.py b/njupt_api/baselib/logger.py new file mode 100644 index 0000000..c41b665 --- /dev/null +++ b/njupt_api/baselib/logger.py @@ -0,0 +1,42 @@ +import sys +from collections import deque + +from loguru import logger + +logger.remove() +logger.add( + sys.stdout, + level="DEBUG", + colorize=True, +) +logger.add("data/app.log", rotation="10 MB", retention="7 days") # 文件日志 + +log_buffer = deque(maxlen=1000) + + +class LogRecord: + log_counter = 0 + + def __init__(self, message: str) -> None: + self.id = LogRecord.log_counter + self.message = message + + LogRecord.log_counter += 1 + + +def log_record_serialize(record: LogRecord) -> dict: + return { + "id": record.id, + "message": record.message, + } + + +def memory_sink(message: str) -> None: + """向自定义缓冲区写入日志,供 WebUI 获取 + :param message: 'loguru._handler.Message' + """ + log_entry = LogRecord(message=message) + log_buffer.append(log_entry) + + +logger.add(sink=memory_sink, level="DEBUG", colorize=True) diff --git a/njupt_api/baselib/mcploggingmiddleware.py b/njupt_api/baselib/mcploggingmiddleware.py new file mode 100644 index 0000000..ce58c06 --- /dev/null +++ b/njupt_api/baselib/mcploggingmiddleware.py @@ -0,0 +1,36 @@ +import time + +from fastmcp.server.middleware import CallNext, MiddlewareContext +from fastmcp.server.middleware.middleware import Middleware +from fastmcp.tools import ToolResult +from mcp import types as mt + +from . import config +from .logger import logger + + +class LoggingMiddleware(Middleware): + async def on_call_tool( + self, + context: MiddlewareContext[mt.CallToolRequestParams], + call_next: CallNext[mt.CallToolRequestParams, ToolResult], + ) -> ToolResult: + tool_name = context.message.name + args = context.message.arguments + + start_time = time.time() + logger.debug(f"MCP → 调用工具: {tool_name}") + if config.get("log", "log_mcp_request_details", False): + logger.debug(f"调用参数 - {args=}") + + try: + result = await call_next(context) + elapsed = time.time() - start_time + logger.info(f"MCP ← 工具 {tool_name} 完成, 耗时: {elapsed:.3f}s") + return result + except Exception as e: + elapsed = time.time() - start_time + logger.error( + f"MCP ✗ 工具 {tool_name} 失败, 耗时: {elapsed:.3f}s, 错误: {e}", + ) + raise diff --git a/njupt_api/baselib/playcontextmanager.py b/njupt_api/baselib/playcontextmanager.py new file mode 100644 index 0000000..a5a4e15 --- /dev/null +++ b/njupt_api/baselib/playcontextmanager.py @@ -0,0 +1,56 @@ +from playwright.async_api import ( + Browser, + BrowserContext, + Page, + Playwright, + async_playwright, +) + + +class PlayContextManager: + def __init__( + self, + playwright: Playwright = None, + browser: Browser = None, + context: BrowserContext = None, + page: Page = None, + ) -> None: + self.playwright = playwright + self.browser = browser + self.context = context + self.page = page + + self.isLogin = False + + async def start(self) -> None: + """手动启动""" + self.playwright = await async_playwright().start() # 不是 __enter__ + self.browser = await self.playwright.chromium.launch( + headless=False, + args=[ + "--disable-blink-features=AutomationControlled", + "--no-sandbox", + "--disable-setuid-sandbox", + "--disable-dev-shm-usage", + "--disable-gpu", + "--no-proxy-server", + ], + ) + self.context = await self.browser.new_context() + self.page = await self.context.new_page() + + async def __aenter__(self) -> "PlayContextManager": + await self.start() + return self + + async def close(self) -> None: + """手动关闭""" + if self.context: + await self.context.close() + if self.browser: + await self.browser.close() + if self.playwright: + await self.playwright.stop() # 不是 __exit__ + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: # noqa: ANN001 + await self.close() diff --git a/njupt_api/zhengfang/__init__.py b/njupt_api/zhengfang/__init__.py new file mode 100644 index 0000000..166a3f0 --- /dev/null +++ b/njupt_api/zhengfang/__init__.py @@ -0,0 +1,13 @@ +from .createcourse import create_course_schedule +from .sso import SSO +from .types import Course, course_dict_serializer, course_list_serializer +from .zhengfang import ZhengFang + +__all__ = [ + create_course_schedule, + SSO, + Course, + course_dict_serializer, + course_list_serializer, + ZhengFang, +] diff --git a/njupt_api/zhengfang/createcourse.py b/njupt_api/zhengfang/createcourse.py new file mode 100644 index 0000000..1d2b0de --- /dev/null +++ b/njupt_api/zhengfang/createcourse.py @@ -0,0 +1,315 @@ +import re + +from bs4 import BeautifulSoup + +from .types import Course + + +def normalize_course_str(course_str: str) -> str: + """ + 规范化课程字符串,确保 create_course 能正确解析。 + + Returns: + 字符串。 + """ + parts = course_str.split("
") + while parts and parts[0] == "": + parts.pop(0) + while len(parts) < 4: + parts.append(" ") + for i in range(2, 4): + if parts[i] == "": + parts[i] = " " + return "
".join(parts) + + +def create_course_schedule(html: str) -> list[Course]: + """解析给定 HTML 字符串,返回包含数个 Course 对象的列表。 + Args: + html: HTML 字符串。应该有且只有一个 标签,其中是课程表数据。 + + Returns: + list[Course] + + """ + soup = BeautifulSoup(html, "html.parser") + table = soup.find("table") + rows = table.find_all("tr") + + courses: list[Course] = [] + rowspan_map: dict[int, int] = {} + + # 解析第一行表头,建立列索引到星期几的映射 + # 表头格式:第1列是"时间"(colspan=2),然后是 星期一 到 星期日 + day_map: dict[int, int] = {} # col_idx -> day (1-7) + if rows: + header_cells = rows[0].find_all(["td", "th"]) + col_idx = 0 + for cell in header_cells: + text = cell.get_text(strip=True) + colspan = int(cell.get("colspan", 1)) + + # 跳过"时间"单元格 + if text != "时间": + # 映射星期几到数字 + day_mapping = { + "星期一": 1, + "星期二": 2, + "星期三": 3, + "星期四": 4, + "星期五": 5, + "星期六": 6, + "星期日": 7, + "星期天": 7, + } + day = day_mapping.get(text) + if day is not None: + for c in range(col_idx, col_idx + colspan): + day_map[c] = day + + col_idx += colspan + + for row_idx, row in enumerate(rows): + if row_idx == 0: + continue + + cells = row.find_all(["td", "th"]) + col_idx = 0 + class_start: int | None = None + + for cell in cells: + while col_idx in rowspan_map and rowspan_map[col_idx] > 0: + rowspan_map[col_idx] -= 1 + if rowspan_map[col_idx] == 0: + del rowspan_map[col_idx] + col_idx += 1 + + text = cell.get_text(strip=True) + colspan = int(cell.get("colspan", 1)) + rowspan = int(cell.get("rowspan", 1)) + + if text.startswith("第") and text.endswith("节"): + class_start = int(text[1:-1]) + if rowspan > 1: + for c in range(col_idx, col_idx + colspan): + rowspan_map[c] = rowspan - 1 + col_idx += colspan + continue + + if text in ("早晨", "上午", "下午", "晚上"): + if rowspan > 1: + for c in range(col_idx, col_idx + colspan): + rowspan_map[c] = rowspan - 1 + col_idx += colspan + continue + + td_str = str(cell) + start = td_str.find(">") + 1 + end = td_str.rfind("") + inner_html = td_str[start:end] + + if " " not in inner_html and inner_html.strip(): + inner_html = re.sub(r"", "
", inner_html) + course_strs = [ + s.strip() for s in re.split(r"(?:
){2,}", inner_html) if s.strip() and " " not in s + ] + # 获取当前列对应的星期几 + day = day_map.get(col_idx, 1) # 默认为1(星期一) + for course_str in course_strs: + course_str = normalize_course_str(course_str) + courses.append( + create_course( + course_str, + day, + default_classes_start=class_start, + ), + ) + + if rowspan > 1: + for c in range(col_idx, col_idx + colspan): + rowspan_map[c] = rowspan - 1 + + col_idx += colspan + + return courses + + +def create_course( + raw: str, + day: int, + default_classes_start: int | None = None, +) -> Course: + """根据从 HTML 中提取出的原字符串解析课程信息 + Args: + raw: 原字符串,以
作为换行符 + day: 周内的星期几 + default_classes_start: 如果没有解析出课程的 classes,则使用此参数。 + 此参数应当从表格的行标题解析。 + + Returns: + Course + + """ + # 0 1 2 3 4 + # ['概率论与数理统计', '1-17单(1,2)', '王雪红', '教3-520', ''] + raw_list = raw.split("
") + + # 首先去除列表头部的所有空字符串 + while True: + if raw_list[0] == "": + raw_list.pop(0) + else: + break + + # 对于大部分课程,raw_list[1] 都是形如以下格式 + # 1-17(3,4) + # 1-17单(1,2) *(也可能是双) + # 2节/周 + # 2节/单周 *(也可能是双) + # 周三第3,4节{第1-17周} + # 周五第3,4节{第2-16周|双周} + raw_time = raw_list[1] + weeks = [] + classes = [] + single = False # 内部变量 + double = False # 内部变量 + # 处理前两种形式 + if "-" in raw_time and "第" not in raw_time: + # 也可能是 '1-17单' + t = raw_time.split("(") # ['1-17', '3-4)'] + # 也可能是 '17单' + start, end = t[0].split("-") # ['1', '17'] + if end.endswith("单"): + end = end[:-1] + single = True + elif end.endswith("双"): + end = end[:-1] + double = True + for i in range(int(start), int(end) + 1): + if single and i % 2 == 0: + continue + if double and i % 2 == 1: + continue + weeks.append(i) + raw_classes = t[1].removesuffix(")") + classes = [int(i) for i in raw_classes.split(",")] + # 处理中两种形式 + elif "/" in raw_time: + # 默认学期 1-16 周 + if "/单周" in raw_time: + single = True + elif "/双周" in raw_time: + double = True + for i in range(1, 17): + if single and i % 2 == 0: + continue + if double and i % 2 == 1: + continue + weeks.append(i) + + # 获取多少节课 + t_num = int(raw_time.split("节")[0]) + for i in range(0, t_num): + classes.append(default_classes_start + i) + # 处理后两种形式 + elif "第" in raw_time: + # '周三', '3,4节{', '1-17周}' + # '周五', '3,4节{', '2-16周|双周}' + u = raw_time.split("第") + classes = [int(u_c) for u_c in u[1].split("节")[0].split(",")] + + # '1-17', '}' + # '2-16', '|双', '}' + u_w = u[2].split("周") + if "单" in u_w[1]: + single = True + elif "双" in u_w[1]: + double = True + u_start, u_end = u_w[0].split("-") + for i in range(int(u_start), int(u_end) + 1): + if single and i % 2 == 0: + continue + if double and i % 2 == 1: + continue + weeks.append(i) + + teacher = raw_list[2] if raw_list[2] != " " else None + classroom = raw_list[3] if raw_list[3] != " " else None + + return Course(raw_list[0], weeks, day, classes, teacher, classroom) + + +def convert_dict_schedule_to_tuple(schedule: list[dict]) -> list[tuple]: + """将字典格式的课表转换为压缩的元组格式。 + + Args: + schedule: list[dict],标准格式的课程数据 + + Returns: + list[tuple]: 压缩后的元组格式 (name, teacher, classroom, weeks_str, day, classes) + 其中 weeks 尽量压缩为字符串格式(如 "1-17") + + """ + result = [] + for course in schedule: + name = course.get("name", "") + teacher = course.get("teacher") + classroom = course.get("classroom") + weeks = course.get("weeks", []) + day = course.get("day", 1) + classes = course.get("classes", []) + + # 压缩 weeks 为字符串 + weeks_str = compress_weeks_to_string(weeks) if weeks else "" + + result.append((name, teacher, classroom, weeks_str, day, classes)) + + return result + + +def compress_weeks_to_string(weeks: list[int]) -> str: + """将周数列表压缩为最短的字符串表示。 + + 例如: + [1,2,3,4,5] -> "1-5" + [1,3,5,7] -> "1,3,5,7" + [1,2,3,5,6,7,8] -> "1-3,5-8" + [1] -> "1" + + Args: + weeks: 周数列表 + + Returns: + str: 压缩后的周数字符串 + + """ + if not weeks: + return "" + + # 去重并排序 + weeks = sorted({int(w) for w in weeks}) + + ranges = [] + start = end = weeks[0] + + for w in weeks[1:]: + if w == end + 1: + # 连续,扩展当前范围 + end = w + else: + # 不连续,保存当前范围,开始新范围 + ranges.append((start, end)) + start = end = w + + # 保存最后一个范围 + ranges.append((start, end)) + + # 格式化为字符串 + parts = [] + for start, end in ranges: + if start == end: + parts.append(str(start)) + else: + parts.append(f"{start}-{end}") + + return ",".join(parts) diff --git a/njupt_api/zhengfang/sso.py b/njupt_api/zhengfang/sso.py new file mode 100644 index 0000000..3f5cf25 --- /dev/null +++ b/njupt_api/zhengfang/sso.py @@ -0,0 +1,38 @@ +from njupt_api.baselib import PlayContextManager, logger + + +class SSO(PlayContextManager): + def __init__(self) -> None: + super().__init__() + + async def login(self, username: str, password: str) -> bool: + """使用用户名和密码实现登录南邮统一身份验证。 + + Parameters: + username: 用户名,学号,一般为一位大写字母+八位数字 + password: 密码 + + Returns: + bool,表明判登录是否成功。 + """ + await self.page.goto("http://i.njupt.edu.cn/") + + await self.page.fill('input[name="username"]', username) + await self.page.fill('input[type="password"]', password) + await self.page.click('button[type="button"]') + + await self.page.wait_for_load_state("networkidle") + if "user-login" in self.page.url: + logger.error(f"{username} | 登录失败,请检查学号和密码是否正确。") + return False + + logger.info(f"{username} | 登录南邮统一身份认证成功。") + self.isLogin = True + return True + + async def goto_zf(self) -> None: + sub_frame = self.page.frame_locator('iframe[name="iframe0"]') + async with self.context.expect_event("page") as new_page_event: + await sub_frame.locator('a[title="教务系统"]').click() + self.page = await new_page_event.value + return diff --git a/njupt_api/zhengfang/types.py b/njupt_api/zhengfang/types.py new file mode 100644 index 0000000..d0b93e6 --- /dev/null +++ b/njupt_api/zhengfang/types.py @@ -0,0 +1,42 @@ +from dataclasses import dataclass + + +@dataclass +class Course: + """Course 是对课程表中的 **某一节课** 的抽象。 + + Examples: + 1-17周,星期一,1-2节,数据结构,是一个 Course 对象; + + 1-17周,星期三,3-4节,数据结构,是另一个 Course 对象; + + 1-17周中的单周,星期四,3-4节,英语,是一个 Course 对象; + + 1-17周中的双周,星期四,3-4节,物理,是另一个 Course 对象。 + + """ + + name: str + weeks: list[int] + day: int + classes: list[int] + teacher: str | None + classroom: str | None + + +def course_dict_serializer(course: Course) -> dict[str, str | list[int] | int | None]: + return { + "name": course.name, + "weeks": course.weeks, + "day": course.day, + "classes": course.classes, + "teacher": course.teacher, + "classroom": course.classroom, + } + + +def course_list_serializer(course_list: list[Course]) -> list[dict]: + final_list = [] + for course in course_list: + final_list.append(course_dict_serializer(course)) + return final_list diff --git a/njupt_api/zhengfang/zhengfang.py b/njupt_api/zhengfang/zhengfang.py new file mode 100644 index 0000000..143af8e --- /dev/null +++ b/njupt_api/zhengfang/zhengfang.py @@ -0,0 +1,79 @@ +from ddddocr import DdddOcr +from playwright.async_api import Browser, BrowserContext, Page, Playwright + +from njupt_api.baselib import PlayContextManager, logger +from njupt_api.zhengfang import Course +from njupt_api.zhengfang.createcourse import create_course_schedule +from njupt_api.zhengfang.sso import SSO + + +class ZhengFang(PlayContextManager): + def __init__( + self, + playwright: Playwright = None, + browser: Browser = None, + context: BrowserContext = None, + page: Page = None, + ) -> None: + super().__init__(playwright, browser, context, page) + + @classmethod + async def init_from_sso(cls, sso: SSO) -> "ZhengFang": + await sso.goto_zf() + logger.info("从 SSO 进入正方教务系统。") + return cls(sso.playwright, sso.browser, sso.context, sso.page) + + async def login(self, username: str, password: str) -> bool: + """ + 使用用户名和密码实现教务系统登录。 + + Returns: + bool,表明登录是否成功。 + """ + await self.page.goto("http://jwxt.njupt.edu.cn") + + # 填充用户名和密码 + await self.page.fill("input#txtUserName", username) + await self.page.fill("input#TextBox2", password) + + # 处理验证码 + captcha_img = self.page.locator("img#icode") + captcha_bytes = await captcha_img.screenshot() + ocr = DdddOcr(show_ad=False) + captcha_code = str(ocr.classification(captcha_bytes)) + logger.debug(f"识别到的验证码为: {captcha_code}") + await self.page.fill("input#txtSecretCode", captcha_code) + + async with self.page.expect_event("dialog", timeout=3000) as dialog_info: + await self.page.click("input#Button1") + dialog = await dialog_info.value + if dialog.message == "请到信息维护中完善个人联系方式": + await dialog.accept() + logger.info(f"{username} | 登录正方教务系统成功。") + self.isLogin = True + return True + if "验证码" in dialog.message: + await dialog.accept() + logger.warning(f"{username} | 验证码错误,自动重试...") + return await self.login(username, password) + await dialog.accept() + logger.error(f"{username} | 登录失败,教务系统提示信息为: {dialog.message}") + return False + + async def get_class_schedule(self) -> list[Course]: + await self.page.locator("a.top_link:has-text('公用信息')").click() + await self.page.locator("a:has-text('班级课表查询')").click() + sub_frame = self.page.frame_locator("iframe[name='zhuti']") + logger.debug("获取班级课表。") + return create_course_schedule( + f"
{await sub_frame.locator('table#Table6').inner_html()}
", + ) + + async def get_student_schedule(self) -> list[Course]: + await self.page.locator("a.top_link:has-text('信息查询')").click() + await self.page.locator("a:has-text('学生个人课表')").click() + sub_frame = self.page.frame_locator("iframe[name='zhuti']") + logger.debug("获取个人课表。") + return create_course_schedule( + f"{await sub_frame.locator('table#Table1').inner_html()}
", + ) diff --git a/router/__init__.py b/router/__init__.py new file mode 100644 index 0000000..9226fe7 --- /dev/null +++ b/router/__init__.py @@ -0,0 +1 @@ +from .__version__ import __version__ diff --git a/router/__version__.py b/router/__version__.py new file mode 100644 index 0000000..478e063 --- /dev/null +++ b/router/__version__.py @@ -0,0 +1,6 @@ +from importlib.metadata import PackageNotFoundError, version + +try: + __version__ = version("njupt-suan-api") +except PackageNotFoundError: + __version__ = "dev" diff --git a/router/admin_router.py b/router/admin_router.py new file mode 100644 index 0000000..caf2a2d --- /dev/null +++ b/router/admin_router.py @@ -0,0 +1,123 @@ +from pathlib import Path +from typing import Annotated, Sequence + +import aiofiles +from fastapi import APIRouter, Depends +from pydantic import BaseModel +from sqlmodel import Session, delete, select + +from njupt_api.baselib import config, logger +from njupt_api.zhengfang import ZhengFang, course_list_serializer +from router.enhance.auth import verify_token +from router.enhance.lib import AliasDto, ReturnDto, TestDto, get_session +from router.enhance.model import Alias, Course + + +class ValidateTokenDto(BaseModel): + token: str + + +admin_router = APIRouter(prefix="/admin", tags=["admin"]) + + +@admin_router.post("/validateToken") +async def validate_token(vtd: ValidateTokenDto) -> ReturnDto: + """ + 验证 Token 是否正确,以此判断是否允许登录 WebUI。 + 验证时无需使用 HTTP Bearer,直接作为 body 传入即可。 + + Returns: + ReturnDto,以 success 字段表明是否有效。 + """ + async with aiofiles.open(file=Path.cwd() / "data/token.txt", mode="r") as f: + if (await f.readline()).strip() == vtd.token: + return ReturnDto(success=True) + return ReturnDto(success=False) + + +@admin_router.post("/schedule/test", dependencies=[Depends(verify_token)]) +async def post_schedule_test(test: TestDto, session: Annotated[Session, Depends(get_session)]) -> ReturnDto: + async with ZhengFang() as zf: + if await zf.login(test.username, test.password): + if test.scheduleType == "class": + final_course_list = course_list_serializer( + await zf.get_class_schedule(), + ) + session.exec(delete(Course)) + for course in final_course_list: + session.add(Course(**course)) + session.commit() + + logger.success( + f"{test.username} | 获取 {test.scheduleType} 课表成功,已保存到数据库。", + ) + return ReturnDto(success=True, result=final_course_list) + if test.scheduleType == "student": + final_course_list = course_list_serializer( + await zf.get_student_schedule(), + ) + logger.success( + f"{test.username} | 获取 {test.scheduleType} 课表成功。个人课表不保存。", + ) + return ReturnDto(success=True, result=final_course_list) + logger.error( + f"{test.username} | scheduleType 参数错误。给定的 schedule={test.scheduleType}", + ) + return ReturnDto( + success=False, + message="参数错误,请检查 scheduleType 参数。", + ) + logger.error( + f"{test.username} | 获取课程表失败,请检查账号密码是否正确后再试。", + ) + return ReturnDto( + success=False, + message="获取课程表失败,请检查账号密码是否正确后再试。", + ) + + +@admin_router.get("/schedule/test", dependencies=[Depends(verify_token)]) +async def get_schedule_test(session: Annotated[Session, Depends(get_session)]) -> ReturnDto: + course_dtos: Sequence[Course] = session.exec(select(Course)).all() + return ReturnDto( + success=True, + result=[course.model_dump() for course in course_dtos], + ) + + +@admin_router.post("/schedule/alias", dependencies=[Depends(verify_token)]) +async def post_schedule_alias(alias: AliasDto, session: Annotated[Session, Depends(get_session)]) -> ReturnDto: + for alia in session.exec(select(Alias)).all(): + if alias.originalName == alia.originalName: + logger.error( + f"课程 {alia.originalName} 已经在数据库中存在,不允许重复添加。", + ) + return ReturnDto( + success=False, + message=f"课程 {alia.originalName} 已经在数据库中存在,不允许重复添加。", + ) + + session.add(Alias(originalName=alias.originalName, aliasName=alias.aliasName)) + session.commit() + logger.success(f"已添加课程别名 | {alias.originalName} => {alias.aliasName}") + return ReturnDto(success=True) + + +@admin_router.get("/schedule/alias", dependencies=[Depends(verify_token)]) +async def get_schedule_alias(session: Annotated[Session, Depends(get_session)]) -> ReturnDto: + aliases: Sequence[Alias] = session.exec(select(Alias)).all() + return ReturnDto(success=True, result=[alias.model_dump() for alias in aliases]) + + +@admin_router.post("/config", dependencies=[Depends(verify_token)]) +async def post_config(data: dict) -> ReturnDto: + data_ = data.get("data") + logger.debug(f"接收到配置字典 - {data_}") + config.from_dict(data_) + await config.save_json() + return ReturnDto(success=True, result=config.to_dict()) + + +@admin_router.get("/config", dependencies=[Depends(verify_token)]) +async def get_config() -> ReturnDto: + return ReturnDto(success=True, result=config.to_dict()) diff --git a/router/api_router.py b/router/api_router.py new file mode 100644 index 0000000..65a7849 --- /dev/null +++ b/router/api_router.py @@ -0,0 +1,100 @@ +from pathlib import Path +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.responses import FileResponse +from sqlmodel import Session, select + +from njupt_api.baselib import logger +from njupt_api.zhengfang import ( + ZhengFang, + course_dict_serializer, + course_list_serializer, +) +from router.enhance.lib import ReturnDto, ScheduleQueryDto, apply_enhance, get_session +from router.enhance.model import Course + +TEMP_DIR = Path.cwd() / "temp" + + +api_router = APIRouter(prefix="/api", tags=["API"]) + + +@api_router.post("/schedule/class") +async def post_schedule_class( + student: ScheduleQueryDto, + session: Annotated[Session, Depends(get_session)], +) -> ReturnDto: + if student.username is None and student.password is None: + logger.debug("未提供学号和密码参数,尝试从数据库中返回一次性存储的班级课表。") + course_dtos = session.exec(select(Course)).all() + course_list: list[dict] = [course_dict_serializer(course) for course in course_dtos] + logger.success(f"{student.week=} 从数据库中返回一次性存储的班级课表。") + return await apply_enhance(course_list, student.week, student.img) + if student.username and student.password: + async with ZhengFang() as zf: + if await zf.login(student.username, student.password): + course_list = course_list_serializer(await zf.get_class_schedule()) + logger.success( + f"{student.username} | {student.week=} 获取指定学生的班级课表成功。", + ) + return await apply_enhance(course_list, student.week, student.img) + logger.error( + f"{student.username} | 获取课程表失败,请检查账号密码是否正确后再试。", + ) + return ReturnDto( + success=False, + message="获取课程表失败,请检查账号密码是否正确后再试。", + ) + else: + logger.error( + f"参数错误,请同时携带或同时不携带学号和密码参数: {student.username=} | {student.password=}", + ) + return ReturnDto( + success=False, + message="参数错误,请同时携带或同时不携带学号和密码参数。", + ) + + +@api_router.post("/schedule/student") +async def post_schedule_student(student: ScheduleQueryDto) -> ReturnDto: + if student.username is None or student.password is None: + logger.error("查询学生课表需要同时提供学号和密码参数。") + return ReturnDto( + success=False, + message="查询学生课表需要同时提供学号和密码参数。", + ) + + async with ZhengFang() as zf: + if await zf.login(student.username, student.password): + course_list = course_list_serializer(await zf.get_student_schedule()) + logger.success(f"{student.username} | 获取学生个人课表成功。") + return await apply_enhance(course_list, student.week, student.img) + logger.error( + f"{student.username} | 获取课程表失败,请检查账号密码是否正确后再试。", + ) + return ReturnDto( + success=False, + message="获取课程表失败,请检查账号密码是否正确后再试。", + ) + + +@api_router.get("/schedule/img/{name}") +async def get_schedule_img(name: str) -> FileResponse: + """ + 从 temp 工作目录中读取指定图片并返回。如果图片不存在则报 404。 + + Returns: + FileResponse: 图片。 + + Raises: + HTTPException: 404 - 查找的图片不存在。 + """ + image_file = TEMP_DIR / name + logger.debug(f"尝试获取 {image_file!s}") + if image_file.exists(): + return FileResponse(path=str(image_file), media_type="image/png") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Name wrong or too late.", + ) diff --git a/router/enhance/__init__.py b/router/enhance/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/router/enhance/alias.py b/router/enhance/alias.py new file mode 100644 index 0000000..1eab5b1 --- /dev/null +++ b/router/enhance/alias.py @@ -0,0 +1,41 @@ +"""为课程提供一个简短的别名,便于在空间有限的课程表图片中辨认。 + +别名是单独保存的,在最终输出阶段才会被装饰在原有的课表输出上。 +""" + +from typing import Sequence + +from sqlmodel import Session, select + +from njupt_api.baselib import logger +from router.enhance.model import Alias, engine + + +def apply_alias(courses: list[dict]) -> list[dict]: + with Session(engine) as session: + aliases: Sequence[Alias] = session.exec(select(Alias)).all() + + # 否则不做任何更改 + if len(aliases) == 0: + return courses + + alias_count = 0 + apply_count = 0 + + alias_dict = {} + for alias in aliases: + m = alias.model_dump() + alias_dict[m["originalName"]] = m["aliasName"] + alias_count += 1 + + for course in courses: + if course["name"] in alias_dict: + course["alias"] = alias_dict[course["name"]] + apply_count += 1 + else: + course["alias"] = None + + logger.debug( + f"课程别名 | 将 {alias_count} 个别名应用在了 {apply_count} 门输出的课程上。", + ) + return courses diff --git a/router/enhance/auth.py b/router/enhance/auth.py new file mode 100644 index 0000000..b77fb15 --- /dev/null +++ b/router/enhance/auth.py @@ -0,0 +1,21 @@ +from pathlib import Path +from typing import Annotated + +import aiofiles +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +security = HTTPBearer() + + +async def verify_token(credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)]) -> str: + token = credentials.credentials + + async with aiofiles.open(file=Path.cwd() / "data" / "token.txt", mode="r") as f: + if (await f.readline()).strip() == token: + return token + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing token. (Suan API WebUI)", + headers={"WWW-Authenticate": "Bearer"}, + ) diff --git a/router/enhance/lib.py b/router/enhance/lib.py new file mode 100644 index 0000000..e0c510c --- /dev/null +++ b/router/enhance/lib.py @@ -0,0 +1,77 @@ +from datetime import date, timedelta +from typing import Any, Generator, Literal + +from pydantic import BaseModel +from sqlmodel import Session + +from njupt_api.baselib import config + +from .alias import apply_alias +from .model import engine +from .screenshot import generate_img + + +class ScheduleQueryDto(BaseModel): + username: str | None = None + password: str | None = None + week: int = 0 + img: bool = False + + +class TestDto(BaseModel): + username: str + password: str + scheduleType: Literal["class", "student"] # noqa: N815 + + +class AliasDto(BaseModel): + originalName: str # noqa: N815 + aliasName: str | None # noqa: N815 + + +class ReturnDto(BaseModel): + success: bool + message: str | None = None + result: Any | None = None + img_url: str | None = None + + +def get_session() -> Generator[Session, None, None]: + with Session(engine) as session: + yield session + + +async def apply_enhance(course_list: list[dict], week: int, img: bool) -> ReturnDto: + """ + 在一个方法中集成了 应用别名 和 生成课表图片 功能。此为异步方法,需要 await。 + + Example: + return await apply_enhance(course_list, week, img) + + Returns: + 返回应用别名和图片完毕的 ReturnDto。 + """ + final_course_list = [course for course in course_list if week in course["weeks"]] if week > 0 else course_list + + final_course_list = apply_alias(final_course_list) + + # 获取课表图片设置 + title_template = config.get("schedule", "schedule_title_template", "芒果酸的课程表") + subtitle_template = config.get("schedule", "schedule_subtitle_template", "") + semester_start_date = date.fromisoformat(config.get("schedule", "semester_start_date", "2026-03-02")) + + # 可用变量 + week_start_day = semester_start_date + timedelta(weeks=week) + vars_ = { + "week": week, + "week_start_day": week_start_day.isoformat(), + "week_end_day": (week_start_day + timedelta(days=6)).isoformat(), + } + + img_url = None + if img: + img_url = f"http://172.28.143.24:8000/api/schedule/img/{ + await generate_img(final_course_list, title_template.format(**vars_), subtitle_template.format(**vars_)) + }" + + return ReturnDto(success=True, result=final_course_list, img_url=img_url) diff --git a/router/enhance/model.py b/router/enhance/model.py new file mode 100644 index 0000000..3c66fa6 --- /dev/null +++ b/router/enhance/model.py @@ -0,0 +1,29 @@ +from typing import Optional + +from sqlalchemy import JSON, Column +from sqlmodel import Field, SQLModel, create_engine + +sqlite_file_name = "data/njupt_api.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +engine = create_engine(sqlite_url, connect_args={"check_same_thread": False}) + + +class Course(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + teacher: Optional[str] = Field(default=None, nullable=True) + classroom: Optional[str] = Field(default=None, nullable=True) + weeks: list[int] = Field(default=[], sa_column=Column(JSON)) + day: int + classes: list[int] = Field(default=[], sa_column=Column(JSON)) + + +class Alias(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + originalName: Optional[str] = Field(default=None, nullable=True) # noqa: N815 + aliasName: Optional[str] = Field(default=None, nullable=True) # noqa: N815 + + +def create_db_and_tables() -> None: + SQLModel.metadata.create_all(engine) diff --git a/router/enhance/screenshot.py b/router/enhance/screenshot.py new file mode 100644 index 0000000..8ee40f0 --- /dev/null +++ b/router/enhance/screenshot.py @@ -0,0 +1,49 @@ +"""使用 Playwright 截图""" + +from json import dumps +from pathlib import Path +from urllib.parse import urlencode +from uuid import uuid4 + +from playwright.async_api import ViewportSize + +from njupt_api.baselib import PlayContextManager, logger + +TEMP_DIR = Path.cwd() / "temp" + + +class ScreenShot(PlayContextManager): + def __init__(self) -> None: + super().__init__() + + async def goto(self, url: str) -> bool: + await self.page.set_viewport_size(ViewportSize(width=1200, height=900)) + res = await self.page.goto(url) + logger.debug(f"截图 | {res.ok=} - {url=:.{50}}...") + return res.ok + + async def shot(self, save_path: str) -> None: + await self.page.mouse.move(0, 0) + await self.page.wait_for_load_state("networkidle") + await self.page.screenshot(path=save_path) + logger.debug(f"截图 | 截图已经保存在 {save_path=}") + return + + +async def generate_img(courses: list[dict], title: str, subtitle: str) -> str: + """ + 方法将生成课程表图片并保存在临时目录中,返回图片的完整名称。图片位于 temp 工作目录。 + + Returns: + 字符串,表明生成图片的文件名,格式为 `schedule-{uuid4()}.png` + """ + t_name = f"schedule-{uuid4()}.png" + async with ScreenShot() as ss: + await ss.goto( + f"127.0.0.1:8000/webui/schedule#/?{ + urlencode({'data': dumps(courses), 'title': title, 'subtitle': subtitle}) + }", + ) + await ss.shot(str(TEMP_DIR / t_name)) + logger.debug(f"截图 | 生成临时图片 - {t_name}") + return t_name diff --git a/router/enhance/week.py b/router/enhance/week.py new file mode 100644 index 0000000..944fbe0 --- /dev/null +++ b/router/enhance/week.py @@ -0,0 +1,29 @@ +"""学期-星期计算""" + +from datetime import date, timedelta + + +def get_semester_week_info(start: date, target: date) -> tuple[int, int]: + """ + 给定学期开始日期(第一周周一)和另一指定日期,计算指定的日期是第几周的星期几。 + Args: + start: 学期开始的日期,date + target: 指定日期,date + + Returns: + 包含两个数字的元组,分别为第几周和星期几。 + """ + return (target - start).days // 7 + 1, target.isoweekday() + + +def get_week_day_info(target: date) -> tuple[date, date]: + """ + 给定一个指定日期,获取该日所在的星期的周一和周日的日期。 + Args: + target: 指定日期,date + + Returns: + 包含两个 date 的元组,分别为周一和周日。 + """ + weekday_int = target.weekday() + return target - timedelta(days=weekday_int), target + timedelta(days=6 - weekday_int) diff --git a/router/mcp_router.py b/router/mcp_router.py new file mode 100644 index 0000000..a2dc9c0 --- /dev/null +++ b/router/mcp_router.py @@ -0,0 +1,139 @@ +from pathlib import Path +from typing import Annotated + +from fastmcp import FastMCP +from fastmcp.utilities.types import Image +from mcp.types import ToolAnnotations +from pydantic import Field +from sqlmodel import Session, select + +from njupt_api.baselib import LoggingMiddleware, logger +from njupt_api.zhengfang import ( + ZhengFang, + course_dict_serializer, + course_list_serializer, +) +from router.enhance.lib import ReturnDto, apply_enhance +from router.enhance.model import Course, engine + +mcp = FastMCP("NJUPT API Suan") + +mcp.add_middleware(LoggingMiddleware()) + +mcp_app = mcp.http_app("/") + + +# 统一参数文档 +USERNAME_TYPE = Annotated[str, Field(description="用户名,也即学号,一般是一位字母接八位数字,字母需要大写。")] +PASSWORD_TYPE = Annotated[str, Field(description="密码,字符串。")] +WEEK_TYPE = Annotated[int, Field(description="获取第几周的课表,默认为 0 即获取全部。")] +IMG_TYPE = Annotated[ + bool, + Field( + description="是否需要同时生成图片,默认为 False。如果为 True,图片链接将在 img_url 中提供,链接有效时间为两小时。", # noqa: E501 + ), +] + + +@mcp.tool( + name="tool_schedule_class", + title="获取默认班级课表", + description="获取存储在酸 API 中的默认班级课表,返回值包含 success result message 和 img_url 四个字段。", + annotations=ToolAnnotations( + title="获取默认课表", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), +) +async def tool_schedule_class( + week: WEEK_TYPE = 0, + img: IMG_TYPE = False, +) -> ReturnDto: + with Session(engine) as session: + course_dtos = session.exec(select(Course)).all() + logger.success("从数据库中返回一次性存储的班级课表。") + course_list: list[dict] = [course_dict_serializer(course) for course in course_dtos] + return await apply_enhance(course_list, week, img) + + +@mcp.tool( + name="tool_schedule_class_special", + title="获取指定学生的班级课表", + description="获取指定学生的班级课表。需要提供学号和密码。返回值包含 success result message 和 img_url 四个字段。", + annotations=ToolAnnotations( + title="获取指定学生的班级课表", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), +) +async def tool_schedule_class_special( + username: USERNAME_TYPE, + password: PASSWORD_TYPE, + week: WEEK_TYPE = 0, + img: IMG_TYPE = False, +) -> ReturnDto: + async with ZhengFang() as zf: + if await zf.login(username, password): + final_course_list = course_list_serializer(await zf.get_class_schedule()) + logger.success(f"{username} | 获取指定学生的班级课表成功。") + return await apply_enhance(final_course_list, week, img) + logger.error(f"{username} | 获取课程表失败,请检查账号密码是否正确后再试。") + return ReturnDto( + success=False, + message="获取课程表失败,请检查账号密码是否正确后再试。", + ) + + +@mcp.tool( + name="tool_schedule_student_special", + title="获取指定学生的个人课表", + description="获取指定学生的个人课表。需要提供学号和密码。返回值包含 success result message 和 img_url 四个字段。", + annotations=ToolAnnotations( + title="获取指定学生的个人课表", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), +) +async def tool_schedule_student_special( + username: USERNAME_TYPE, + password: PASSWORD_TYPE, + week: WEEK_TYPE = 0, + img: IMG_TYPE = False, +) -> ReturnDto: + async with ZhengFang() as zf: + if await zf.login(username, password): + final_course_list = course_list_serializer(await zf.get_student_schedule()) + logger.success(f"{username} | 获取指定学生的个人课表成功。") + return await apply_enhance(final_course_list, week, img) + logger.error(f"{username} | 获取课程表失败,请检查账号密码是否正确后再试。") + return ReturnDto( + success=False, + message="获取课程表失败,请检查账号密码是否正确后再试。", + ) + + +@mcp.tool( + name="tool_schedule_image", + title="直接获取课表图片", + description="接收使用其他课表工具得到的图片的文件名,返回 ImageContent。", + annotations=ToolAnnotations( + title="直接获取课表图片", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), +) +async def tool_schedule_image( + img_name: Annotated[str, Field(description="课表图片的文件名,形如 schedule-{uuid4}.png")], +) -> Image | ReturnDto: + img_path = Path.cwd() / "temp" / img_name + if img_path.exists(): + return Image(path=img_path) + return ReturnDto(success=False, message=f"未找到指定的课表图片 {img_name}") diff --git a/router/webui_router.py b/router/webui_router.py new file mode 100644 index 0000000..a88a3c1 --- /dev/null +++ b/router/webui_router.py @@ -0,0 +1,22 @@ +from pathlib import Path + +import aiofiles +from fastapi import APIRouter +from fastapi.responses import HTMLResponse + +WEBUI_INDEX = Path.cwd() / "webui" / "dist" / "index.html" +SCHEDULE_INDEX = Path.cwd() / "webui" / "dist" / "index-schedule.html" + +webui_router = APIRouter(prefix="/webui") + + +@webui_router.get("/", response_class=HTMLResponse) +async def get_webui() -> HTMLResponse: + async with aiofiles.open(file=WEBUI_INDEX, mode="r", encoding="utf-8") as f: + return HTMLResponse(content=await f.read(), status_code=200) + + +@webui_router.get("/schedule", response_class=HTMLResponse) +async def get_webui_schedule() -> HTMLResponse: + async with aiofiles.open(file=SCHEDULE_INDEX, mode="r", encoding="utf-8") as f: + return HTMLResponse(content=await f.read(), status_code=200) diff --git a/server.py b/server.py new file mode 100644 index 0000000..fc5d90b --- /dev/null +++ b/server.py @@ -0,0 +1,208 @@ +import asyncio +import time +import traceback +from contextlib import asynccontextmanager, suppress +from pathlib import Path +from typing import AsyncGenerator, Callable + +from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi.staticfiles import StaticFiles +from fastmcp.utilities.lifespan import combine_lifespans +from watchfiles import awatch + +from njupt_api.baselib import ( + LogRecord, + config, + log_buffer, + log_record_serialize, + logger, +) +from njupt_api.zhengfang.zhengfang import ZhengFang +from router import __version__ +from router.admin_router import admin_router +from router.api_router import api_router +from router.enhance.lib import ReturnDto +from router.enhance.model import create_db_and_tables +from router.mcp_router import mcp_app +from router.webui_router import webui_router + +DATA_DIR = Path.cwd() / "data" + + +async def toml_watcher() -> None: + """配置文件监听器""" + await config.load_json() + async for change in awatch(DATA_DIR / "config.json"): + logger.info(f"配置文件更新,重新加载 | {change=}") + await config.load_json() + + +@asynccontextmanager +async def life_span(_: FastAPI) -> AsyncGenerator[None, None]: + logger.info("初始化 SQLite 数据库中...") + create_db_and_tables() + logger.info("启动配置文件监听任务...") + watcher_task = asyncio.create_task(toml_watcher(), name="toml_watcher") + logger.success("🌟 NJUPT API Suan 已经启动。") + try: + yield + finally: + logger.info("🌙 NJUPT API Suan 正在关闭。") + watcher_task.cancel() + logger.info("配置文件监听任务已结束。") + + +@asynccontextmanager +async def jwxt(username: str, password: str) -> AsyncGenerator[ZhengFang, None]: + zf = ZhengFang() + await zf.start() + await zf.login(username, password) + yield zf + await zf.close() + return + + +app = FastAPI(lifespan=combine_lifespans(life_span, mcp_app.lifespan)) + + +@app.middleware("http") +async def log_requests(request: Request, call_next: Callable) -> None: + # 忽略对路径 /mcp 好 /mcp/ 的日志记录 + if request.url.path == "/mcp" or request.url.path == "/mcp/": + return await call_next(request) + + # 如有需要,忽略对路径 /assets/ 的日志记录 + if not config.get("log", "log_assets_request", False) and request.url.path.startswith("/assets/"): + return await call_next(request) + + # 请求开始时间 + start_time = time.time() + + # 获取请求信息 + client_host = request.client.host if request.client else "unknown" + method = request.method + path = request.url.path + + # 记录请求开始 + logger.debug(f"访问 [{client_host}] {method} {path}") + if config.get("log", "log_api_request_details", False): + logger.debug(f" - {request.headers=}") + logger.debug(f" - {request.path_params=}") + logger.debug(f" - {request.query_params=}") + logger.debug(f" - request.body={await request.body()}") + + # 处理请求 + try: + response = await call_next(request) + + # 计算耗时 + process_time = (time.time() - start_time) * 1000 + + # 记录响应 + if response.status_code < 400: + logger.info( + f"成功 [{client_host}] {method} {path} | Status: {response.status_code} | Time: {process_time:.2f}ms", + ) + if config.get("log", "log_api_request_details", False): + logger.debug(f" - {response}") + else: + logger.warning( + f"警告 [{client_host}] {method} {path} | Status: {response.status_code} | Time: {process_time:.2f}ms", + ) + + # 可以添加响应头(可选) + response.headers["X-Process-Time"] = str(process_time) + return response + + except Exception as exc: + process_time = (time.time() - start_time) * 1000 + logger.error( + f"错误 [{client_host}] {method} {path} | Error: {exc} | Time: {process_time:.2f}ms", + ) + raise + + +@app.websocket("/ws/logs") +async def ws_logs(websocket: WebSocket) -> None: + """向 WebUI 传递 Suan API 日志""" # noqa: DOC501 + await websocket.accept() + logger.debug("日志 websocket 建立连接,日志将被推送到 WebUI。") + last_sent_id = 0 + + try: + while True: + new_logs = [log for log in log_buffer if log.id >= last_sent_id] + if new_logs: + try: + await asyncio.wait_for( + websocket.send_json( + { + "logs": [log_record_serialize(log) for log in new_logs], # 新日志 + "latest_id": LogRecord.log_counter, + }, + ), + timeout=2.0, + ) + except asyncio.TimeoutError: + logger.debug("日志 websocket 发送超时,已自动退出。") + break + last_sent_id = LogRecord.log_counter + + try: + await asyncio.sleep(0.5) + except asyncio.CancelledError: + raise + except WebSocketDisconnect: + logger.debug("日志 websocket 断开连接。") + except asyncio.CancelledError: + logger.debug("日志 websocket 任务取消。") + except Exception as exc: + logger.error(f"向 WebUI 传递日志时遇到异常 : {exc}") + finally: + with suppress(Exception): + await websocket.close() + + +@app.get("/version") +async def get_version() -> ReturnDto: + ver = {"version": __version__} + return ReturnDto(success=True, result=ver) + + +app.include_router(api_router) +app.include_router(admin_router) +app.include_router(webui_router) +app.mount("/mcp", mcp_app) +app.mount( + "/assets", + StaticFiles(directory=Path.cwd() / "webui" / "dist" / "assets"), + name="assets", +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["mcp-session-id"], # 关键:必须显式暴露此 Header +) + + +@app.exception_handler(Exception) +def general_exception_handler(_: Request, exc: Exception) -> JSONResponse: + # 记录错误堆栈 + logger.error( + f"未捕获的异常!这可能表示某些路由、端点已经不可用!\n{exc!s}\n{traceback.format_exc()}", + ) + + return JSONResponse( + status_code=500, + content={ + "code": 500, + "message": "服务器内部错误", + "data": None, # 生产环境不要暴露详细错误信息 + }, + )