diff --git a/main.py b/main.py
new file mode 100644
index 0000000..e95ccfa
--- /dev/null
+++ b/main.py
@@ -0,0 +1,86 @@
+from pathlib import Path
+from secrets import token_urlsafe
+
+from njupt_api.baselib import config, logger
+from router import __version__
+
+DATA_DIR = Path.cwd() / "data"
+TEMP_DIR = Path.cwd() / "temp"
+
+
+if __name__ == "__main__":
+ try:
+ with open(file=Path.cwd() / "njupt_api" / "art.txt", mode="r", encoding="utf-8") as f:
+ print(f.read().format(__version__)) # noqa:T201
+ except FileNotFoundError:
+ pass
+
+ import uvicorn
+
+ logger.success("Ciallo~(∠·ω< )⌒★")
+
+ # 创建需要的工作目录
+ if DATA_DIR.exists() and DATA_DIR.is_dir():
+ logger.debug(f"工作目录 {DATA_DIR=!s} 已存在。")
+ else:
+ DATA_DIR.mkdir(parents=True, exist_ok=True)
+ logger.debug(f"工作目录 {DATA_DIR=!s} 已创建。")
+
+ if TEMP_DIR.exists() and TEMP_DIR.is_dir():
+ logger.debug(f"工作目录 {TEMP_DIR=!s} 已存在。")
+ # 清空 temp 工作目录中的已有文件
+ c = 0
+ for item in TEMP_DIR.iterdir():
+ item.unlink()
+ c += 1
+ logger.debug(f"清理了 temp 工作目录中的 {c} 个已有文件。")
+ else:
+ TEMP_DIR.mkdir(parents=True, exist_ok=True)
+ logger.debug(f"工作目录 {TEMP_DIR=!s} 已创建。")
+
+ # 如果没有 toml 配置文件就创建
+ try:
+ config.sync_load_json()
+ except FileNotFoundError:
+ config.init_config()
+ config.sync_create_json()
+
+ lines = [
+ "",
+ "🌐 WebUI 管理面板将运行在 /webui 端点下,登录需要令牌。",
+ "",
+ "============================================================",
+ "🔐 管理后端令牌",
+ "============================================================",
+ "",
+ ]
+ try:
+ with open(file=DATA_DIR / "token.txt", mode="r", encoding="utf-8") as f:
+ token = f.readline().strip()
+ lines.insert(-2, f"🔐 使用已经存在的管理后端令牌 | {token}")
+ except FileNotFoundError:
+ token = token_urlsafe(32)
+ with open(file=DATA_DIR / "token.txt", mode="w", encoding="utf-8") as f:
+ f.write(token)
+ lines.insert(-2, f"🔐 新的管理后端令牌已生成 | {token}")
+ lines.insert(-2, "🔐 你需要此令牌以登录管理员面板并可视化地配置 NJUPT Suan API。")
+ logger.info("\n".join(lines))
+
+ logger.info("🥭 准备 uvicorn run ...")
+ host = config.get("system", "host", "0.0.0.0")
+ port = config.get("system", "port", 8000)
+ reload = config.get("system", "reload", True)
+ logger.debug(f"启动参数 - {host=} | {port=} | {reload=}")
+ logger.debug("这些参数无法自动热重载。如果你修改了他们,请 Ctrl + C 关闭并重新启动 Suan API。")
+ uvicorn.run(
+ "server:app",
+ host=host,
+ port=port,
+ reload=reload,
+ reload_dirs=["njupt_api", "router"],
+ access_log=False,
+ log_level="critical",
+ timeout_graceful_shutdown=2,
+ )
+ logger.debug("退出时未清理 temp 工作目录,将在 Suan API 下次启动时清理。")
+ logger.info("🥭 uvicorn run 已结束。")
diff --git a/njupt_api/__init__.py b/njupt_api/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/njupt_api/art.txt b/njupt_api/art.txt
new file mode 100644
index 0000000..21a8612
--- /dev/null
+++ b/njupt_api/art.txt
@@ -0,0 +1,10 @@
+
+ ________ ___ ___ ___ ________ _________ ________ ___ ___ ________ ________ ________ ________ ___
+|\ ___ \ |\ \|\ \|\ \|\ __ \|\___ ___\ |\ ____\|\ \|\ \|\ __ \|\ ___ \ |\ __ \|\ __ \|\ \
+\ \ \\ \ \ \ \ \ \ \\\ \ \ \|\ \|___ \ \_| \ \ \___|\ \ \\\ \ \ \|\ \ \ \\ \ \ \ \ \|\ \ \ \|\ \ \ \
+ \ \ \\ \ \ __ \ \ \ \ \\\ \ \ ____\ \ \ \ \ \_____ \ \ \\\ \ \ __ \ \ \\ \ \ \ \ __ \ \ ____\ \ \
+ \ \ \\ \ \|\ \\_\ \ \ \\\ \ \ \___| \ \ \ \|____|\ \ \ \\\ \ \ \ \ \ \ \\ \ \ \ \ \ \ \ \ \___|\ \ \
+ \ \__\\ \__\ \________\ \_______\ \__\ \ \__\ ____\_\ \ \_______\ \__\ \__\ \__\\ \__\ \ \__\ \__\ \__\ \ \__\
+ \|__| \|__|\|________|\|_______|\|__| \|__| |\_________\|_______|\|__|\|__|\|__| \|__| \|__|\|__|\|__| \|__|
+ \|_________|
+ => NJUPT Suan API (v.{}) | Made with 💗 by MangoFanFanw | Powered by FastAPI, FastMCP, Vue and more ~
diff --git a/njupt_api/baselib/__init__.py b/njupt_api/baselib/__init__.py
new file mode 100644
index 0000000..e9d2aec
--- /dev/null
+++ b/njupt_api/baselib/__init__.py
@@ -0,0 +1,14 @@
+from .config import config
+from .logger import LogRecord, log_buffer, log_record_serialize, logger
+from .mcploggingmiddleware import LoggingMiddleware
+from .playcontextmanager import PlayContextManager
+
+__all__ = [
+ config,
+ LogRecord,
+ log_buffer,
+ log_record_serialize,
+ logger,
+ LoggingMiddleware,
+ PlayContextManager,
+]
diff --git a/njupt_api/baselib/config.py b/njupt_api/baselib/config.py
new file mode 100644
index 0000000..5b5daa2
--- /dev/null
+++ b/njupt_api/baselib/config.py
@@ -0,0 +1,115 @@
+from json import dumps, loads
+from pathlib import Path
+from typing import TypeVar
+
+import aiofiles
+
+from .logger import logger
+
+CONFIG_PATH = Path.cwd() / "data" / "config.json"
+
+
+T = TypeVar("T")
+
+
+class Config:
+ def __init__(self) -> None:
+ self._doc = {}
+
+ async def load_json(self) -> None:
+ """
+ 从 Toml 配置文件中读取配置。
+ """
+ logger.debug("异步读取配置文件。")
+ async with aiofiles.open(file=CONFIG_PATH, mode="r") as f:
+ self._doc = loads(await f.read())
+
+ def sync_load_json(self) -> None:
+ """
+ 同步读取配置文件,仅限于 main.py 中启动时。
+
+ Raises:
+ FileNotFoundError: 配置文件不存在。
+ """
+ logger.debug("同步读取配置文件。")
+ try:
+ with open(file=CONFIG_PATH, mode="r") as f:
+ self._doc = loads(f.read())
+ except FileNotFoundError:
+ logger.warning("FileNotFoundError - 配置文件不存在。")
+ raise
+
+ def sync_create_json(self) -> None:
+ """
+ 同步创建配置文件。
+ """
+ logger.debug("同步创建配置文件。")
+ with open(file=CONFIG_PATH, mode="w") as f:
+ f.write(dumps(self._doc))
+
+ def init_config(self) -> None:
+ """
+ 重新初始化 Toml 配置文件。这会重置所有配置。
+ """
+ logger.warning("初始化配置文件,这会重置所有配置。")
+ self._doc.clear()
+ doc_system = {}
+ doc_schedule = {}
+ doc_log = {}
+
+ doc_system["host"] = "0.0.0.0"
+ doc_system["port"] = 8000
+ doc_system["reload"] = True
+
+ doc_schedule["jwxt_login_method"] = "sso"
+ doc_schedule["semester_start_date"] = "2026-03-02"
+ doc_schedule["schedule_title_template"] = "芒果酸的第 {title} 周课程表"
+ doc_schedule["schedule_subtitle_template"] = "我也要上吗?"
+
+ doc_log["log_api_request_details"] = False
+ doc_log["log_mcp_request_details"] = False
+ doc_log["log_assets_request"] = False
+
+ self._doc["system"] = doc_system
+ self._doc["schedule"] = doc_schedule
+ self._doc["log"] = doc_log
+
+ async def save_json(self) -> None:
+ """
+ 异步保存 Toml 配置文件。
+ """
+ logger.debug("异步保存配置文件。")
+ async with aiofiles.open(file=CONFIG_PATH, mode="w") as f:
+ await f.write(dumps(self._doc, indent=4))
+
+ def get(self, group: str, option: str, default: T) -> T:
+ """
+ 获取配置项的值。
+
+ Args:
+ group: Table
+ option: Key
+ default: 默认值
+
+ Returns:
+ Any,与 default 参数类型相同。
+ """
+ try:
+ return self._doc.get(group).get(option)
+ except AttributeError:
+ return default
+
+ def to_dict(self) -> dict:
+ return self._doc
+
+ def from_dict(self, data: dict) -> None:
+ self._doc.clear()
+ for key, value in data.items():
+ if isinstance(value, dict):
+ t_table = {}
+ for k, v in value.items():
+ t_table[k] = v
+ self._doc[key] = t_table
+
+
+config = Config()
diff --git a/njupt_api/baselib/logger.py b/njupt_api/baselib/logger.py
new file mode 100644
index 0000000..c41b665
--- /dev/null
+++ b/njupt_api/baselib/logger.py
@@ -0,0 +1,42 @@
+import sys
+from collections import deque
+
+from loguru import logger
+
+logger.remove()
+logger.add(
+ sys.stdout,
+ level="DEBUG",
+ colorize=True,
+)
+logger.add("data/app.log", rotation="10 MB", retention="7 days") # 文件日志
+
+log_buffer = deque(maxlen=1000)
+
+
+class LogRecord:
+ log_counter = 0
+
+ def __init__(self, message: str) -> None:
+ self.id = LogRecord.log_counter
+ self.message = message
+
+ LogRecord.log_counter += 1
+
+
+def log_record_serialize(record: LogRecord) -> dict:
+ return {
+ "id": record.id,
+ "message": record.message,
+ }
+
+
+def memory_sink(message: str) -> None:
+ """向自定义缓冲区写入日志,供 WebUI 获取
+ :param message: 'loguru._handler.Message'
+ """
+ log_entry = LogRecord(message=message)
+ log_buffer.append(log_entry)
+
+
+logger.add(sink=memory_sink, level="DEBUG", colorize=True)
diff --git a/njupt_api/baselib/mcploggingmiddleware.py b/njupt_api/baselib/mcploggingmiddleware.py
new file mode 100644
index 0000000..ce58c06
--- /dev/null
+++ b/njupt_api/baselib/mcploggingmiddleware.py
@@ -0,0 +1,36 @@
+import time
+
+from fastmcp.server.middleware import CallNext, MiddlewareContext
+from fastmcp.server.middleware.middleware import Middleware
+from fastmcp.tools import ToolResult
+from mcp import types as mt
+
+from . import config
+from .logger import logger
+
+
+class LoggingMiddleware(Middleware):
+ async def on_call_tool(
+ self,
+ context: MiddlewareContext[mt.CallToolRequestParams],
+ call_next: CallNext[mt.CallToolRequestParams, ToolResult],
+ ) -> ToolResult:
+ tool_name = context.message.name
+ args = context.message.arguments
+
+ start_time = time.time()
+ logger.debug(f"MCP → 调用工具: {tool_name}")
+ if config.get("log", "log_mcp_request_details", False):
+ logger.debug(f"调用参数 - {args=}")
+
+ try:
+ result = await call_next(context)
+ elapsed = time.time() - start_time
+ logger.info(f"MCP ← 工具 {tool_name} 完成, 耗时: {elapsed:.3f}s")
+ return result
+ except Exception as e:
+ elapsed = time.time() - start_time
+ logger.error(
+ f"MCP ✗ 工具 {tool_name} 失败, 耗时: {elapsed:.3f}s, 错误: {e}",
+ )
+ raise
diff --git a/njupt_api/baselib/playcontextmanager.py b/njupt_api/baselib/playcontextmanager.py
new file mode 100644
index 0000000..a5a4e15
--- /dev/null
+++ b/njupt_api/baselib/playcontextmanager.py
@@ -0,0 +1,56 @@
+from playwright.async_api import (
+ Browser,
+ BrowserContext,
+ Page,
+ Playwright,
+ async_playwright,
+)
+
+
+class PlayContextManager:
+ def __init__(
+ self,
+ playwright: Playwright = None,
+ browser: Browser = None,
+ context: BrowserContext = None,
+ page: Page = None,
+ ) -> None:
+ self.playwright = playwright
+ self.browser = browser
+ self.context = context
+ self.page = page
+
+ self.isLogin = False
+
+ async def start(self) -> None:
+ """手动启动"""
+ self.playwright = await async_playwright().start() # 不是 __enter__
+ self.browser = await self.playwright.chromium.launch(
+ headless=False,
+ args=[
+ "--disable-blink-features=AutomationControlled",
+ "--no-sandbox",
+ "--disable-setuid-sandbox",
+ "--disable-dev-shm-usage",
+ "--disable-gpu",
+ "--no-proxy-server",
+ ],
+ )
+ self.context = await self.browser.new_context()
+ self.page = await self.context.new_page()
+
+ async def __aenter__(self) -> "PlayContextManager":
+ await self.start()
+ return self
+
+ async def close(self) -> None:
+ """手动关闭"""
+ if self.context:
+ await self.context.close()
+ if self.browser:
+ await self.browser.close()
+ if self.playwright:
+ await self.playwright.stop() # 不是 __exit__
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: # noqa: ANN001
+ await self.close()
diff --git a/njupt_api/zhengfang/__init__.py b/njupt_api/zhengfang/__init__.py
new file mode 100644
index 0000000..166a3f0
--- /dev/null
+++ b/njupt_api/zhengfang/__init__.py
@@ -0,0 +1,13 @@
+from .createcourse import create_course_schedule
+from .sso import SSO
+from .types import Course, course_dict_serializer, course_list_serializer
+from .zhengfang import ZhengFang
+
+__all__ = [
+ create_course_schedule,
+ SSO,
+ Course,
+ course_dict_serializer,
+ course_list_serializer,
+ ZhengFang,
+]
diff --git a/njupt_api/zhengfang/createcourse.py b/njupt_api/zhengfang/createcourse.py
new file mode 100644
index 0000000..1d2b0de
--- /dev/null
+++ b/njupt_api/zhengfang/createcourse.py
@@ -0,0 +1,315 @@
+import re
+
+from bs4 import BeautifulSoup
+
+from .types import Course
+
+
+def normalize_course_str(course_str: str) -> str:
+ """
+ 规范化课程字符串,确保 create_course 能正确解析。
+
+ Returns:
+ 字符串。
+ """
+ parts = course_str.split("
")
+ while parts and parts[0] == "":
+ parts.pop(0)
+ while len(parts) < 4:
+ parts.append(" ")
+ for i in range(2, 4):
+ if parts[i] == "":
+ parts[i] = " "
+ return "
".join(parts)
+
+
+def create_course_schedule(html: str) -> list[Course]:
+ """解析给定 HTML 字符串,返回包含数个 Course 对象的列表。
+ Args:
+ html: HTML 字符串。应该有且只有一个
标签,其中是课程表数据。
+
+ Returns:
+ list[Course]
+
+ """
+ soup = BeautifulSoup(html, "html.parser")
+ table = soup.find("table")
+ rows = table.find_all("tr")
+
+ courses: list[Course] = []
+ rowspan_map: dict[int, int] = {}
+
+ # 解析第一行表头,建立列索引到星期几的映射
+ # 表头格式:第1列是"时间"(colspan=2),然后是 星期一 到 星期日
+ day_map: dict[int, int] = {} # col_idx -> day (1-7)
+ if rows:
+ header_cells = rows[0].find_all(["td", "th"])
+ col_idx = 0
+ for cell in header_cells:
+ text = cell.get_text(strip=True)
+ colspan = int(cell.get("colspan", 1))
+
+ # 跳过"时间"单元格
+ if text != "时间":
+ # 映射星期几到数字
+ day_mapping = {
+ "星期一": 1,
+ "星期二": 2,
+ "星期三": 3,
+ "星期四": 4,
+ "星期五": 5,
+ "星期六": 6,
+ "星期日": 7,
+ "星期天": 7,
+ }
+ day = day_mapping.get(text)
+ if day is not None:
+ for c in range(col_idx, col_idx + colspan):
+ day_map[c] = day
+
+ col_idx += colspan
+
+ for row_idx, row in enumerate(rows):
+ if row_idx == 0:
+ continue
+
+ cells = row.find_all(["td", "th"])
+ col_idx = 0
+ class_start: int | None = None
+
+ for cell in cells:
+ while col_idx in rowspan_map and rowspan_map[col_idx] > 0:
+ rowspan_map[col_idx] -= 1
+ if rowspan_map[col_idx] == 0:
+ del rowspan_map[col_idx]
+ col_idx += 1
+
+ text = cell.get_text(strip=True)
+ colspan = int(cell.get("colspan", 1))
+ rowspan = int(cell.get("rowspan", 1))
+
+ if text.startswith("第") and text.endswith("节"):
+ class_start = int(text[1:-1])
+ if rowspan > 1:
+ for c in range(col_idx, col_idx + colspan):
+ rowspan_map[c] = rowspan - 1
+ col_idx += colspan
+ continue
+
+ if text in ("早晨", "上午", "下午", "晚上"):
+ if rowspan > 1:
+ for c in range(col_idx, col_idx + colspan):
+ rowspan_map[c] = rowspan - 1
+ col_idx += colspan
+ continue
+
+ td_str = str(cell)
+ start = td_str.find(">") + 1
+ end = td_str.rfind("")
+ inner_html = td_str[start:end]
+
+ if " " not in inner_html and inner_html.strip():
+ inner_html = re.sub(r"
", "
", inner_html)
+ course_strs = [
+ s.strip() for s in re.split(r"(?:
){2,}", inner_html) if s.strip() and " " not in s
+ ]
+ # 获取当前列对应的星期几
+ day = day_map.get(col_idx, 1) # 默认为1(星期一)
+ for course_str in course_strs:
+ course_str = normalize_course_str(course_str)
+ courses.append(
+ create_course(
+ course_str,
+ day,
+ default_classes_start=class_start,
+ ),
+ )
+
+ if rowspan > 1:
+ for c in range(col_idx, col_idx + colspan):
+ rowspan_map[c] = rowspan - 1
+
+ col_idx += colspan
+
+ return courses
+
+
+def create_course(
+ raw: str,
+ day: int,
+ default_classes_start: int | None = None,
+) -> Course:
+ """根据从 HTML 中提取出的原字符串解析课程信息
+ Args:
+ raw: 原字符串,以
作为换行符
+ day: 周内的星期几
+ default_classes_start: 如果没有解析出课程的 classes,则使用此参数。
+ 此参数应当从表格的行标题解析。
+
+ Returns:
+ Course
+
+ """
+ # 0 1 2 3 4
+ # ['概率论与数理统计', '1-17单(1,2)', '王雪红', '教3-520', '']
+ raw_list = raw.split("
")
+
+ # 首先去除列表头部的所有空字符串
+ while True:
+ if raw_list[0] == "":
+ raw_list.pop(0)
+ else:
+ break
+
+ # 对于大部分课程,raw_list[1] 都是形如以下格式
+ # 1-17(3,4)
+ # 1-17单(1,2) *(也可能是双)
+ # 2节/周
+ # 2节/单周 *(也可能是双)
+ # 周三第3,4节{第1-17周}
+ # 周五第3,4节{第2-16周|双周}
+ raw_time = raw_list[1]
+ weeks = []
+ classes = []
+ single = False # 内部变量
+ double = False # 内部变量
+ # 处理前两种形式
+ if "-" in raw_time and "第" not in raw_time:
+ # 也可能是 '1-17单'
+ t = raw_time.split("(") # ['1-17', '3-4)']
+ # 也可能是 '17单'
+ start, end = t[0].split("-") # ['1', '17']
+ if end.endswith("单"):
+ end = end[:-1]
+ single = True
+ elif end.endswith("双"):
+ end = end[:-1]
+ double = True
+ for i in range(int(start), int(end) + 1):
+ if single and i % 2 == 0:
+ continue
+ if double and i % 2 == 1:
+ continue
+ weeks.append(i)
+ raw_classes = t[1].removesuffix(")")
+ classes = [int(i) for i in raw_classes.split(",")]
+ # 处理中两种形式
+ elif "/" in raw_time:
+ # 默认学期 1-16 周
+ if "/单周" in raw_time:
+ single = True
+ elif "/双周" in raw_time:
+ double = True
+ for i in range(1, 17):
+ if single and i % 2 == 0:
+ continue
+ if double and i % 2 == 1:
+ continue
+ weeks.append(i)
+
+ # 获取多少节课
+ t_num = int(raw_time.split("节")[0])
+ for i in range(0, t_num):
+ classes.append(default_classes_start + i)
+ # 处理后两种形式
+ elif "第" in raw_time:
+ # '周三', '3,4节{', '1-17周}'
+ # '周五', '3,4节{', '2-16周|双周}'
+ u = raw_time.split("第")
+ classes = [int(u_c) for u_c in u[1].split("节")[0].split(",")]
+
+ # '1-17', '}'
+ # '2-16', '|双', '}'
+ u_w = u[2].split("周")
+ if "单" in u_w[1]:
+ single = True
+ elif "双" in u_w[1]:
+ double = True
+ u_start, u_end = u_w[0].split("-")
+ for i in range(int(u_start), int(u_end) + 1):
+ if single and i % 2 == 0:
+ continue
+ if double and i % 2 == 1:
+ continue
+ weeks.append(i)
+
+ teacher = raw_list[2] if raw_list[2] != " " else None
+ classroom = raw_list[3] if raw_list[3] != " " else None
+
+ return Course(raw_list[0], weeks, day, classes, teacher, classroom)
+
+
+def convert_dict_schedule_to_tuple(schedule: list[dict]) -> list[tuple]:
+ """将字典格式的课表转换为压缩的元组格式。
+
+ Args:
+ schedule: list[dict],标准格式的课程数据
+
+ Returns:
+ list[tuple]: 压缩后的元组格式 (name, teacher, classroom, weeks_str, day, classes)
+ 其中 weeks 尽量压缩为字符串格式(如 "1-17")
+
+ """
+ result = []
+ for course in schedule:
+ name = course.get("name", "")
+ teacher = course.get("teacher")
+ classroom = course.get("classroom")
+ weeks = course.get("weeks", [])
+ day = course.get("day", 1)
+ classes = course.get("classes", [])
+
+ # 压缩 weeks 为字符串
+ weeks_str = compress_weeks_to_string(weeks) if weeks else ""
+
+ result.append((name, teacher, classroom, weeks_str, day, classes))
+
+ return result
+
+
+def compress_weeks_to_string(weeks: list[int]) -> str:
+ """将周数列表压缩为最短的字符串表示。
+
+ 例如:
+ [1,2,3,4,5] -> "1-5"
+ [1,3,5,7] -> "1,3,5,7"
+ [1,2,3,5,6,7,8] -> "1-3,5-8"
+ [1] -> "1"
+
+ Args:
+ weeks: 周数列表
+
+ Returns:
+ str: 压缩后的周数字符串
+
+ """
+ if not weeks:
+ return ""
+
+ # 去重并排序
+ weeks = sorted({int(w) for w in weeks})
+
+ ranges = []
+ start = end = weeks[0]
+
+ for w in weeks[1:]:
+ if w == end + 1:
+ # 连续,扩展当前范围
+ end = w
+ else:
+ # 不连续,保存当前范围,开始新范围
+ ranges.append((start, end))
+ start = end = w
+
+ # 保存最后一个范围
+ ranges.append((start, end))
+
+ # 格式化为字符串
+ parts = []
+ for start, end in ranges:
+ if start == end:
+ parts.append(str(start))
+ else:
+ parts.append(f"{start}-{end}")
+
+ return ",".join(parts)
diff --git a/njupt_api/zhengfang/sso.py b/njupt_api/zhengfang/sso.py
new file mode 100644
index 0000000..3f5cf25
--- /dev/null
+++ b/njupt_api/zhengfang/sso.py
@@ -0,0 +1,38 @@
+from njupt_api.baselib import PlayContextManager, logger
+
+
+class SSO(PlayContextManager):
+ def __init__(self) -> None:
+ super().__init__()
+
+ async def login(self, username: str, password: str) -> bool:
+ """使用用户名和密码实现登录南邮统一身份验证。
+
+ Parameters:
+ username: 用户名,学号,一般为一位大写字母+八位数字
+ password: 密码
+
+ Returns:
+ bool,表明判登录是否成功。
+ """
+ await self.page.goto("http://i.njupt.edu.cn/")
+
+ await self.page.fill('input[name="username"]', username)
+ await self.page.fill('input[type="password"]', password)
+ await self.page.click('button[type="button"]')
+
+ await self.page.wait_for_load_state("networkidle")
+ if "user-login" in self.page.url:
+ logger.error(f"{username} | 登录失败,请检查学号和密码是否正确。")
+ return False
+
+ logger.info(f"{username} | 登录南邮统一身份认证成功。")
+ self.isLogin = True
+ return True
+
+ async def goto_zf(self) -> None:
+ sub_frame = self.page.frame_locator('iframe[name="iframe0"]')
+ async with self.context.expect_event("page") as new_page_event:
+ await sub_frame.locator('a[title="教务系统"]').click()
+ self.page = await new_page_event.value
+ return
diff --git a/njupt_api/zhengfang/types.py b/njupt_api/zhengfang/types.py
new file mode 100644
index 0000000..d0b93e6
--- /dev/null
+++ b/njupt_api/zhengfang/types.py
@@ -0,0 +1,42 @@
+from dataclasses import dataclass
+
+
+@dataclass
+class Course:
+ """Course 是对课程表中的 **某一节课** 的抽象。
+
+ Examples:
+ 1-17周,星期一,1-2节,数据结构,是一个 Course 对象;
+
+ 1-17周,星期三,3-4节,数据结构,是另一个 Course 对象;
+
+ 1-17周中的单周,星期四,3-4节,英语,是一个 Course 对象;
+
+ 1-17周中的双周,星期四,3-4节,物理,是另一个 Course 对象。
+
+ """
+
+ name: str
+ weeks: list[int]
+ day: int
+ classes: list[int]
+ teacher: str | None
+ classroom: str | None
+
+
+def course_dict_serializer(course: Course) -> dict[str, str | list[int] | int | None]:
+ return {
+ "name": course.name,
+ "weeks": course.weeks,
+ "day": course.day,
+ "classes": course.classes,
+ "teacher": course.teacher,
+ "classroom": course.classroom,
+ }
+
+
+def course_list_serializer(course_list: list[Course]) -> list[dict]:
+ final_list = []
+ for course in course_list:
+ final_list.append(course_dict_serializer(course))
+ return final_list
diff --git a/njupt_api/zhengfang/zhengfang.py b/njupt_api/zhengfang/zhengfang.py
new file mode 100644
index 0000000..143af8e
--- /dev/null
+++ b/njupt_api/zhengfang/zhengfang.py
@@ -0,0 +1,79 @@
+from ddddocr import DdddOcr
+from playwright.async_api import Browser, BrowserContext, Page, Playwright
+
+from njupt_api.baselib import PlayContextManager, logger
+from njupt_api.zhengfang import Course
+from njupt_api.zhengfang.createcourse import create_course_schedule
+from njupt_api.zhengfang.sso import SSO
+
+
+class ZhengFang(PlayContextManager):
+ def __init__(
+ self,
+ playwright: Playwright = None,
+ browser: Browser = None,
+ context: BrowserContext = None,
+ page: Page = None,
+ ) -> None:
+ super().__init__(playwright, browser, context, page)
+
+ @classmethod
+ async def init_from_sso(cls, sso: SSO) -> "ZhengFang":
+ await sso.goto_zf()
+ logger.info("从 SSO 进入正方教务系统。")
+ return cls(sso.playwright, sso.browser, sso.context, sso.page)
+
+ async def login(self, username: str, password: str) -> bool:
+ """
+ 使用用户名和密码实现教务系统登录。
+
+ Returns:
+ bool,表明登录是否成功。
+ """
+ await self.page.goto("http://jwxt.njupt.edu.cn")
+
+ # 填充用户名和密码
+ await self.page.fill("input#txtUserName", username)
+ await self.page.fill("input#TextBox2", password)
+
+ # 处理验证码
+ captcha_img = self.page.locator("img#icode")
+ captcha_bytes = await captcha_img.screenshot()
+ ocr = DdddOcr(show_ad=False)
+ captcha_code = str(ocr.classification(captcha_bytes))
+ logger.debug(f"识别到的验证码为: {captcha_code}")
+ await self.page.fill("input#txtSecretCode", captcha_code)
+
+ async with self.page.expect_event("dialog", timeout=3000) as dialog_info:
+ await self.page.click("input#Button1")
+ dialog = await dialog_info.value
+ if dialog.message == "请到信息维护中完善个人联系方式":
+ await dialog.accept()
+ logger.info(f"{username} | 登录正方教务系统成功。")
+ self.isLogin = True
+ return True
+ if "验证码" in dialog.message:
+ await dialog.accept()
+ logger.warning(f"{username} | 验证码错误,自动重试...")
+ return await self.login(username, password)
+ await dialog.accept()
+ logger.error(f"{username} | 登录失败,教务系统提示信息为: {dialog.message}")
+ return False
+
+ async def get_class_schedule(self) -> list[Course]:
+ await self.page.locator("a.top_link:has-text('公用信息')").click()
+ await self.page.locator("a:has-text('班级课表查询')").click()
+ sub_frame = self.page.frame_locator("iframe[name='zhuti']")
+ logger.debug("获取班级课表。")
+ return create_course_schedule(
+ f"{await sub_frame.locator('table#Table6').inner_html()}
",
+ )
+
+ async def get_student_schedule(self) -> list[Course]:
+ await self.page.locator("a.top_link:has-text('信息查询')").click()
+ await self.page.locator("a:has-text('学生个人课表')").click()
+ sub_frame = self.page.frame_locator("iframe[name='zhuti']")
+ logger.debug("获取个人课表。")
+ return create_course_schedule(
+ f"{await sub_frame.locator('table#Table1').inner_html()}
",
+ )
diff --git a/router/__init__.py b/router/__init__.py
new file mode 100644
index 0000000..9226fe7
--- /dev/null
+++ b/router/__init__.py
@@ -0,0 +1 @@
+from .__version__ import __version__
diff --git a/router/__version__.py b/router/__version__.py
new file mode 100644
index 0000000..478e063
--- /dev/null
+++ b/router/__version__.py
@@ -0,0 +1,6 @@
+from importlib.metadata import PackageNotFoundError, version
+
+try:
+ __version__ = version("njupt-suan-api")
+except PackageNotFoundError:
+ __version__ = "dev"
diff --git a/router/admin_router.py b/router/admin_router.py
new file mode 100644
index 0000000..caf2a2d
--- /dev/null
+++ b/router/admin_router.py
@@ -0,0 +1,123 @@
+from pathlib import Path
+from typing import Annotated, Sequence
+
+import aiofiles
+from fastapi import APIRouter, Depends
+from pydantic import BaseModel
+from sqlmodel import Session, delete, select
+
+from njupt_api.baselib import config, logger
+from njupt_api.zhengfang import ZhengFang, course_list_serializer
+from router.enhance.auth import verify_token
+from router.enhance.lib import AliasDto, ReturnDto, TestDto, get_session
+from router.enhance.model import Alias, Course
+
+
+class ValidateTokenDto(BaseModel):
+ token: str
+
+
+admin_router = APIRouter(prefix="/admin", tags=["admin"])
+
+
+@admin_router.post("/validateToken")
+async def validate_token(vtd: ValidateTokenDto) -> ReturnDto:
+ """
+ 验证 Token 是否正确,以此判断是否允许登录 WebUI。
+ 验证时无需使用 HTTP Bearer,直接作为 body 传入即可。
+
+ Returns:
+ ReturnDto,以 success 字段表明是否有效。
+ """
+ async with aiofiles.open(file=Path.cwd() / "data/token.txt", mode="r") as f:
+ if (await f.readline()).strip() == vtd.token:
+ return ReturnDto(success=True)
+ return ReturnDto(success=False)
+
+
+@admin_router.post("/schedule/test", dependencies=[Depends(verify_token)])
+async def post_schedule_test(test: TestDto, session: Annotated[Session, Depends(get_session)]) -> ReturnDto:
+ async with ZhengFang() as zf:
+ if await zf.login(test.username, test.password):
+ if test.scheduleType == "class":
+ final_course_list = course_list_serializer(
+ await zf.get_class_schedule(),
+ )
+ session.exec(delete(Course))
+ for course in final_course_list:
+ session.add(Course(**course))
+ session.commit()
+
+ logger.success(
+ f"{test.username} | 获取 {test.scheduleType} 课表成功,已保存到数据库。",
+ )
+ return ReturnDto(success=True, result=final_course_list)
+ if test.scheduleType == "student":
+ final_course_list = course_list_serializer(
+ await zf.get_student_schedule(),
+ )
+ logger.success(
+ f"{test.username} | 获取 {test.scheduleType} 课表成功。个人课表不保存。",
+ )
+ return ReturnDto(success=True, result=final_course_list)
+ logger.error(
+ f"{test.username} | scheduleType 参数错误。给定的 schedule={test.scheduleType}",
+ )
+ return ReturnDto(
+ success=False,
+ message="参数错误,请检查 scheduleType 参数。",
+ )
+ logger.error(
+ f"{test.username} | 获取课程表失败,请检查账号密码是否正确后再试。",
+ )
+ return ReturnDto(
+ success=False,
+ message="获取课程表失败,请检查账号密码是否正确后再试。",
+ )
+
+
+@admin_router.get("/schedule/test", dependencies=[Depends(verify_token)])
+async def get_schedule_test(session: Annotated[Session, Depends(get_session)]) -> ReturnDto:
+ course_dtos: Sequence[Course] = session.exec(select(Course)).all()
+ return ReturnDto(
+ success=True,
+ result=[course.model_dump() for course in course_dtos],
+ )
+
+
+@admin_router.post("/schedule/alias", dependencies=[Depends(verify_token)])
+async def post_schedule_alias(alias: AliasDto, session: Annotated[Session, Depends(get_session)]) -> ReturnDto:
+ for alia in session.exec(select(Alias)).all():
+ if alias.originalName == alia.originalName:
+ logger.error(
+ f"课程 {alia.originalName} 已经在数据库中存在,不允许重复添加。",
+ )
+ return ReturnDto(
+ success=False,
+ message=f"课程 {alia.originalName} 已经在数据库中存在,不允许重复添加。",
+ )
+
+ session.add(Alias(originalName=alias.originalName, aliasName=alias.aliasName))
+ session.commit()
+ logger.success(f"已添加课程别名 | {alias.originalName} => {alias.aliasName}")
+ return ReturnDto(success=True)
+
+
+@admin_router.get("/schedule/alias", dependencies=[Depends(verify_token)])
+async def get_schedule_alias(session: Annotated[Session, Depends(get_session)]) -> ReturnDto:
+ aliases: Sequence[Alias] = session.exec(select(Alias)).all()
+ return ReturnDto(success=True, result=[alias.model_dump() for alias in aliases])
+
+
+@admin_router.post("/config", dependencies=[Depends(verify_token)])
+async def post_config(data: dict) -> ReturnDto:
+ data_ = data.get("data")
+ logger.debug(f"接收到配置字典 - {data_}")
+ config.from_dict(data_)
+ await config.save_json()
+ return ReturnDto(success=True, result=config.to_dict())
+
+
+@admin_router.get("/config", dependencies=[Depends(verify_token)])
+async def get_config() -> ReturnDto:
+ return ReturnDto(success=True, result=config.to_dict())
diff --git a/router/api_router.py b/router/api_router.py
new file mode 100644
index 0000000..65a7849
--- /dev/null
+++ b/router/api_router.py
@@ -0,0 +1,100 @@
+from pathlib import Path
+from typing import Annotated
+
+from fastapi import APIRouter, Depends, HTTPException, status
+from fastapi.responses import FileResponse
+from sqlmodel import Session, select
+
+from njupt_api.baselib import logger
+from njupt_api.zhengfang import (
+ ZhengFang,
+ course_dict_serializer,
+ course_list_serializer,
+)
+from router.enhance.lib import ReturnDto, ScheduleQueryDto, apply_enhance, get_session
+from router.enhance.model import Course
+
+TEMP_DIR = Path.cwd() / "temp"
+
+
+api_router = APIRouter(prefix="/api", tags=["API"])
+
+
+@api_router.post("/schedule/class")
+async def post_schedule_class(
+ student: ScheduleQueryDto,
+ session: Annotated[Session, Depends(get_session)],
+) -> ReturnDto:
+ if student.username is None and student.password is None:
+ logger.debug("未提供学号和密码参数,尝试从数据库中返回一次性存储的班级课表。")
+ course_dtos = session.exec(select(Course)).all()
+ course_list: list[dict] = [course_dict_serializer(course) for course in course_dtos]
+ logger.success(f"{student.week=} 从数据库中返回一次性存储的班级课表。")
+ return await apply_enhance(course_list, student.week, student.img)
+ if student.username and student.password:
+ async with ZhengFang() as zf:
+ if await zf.login(student.username, student.password):
+ course_list = course_list_serializer(await zf.get_class_schedule())
+ logger.success(
+ f"{student.username} | {student.week=} 获取指定学生的班级课表成功。",
+ )
+ return await apply_enhance(course_list, student.week, student.img)
+ logger.error(
+ f"{student.username} | 获取课程表失败,请检查账号密码是否正确后再试。",
+ )
+ return ReturnDto(
+ success=False,
+ message="获取课程表失败,请检查账号密码是否正确后再试。",
+ )
+ else:
+ logger.error(
+ f"参数错误,请同时携带或同时不携带学号和密码参数: {student.username=} | {student.password=}",
+ )
+ return ReturnDto(
+ success=False,
+ message="参数错误,请同时携带或同时不携带学号和密码参数。",
+ )
+
+
+@api_router.post("/schedule/student")
+async def post_schedule_student(student: ScheduleQueryDto) -> ReturnDto:
+ if student.username is None or student.password is None:
+ logger.error("查询学生课表需要同时提供学号和密码参数。")
+ return ReturnDto(
+ success=False,
+ message="查询学生课表需要同时提供学号和密码参数。",
+ )
+
+ async with ZhengFang() as zf:
+ if await zf.login(student.username, student.password):
+ course_list = course_list_serializer(await zf.get_student_schedule())
+ logger.success(f"{student.username} | 获取学生个人课表成功。")
+ return await apply_enhance(course_list, student.week, student.img)
+ logger.error(
+ f"{student.username} | 获取课程表失败,请检查账号密码是否正确后再试。",
+ )
+ return ReturnDto(
+ success=False,
+ message="获取课程表失败,请检查账号密码是否正确后再试。",
+ )
+
+
+@api_router.get("/schedule/img/{name}")
+async def get_schedule_img(name: str) -> FileResponse:
+ """
+ 从 temp 工作目录中读取指定图片并返回。如果图片不存在则报 404。
+
+ Returns:
+ FileResponse: 图片。
+
+ Raises:
+ HTTPException: 404 - 查找的图片不存在。
+ """
+ image_file = TEMP_DIR / name
+ logger.debug(f"尝试获取 {image_file!s}")
+ if image_file.exists():
+ return FileResponse(path=str(image_file), media_type="image/png")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Name wrong or too late.",
+ )
diff --git a/router/enhance/__init__.py b/router/enhance/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/router/enhance/alias.py b/router/enhance/alias.py
new file mode 100644
index 0000000..1eab5b1
--- /dev/null
+++ b/router/enhance/alias.py
@@ -0,0 +1,41 @@
+"""为课程提供一个简短的别名,便于在空间有限的课程表图片中辨认。
+
+别名是单独保存的,在最终输出阶段才会被装饰在原有的课表输出上。
+"""
+
+from typing import Sequence
+
+from sqlmodel import Session, select
+
+from njupt_api.baselib import logger
+from router.enhance.model import Alias, engine
+
+
+def apply_alias(courses: list[dict]) -> list[dict]:
+ with Session(engine) as session:
+ aliases: Sequence[Alias] = session.exec(select(Alias)).all()
+
+ # 否则不做任何更改
+ if len(aliases) == 0:
+ return courses
+
+ alias_count = 0
+ apply_count = 0
+
+ alias_dict = {}
+ for alias in aliases:
+ m = alias.model_dump()
+ alias_dict[m["originalName"]] = m["aliasName"]
+ alias_count += 1
+
+ for course in courses:
+ if course["name"] in alias_dict:
+ course["alias"] = alias_dict[course["name"]]
+ apply_count += 1
+ else:
+ course["alias"] = None
+
+ logger.debug(
+ f"课程别名 | 将 {alias_count} 个别名应用在了 {apply_count} 门输出的课程上。",
+ )
+ return courses
diff --git a/router/enhance/auth.py b/router/enhance/auth.py
new file mode 100644
index 0000000..b77fb15
--- /dev/null
+++ b/router/enhance/auth.py
@@ -0,0 +1,21 @@
+from pathlib import Path
+from typing import Annotated
+
+import aiofiles
+from fastapi import Depends, HTTPException, status
+from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+
+security = HTTPBearer()
+
+
+async def verify_token(credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)]) -> str:
+ token = credentials.credentials
+
+ async with aiofiles.open(file=Path.cwd() / "data" / "token.txt", mode="r") as f:
+ if (await f.readline()).strip() == token:
+ return token
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid or missing token. (Suan API WebUI)",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
diff --git a/router/enhance/lib.py b/router/enhance/lib.py
new file mode 100644
index 0000000..e0c510c
--- /dev/null
+++ b/router/enhance/lib.py
@@ -0,0 +1,77 @@
+from datetime import date, timedelta
+from typing import Any, Generator, Literal
+
+from pydantic import BaseModel
+from sqlmodel import Session
+
+from njupt_api.baselib import config
+
+from .alias import apply_alias
+from .model import engine
+from .screenshot import generate_img
+
+
+class ScheduleQueryDto(BaseModel):
+ username: str | None = None
+ password: str | None = None
+ week: int = 0
+ img: bool = False
+
+
+class TestDto(BaseModel):
+ username: str
+ password: str
+ scheduleType: Literal["class", "student"] # noqa: N815
+
+
+class AliasDto(BaseModel):
+ originalName: str # noqa: N815
+ aliasName: str | None # noqa: N815
+
+
+class ReturnDto(BaseModel):
+ success: bool
+ message: str | None = None
+ result: Any | None = None
+ img_url: str | None = None
+
+
+def get_session() -> Generator[Session, None, None]:
+ with Session(engine) as session:
+ yield session
+
+
+async def apply_enhance(course_list: list[dict], week: int, img: bool) -> ReturnDto:
+ """
+ 在一个方法中集成了 应用别名 和 生成课表图片 功能。此为异步方法,需要 await。
+
+ Example:
+ return await apply_enhance(course_list, week, img)
+
+ Returns:
+ 返回应用别名和图片完毕的 ReturnDto。
+ """
+ final_course_list = [course for course in course_list if week in course["weeks"]] if week > 0 else course_list
+
+ final_course_list = apply_alias(final_course_list)
+
+ # 获取课表图片设置
+ title_template = config.get("schedule", "schedule_title_template", "芒果酸的课程表")
+ subtitle_template = config.get("schedule", "schedule_subtitle_template", "")
+ semester_start_date = date.fromisoformat(config.get("schedule", "semester_start_date", "2026-03-02"))
+
+ # 可用变量
+ week_start_day = semester_start_date + timedelta(weeks=week)
+ vars_ = {
+ "week": week,
+ "week_start_day": week_start_day.isoformat(),
+ "week_end_day": (week_start_day + timedelta(days=6)).isoformat(),
+ }
+
+ img_url = None
+ if img:
+ img_url = f"http://172.28.143.24:8000/api/schedule/img/{
+ await generate_img(final_course_list, title_template.format(**vars_), subtitle_template.format(**vars_))
+ }"
+
+ return ReturnDto(success=True, result=final_course_list, img_url=img_url)
diff --git a/router/enhance/model.py b/router/enhance/model.py
new file mode 100644
index 0000000..3c66fa6
--- /dev/null
+++ b/router/enhance/model.py
@@ -0,0 +1,29 @@
+from typing import Optional
+
+from sqlalchemy import JSON, Column
+from sqlmodel import Field, SQLModel, create_engine
+
+sqlite_file_name = "data/njupt_api.db"
+sqlite_url = f"sqlite:///{sqlite_file_name}"
+
+engine = create_engine(sqlite_url, connect_args={"check_same_thread": False})
+
+
+class Course(SQLModel, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+ name: str
+ teacher: Optional[str] = Field(default=None, nullable=True)
+ classroom: Optional[str] = Field(default=None, nullable=True)
+ weeks: list[int] = Field(default=[], sa_column=Column(JSON))
+ day: int
+ classes: list[int] = Field(default=[], sa_column=Column(JSON))
+
+
+class Alias(SQLModel, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+ originalName: Optional[str] = Field(default=None, nullable=True) # noqa: N815
+ aliasName: Optional[str] = Field(default=None, nullable=True) # noqa: N815
+
+
+def create_db_and_tables() -> None:
+ SQLModel.metadata.create_all(engine)
diff --git a/router/enhance/screenshot.py b/router/enhance/screenshot.py
new file mode 100644
index 0000000..8ee40f0
--- /dev/null
+++ b/router/enhance/screenshot.py
@@ -0,0 +1,49 @@
+"""使用 Playwright 截图"""
+
+from json import dumps
+from pathlib import Path
+from urllib.parse import urlencode
+from uuid import uuid4
+
+from playwright.async_api import ViewportSize
+
+from njupt_api.baselib import PlayContextManager, logger
+
+TEMP_DIR = Path.cwd() / "temp"
+
+
+class ScreenShot(PlayContextManager):
+ def __init__(self) -> None:
+ super().__init__()
+
+ async def goto(self, url: str) -> bool:
+ await self.page.set_viewport_size(ViewportSize(width=1200, height=900))
+ res = await self.page.goto(url)
+ logger.debug(f"截图 | {res.ok=} - {url=:.{50}}...")
+ return res.ok
+
+ async def shot(self, save_path: str) -> None:
+ await self.page.mouse.move(0, 0)
+ await self.page.wait_for_load_state("networkidle")
+ await self.page.screenshot(path=save_path)
+ logger.debug(f"截图 | 截图已经保存在 {save_path=}")
+ return
+
+
+async def generate_img(courses: list[dict], title: str, subtitle: str) -> str:
+ """
+ 方法将生成课程表图片并保存在临时目录中,返回图片的完整名称。图片位于 temp 工作目录。
+
+ Returns:
+ 字符串,表明生成图片的文件名,格式为 `schedule-{uuid4()}.png`
+ """
+ t_name = f"schedule-{uuid4()}.png"
+ async with ScreenShot() as ss:
+ await ss.goto(
+ f"127.0.0.1:8000/webui/schedule#/?{
+ urlencode({'data': dumps(courses), 'title': title, 'subtitle': subtitle})
+ }",
+ )
+ await ss.shot(str(TEMP_DIR / t_name))
+ logger.debug(f"截图 | 生成临时图片 - {t_name}")
+ return t_name
diff --git a/router/enhance/week.py b/router/enhance/week.py
new file mode 100644
index 0000000..944fbe0
--- /dev/null
+++ b/router/enhance/week.py
@@ -0,0 +1,29 @@
+"""学期-星期计算"""
+
+from datetime import date, timedelta
+
+
+def get_semester_week_info(start: date, target: date) -> tuple[int, int]:
+ """
+ 给定学期开始日期(第一周周一)和另一指定日期,计算指定的日期是第几周的星期几。
+ Args:
+ start: 学期开始的日期,date
+ target: 指定日期,date
+
+ Returns:
+ 包含两个数字的元组,分别为第几周和星期几。
+ """
+ return (target - start).days // 7 + 1, target.isoweekday()
+
+
+def get_week_day_info(target: date) -> tuple[date, date]:
+ """
+ 给定一个指定日期,获取该日所在的星期的周一和周日的日期。
+ Args:
+ target: 指定日期,date
+
+ Returns:
+ 包含两个 date 的元组,分别为周一和周日。
+ """
+ weekday_int = target.weekday()
+ return target - timedelta(days=weekday_int), target + timedelta(days=6 - weekday_int)
diff --git a/router/mcp_router.py b/router/mcp_router.py
new file mode 100644
index 0000000..a2dc9c0
--- /dev/null
+++ b/router/mcp_router.py
@@ -0,0 +1,139 @@
+from pathlib import Path
+from typing import Annotated
+
+from fastmcp import FastMCP
+from fastmcp.utilities.types import Image
+from mcp.types import ToolAnnotations
+from pydantic import Field
+from sqlmodel import Session, select
+
+from njupt_api.baselib import LoggingMiddleware, logger
+from njupt_api.zhengfang import (
+ ZhengFang,
+ course_dict_serializer,
+ course_list_serializer,
+)
+from router.enhance.lib import ReturnDto, apply_enhance
+from router.enhance.model import Course, engine
+
+mcp = FastMCP("NJUPT API Suan")
+
+mcp.add_middleware(LoggingMiddleware())
+
+mcp_app = mcp.http_app("/")
+
+
+# 统一参数文档
+USERNAME_TYPE = Annotated[str, Field(description="用户名,也即学号,一般是一位字母接八位数字,字母需要大写。")]
+PASSWORD_TYPE = Annotated[str, Field(description="密码,字符串。")]
+WEEK_TYPE = Annotated[int, Field(description="获取第几周的课表,默认为 0 即获取全部。")]
+IMG_TYPE = Annotated[
+ bool,
+ Field(
+ description="是否需要同时生成图片,默认为 False。如果为 True,图片链接将在 img_url 中提供,链接有效时间为两小时。", # noqa: E501
+ ),
+]
+
+
+@mcp.tool(
+ name="tool_schedule_class",
+ title="获取默认班级课表",
+ description="获取存储在酸 API 中的默认班级课表,返回值包含 success result message 和 img_url 四个字段。",
+ annotations=ToolAnnotations(
+ title="获取默认课表",
+ readOnlyHint=True,
+ destructiveHint=False,
+ idempotentHint=True,
+ openWorldHint=False,
+ ),
+)
+async def tool_schedule_class(
+ week: WEEK_TYPE = 0,
+ img: IMG_TYPE = False,
+) -> ReturnDto:
+ with Session(engine) as session:
+ course_dtos = session.exec(select(Course)).all()
+ logger.success("从数据库中返回一次性存储的班级课表。")
+ course_list: list[dict] = [course_dict_serializer(course) for course in course_dtos]
+ return await apply_enhance(course_list, week, img)
+
+
+@mcp.tool(
+ name="tool_schedule_class_special",
+ title="获取指定学生的班级课表",
+ description="获取指定学生的班级课表。需要提供学号和密码。返回值包含 success result message 和 img_url 四个字段。",
+ annotations=ToolAnnotations(
+ title="获取指定学生的班级课表",
+ readOnlyHint=True,
+ destructiveHint=False,
+ idempotentHint=True,
+ openWorldHint=False,
+ ),
+)
+async def tool_schedule_class_special(
+ username: USERNAME_TYPE,
+ password: PASSWORD_TYPE,
+ week: WEEK_TYPE = 0,
+ img: IMG_TYPE = False,
+) -> ReturnDto:
+ async with ZhengFang() as zf:
+ if await zf.login(username, password):
+ final_course_list = course_list_serializer(await zf.get_class_schedule())
+ logger.success(f"{username} | 获取指定学生的班级课表成功。")
+ return await apply_enhance(final_course_list, week, img)
+ logger.error(f"{username} | 获取课程表失败,请检查账号密码是否正确后再试。")
+ return ReturnDto(
+ success=False,
+ message="获取课程表失败,请检查账号密码是否正确后再试。",
+ )
+
+
+@mcp.tool(
+ name="tool_schedule_student_special",
+ title="获取指定学生的个人课表",
+ description="获取指定学生的个人课表。需要提供学号和密码。返回值包含 success result message 和 img_url 四个字段。",
+ annotations=ToolAnnotations(
+ title="获取指定学生的个人课表",
+ readOnlyHint=True,
+ destructiveHint=False,
+ idempotentHint=True,
+ openWorldHint=False,
+ ),
+)
+async def tool_schedule_student_special(
+ username: USERNAME_TYPE,
+ password: PASSWORD_TYPE,
+ week: WEEK_TYPE = 0,
+ img: IMG_TYPE = False,
+) -> ReturnDto:
+ async with ZhengFang() as zf:
+ if await zf.login(username, password):
+ final_course_list = course_list_serializer(await zf.get_student_schedule())
+ logger.success(f"{username} | 获取指定学生的个人课表成功。")
+ return await apply_enhance(final_course_list, week, img)
+ logger.error(f"{username} | 获取课程表失败,请检查账号密码是否正确后再试。")
+ return ReturnDto(
+ success=False,
+ message="获取课程表失败,请检查账号密码是否正确后再试。",
+ )
+
+
+@mcp.tool(
+ name="tool_schedule_image",
+ title="直接获取课表图片",
+ description="接收使用其他课表工具得到的图片的文件名,返回 ImageContent。",
+ annotations=ToolAnnotations(
+ title="直接获取课表图片",
+ readOnlyHint=True,
+ destructiveHint=False,
+ idempotentHint=True,
+ openWorldHint=False,
+ ),
+)
+async def tool_schedule_image(
+ img_name: Annotated[str, Field(description="课表图片的文件名,形如 schedule-{uuid4}.png")],
+) -> Image | ReturnDto:
+ img_path = Path.cwd() / "temp" / img_name
+ if img_path.exists():
+ return Image(path=img_path)
+ return ReturnDto(success=False, message=f"未找到指定的课表图片 {img_name}")
diff --git a/router/webui_router.py b/router/webui_router.py
new file mode 100644
index 0000000..a88a3c1
--- /dev/null
+++ b/router/webui_router.py
@@ -0,0 +1,22 @@
+from pathlib import Path
+
+import aiofiles
+from fastapi import APIRouter
+from fastapi.responses import HTMLResponse
+
+WEBUI_INDEX = Path.cwd() / "webui" / "dist" / "index.html"
+SCHEDULE_INDEX = Path.cwd() / "webui" / "dist" / "index-schedule.html"
+
+webui_router = APIRouter(prefix="/webui")
+
+
+@webui_router.get("/", response_class=HTMLResponse)
+async def get_webui() -> HTMLResponse:
+ async with aiofiles.open(file=WEBUI_INDEX, mode="r", encoding="utf-8") as f:
+ return HTMLResponse(content=await f.read(), status_code=200)
+
+
+@webui_router.get("/schedule", response_class=HTMLResponse)
+async def get_webui_schedule() -> HTMLResponse:
+ async with aiofiles.open(file=SCHEDULE_INDEX, mode="r", encoding="utf-8") as f:
+ return HTMLResponse(content=await f.read(), status_code=200)
diff --git a/server.py b/server.py
new file mode 100644
index 0000000..fc5d90b
--- /dev/null
+++ b/server.py
@@ -0,0 +1,208 @@
+import asyncio
+import time
+import traceback
+from contextlib import asynccontextmanager, suppress
+from pathlib import Path
+from typing import AsyncGenerator, Callable
+
+from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import JSONResponse
+from fastapi.staticfiles import StaticFiles
+from fastmcp.utilities.lifespan import combine_lifespans
+from watchfiles import awatch
+
+from njupt_api.baselib import (
+ LogRecord,
+ config,
+ log_buffer,
+ log_record_serialize,
+ logger,
+)
+from njupt_api.zhengfang.zhengfang import ZhengFang
+from router import __version__
+from router.admin_router import admin_router
+from router.api_router import api_router
+from router.enhance.lib import ReturnDto
+from router.enhance.model import create_db_and_tables
+from router.mcp_router import mcp_app
+from router.webui_router import webui_router
+
+DATA_DIR = Path.cwd() / "data"
+
+
+async def toml_watcher() -> None:
+ """配置文件监听器"""
+ await config.load_json()
+ async for change in awatch(DATA_DIR / "config.json"):
+ logger.info(f"配置文件更新,重新加载 | {change=}")
+ await config.load_json()
+
+
+@asynccontextmanager
+async def life_span(_: FastAPI) -> AsyncGenerator[None, None]:
+ logger.info("初始化 SQLite 数据库中...")
+ create_db_and_tables()
+ logger.info("启动配置文件监听任务...")
+ watcher_task = asyncio.create_task(toml_watcher(), name="toml_watcher")
+ logger.success("🌟 NJUPT API Suan 已经启动。")
+ try:
+ yield
+ finally:
+ logger.info("🌙 NJUPT API Suan 正在关闭。")
+ watcher_task.cancel()
+ logger.info("配置文件监听任务已结束。")
+
+
+@asynccontextmanager
+async def jwxt(username: str, password: str) -> AsyncGenerator[ZhengFang, None]:
+ zf = ZhengFang()
+ await zf.start()
+ await zf.login(username, password)
+ yield zf
+ await zf.close()
+ return
+
+
+app = FastAPI(lifespan=combine_lifespans(life_span, mcp_app.lifespan))
+
+
+@app.middleware("http")
+async def log_requests(request: Request, call_next: Callable) -> None:
+ # 忽略对路径 /mcp 好 /mcp/ 的日志记录
+ if request.url.path == "/mcp" or request.url.path == "/mcp/":
+ return await call_next(request)
+
+ # 如有需要,忽略对路径 /assets/ 的日志记录
+ if not config.get("log", "log_assets_request", False) and request.url.path.startswith("/assets/"):
+ return await call_next(request)
+
+ # 请求开始时间
+ start_time = time.time()
+
+ # 获取请求信息
+ client_host = request.client.host if request.client else "unknown"
+ method = request.method
+ path = request.url.path
+
+ # 记录请求开始
+ logger.debug(f"访问 [{client_host}] {method} {path}")
+ if config.get("log", "log_api_request_details", False):
+ logger.debug(f" - {request.headers=}")
+ logger.debug(f" - {request.path_params=}")
+ logger.debug(f" - {request.query_params=}")
+ logger.debug(f" - request.body={await request.body()}")
+
+ # 处理请求
+ try:
+ response = await call_next(request)
+
+ # 计算耗时
+ process_time = (time.time() - start_time) * 1000
+
+ # 记录响应
+ if response.status_code < 400:
+ logger.info(
+ f"成功 [{client_host}] {method} {path} | Status: {response.status_code} | Time: {process_time:.2f}ms",
+ )
+ if config.get("log", "log_api_request_details", False):
+ logger.debug(f" - {response}")
+ else:
+ logger.warning(
+ f"警告 [{client_host}] {method} {path} | Status: {response.status_code} | Time: {process_time:.2f}ms",
+ )
+
+ # 可以添加响应头(可选)
+ response.headers["X-Process-Time"] = str(process_time)
+ return response
+
+ except Exception as exc:
+ process_time = (time.time() - start_time) * 1000
+ logger.error(
+ f"错误 [{client_host}] {method} {path} | Error: {exc} | Time: {process_time:.2f}ms",
+ )
+ raise
+
+
+@app.websocket("/ws/logs")
+async def ws_logs(websocket: WebSocket) -> None:
+ """向 WebUI 传递 Suan API 日志""" # noqa: DOC501
+ await websocket.accept()
+ logger.debug("日志 websocket 建立连接,日志将被推送到 WebUI。")
+ last_sent_id = 0
+
+ try:
+ while True:
+ new_logs = [log for log in log_buffer if log.id >= last_sent_id]
+ if new_logs:
+ try:
+ await asyncio.wait_for(
+ websocket.send_json(
+ {
+ "logs": [log_record_serialize(log) for log in new_logs], # 新日志
+ "latest_id": LogRecord.log_counter,
+ },
+ ),
+ timeout=2.0,
+ )
+ except asyncio.TimeoutError:
+ logger.debug("日志 websocket 发送超时,已自动退出。")
+ break
+ last_sent_id = LogRecord.log_counter
+
+ try:
+ await asyncio.sleep(0.5)
+ except asyncio.CancelledError:
+ raise
+ except WebSocketDisconnect:
+ logger.debug("日志 websocket 断开连接。")
+ except asyncio.CancelledError:
+ logger.debug("日志 websocket 任务取消。")
+ except Exception as exc:
+ logger.error(f"向 WebUI 传递日志时遇到异常 : {exc}")
+ finally:
+ with suppress(Exception):
+ await websocket.close()
+
+
+@app.get("/version")
+async def get_version() -> ReturnDto:
+ ver = {"version": __version__}
+ return ReturnDto(success=True, result=ver)
+
+
+app.include_router(api_router)
+app.include_router(admin_router)
+app.include_router(webui_router)
+app.mount("/mcp", mcp_app)
+app.mount(
+ "/assets",
+ StaticFiles(directory=Path.cwd() / "webui" / "dist" / "assets"),
+ name="assets",
+)
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ expose_headers=["mcp-session-id"], # 关键:必须显式暴露此 Header
+)
+
+
+@app.exception_handler(Exception)
+def general_exception_handler(_: Request, exc: Exception) -> JSONResponse:
+ # 记录错误堆栈
+ logger.error(
+ f"未捕获的异常!这可能表示某些路由、端点已经不可用!\n{exc!s}\n{traceback.format_exc()}",
+ )
+
+ return JSONResponse(
+ status_code=500,
+ content={
+ "code": 500,
+ "message": "服务器内部错误",
+ "data": None, # 生产环境不要暴露详细错误信息
+ },
+ )