From 567c146fb85d40f010cfbaad24ce90ddd94b9f67 Mon Sep 17 00:00:00 2001 From: MangoFanFanw Date: Mon, 1 Jun 2026 20:45:45 +0800 Subject: [PATCH] =?UTF-8?q?feat(nyahome):=20=E6=94=AF=E6=8C=81=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E7=9A=84=E6=80=9D=E8=80=83=E6=A8=A1=E5=BC=8F=EF=BC=88?= =?UTF-8?q?DS=EF=BC=89=E4=B8=8E=E7=BC=96=E8=BE=91=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加了控制模型是否支持思考以及是否在调用时启用思考的开关,目前为 DeepSeek 适配。 WebUI 进行了同步的更新。 --- src/nyahome/database/__init__.py | 5 +- src/nyahome/database/model_aii.py | 41 +++++----- src/nyahome/database/model_story.py | 1 + src/nyahome/router/aii_router.py | 117 ++++++++++++++++++++++++---- src/nyahome/service/aii_service.py | 9 ++- src/nyahome/service/chat_service.py | 5 +- 6 files changed, 135 insertions(+), 43 deletions(-) diff --git a/src/nyahome/database/__init__.py b/src/nyahome/database/__init__.py index be564d2..f873fe0 100644 --- a/src/nyahome/database/__init__.py +++ b/src/nyahome/database/__init__.py @@ -1,7 +1,7 @@ from sqlmodel import SQLModel from .engine import engine -from .model_aii import AiiModel, AiiModelPublic, AiiProvider, AiiProviderPublic, z_aii_model, z_aii_provider +from .model_aii import AiiModel, AiiModelPublic, AiiProvider, AiiProviderPublic, AiiProviderPublicWithoutKey from .model_story import ( Chatroom, ChatroomChat, @@ -29,6 +29,7 @@ __all__ = [ AiiModelPublic, AiiProvider, AiiProviderPublic, + AiiProviderPublicWithoutKey, ChatScript, Chatroom, ChatroomChat, @@ -42,6 +43,4 @@ __all__ = [ async_get_session, create_db, get_session, - z_aii_model, - z_aii_provider, ] diff --git a/src/nyahome/database/model_aii.py b/src/nyahome/database/model_aii.py index 0515401..48b62f0 100644 --- a/src/nyahome/database/model_aii.py +++ b/src/nyahome/database/model_aii.py @@ -1,6 +1,6 @@ -from typing import Optional +from typing import Any, Optional -from pydantic import BaseModel +from pydantic import BaseModel, SerializerFunctionWrapHandler, model_serializer from sqlmodel import Field, Relationship, SQLModel @@ -16,6 +16,18 @@ class AiiProvider(SQLModel, table=True): aii_models: list["AiiModel"] = Relationship(back_populates="aii_provider") + @model_serializer(mode="wrap") + def serialize_provider(self, handler: SerializerFunctionWrapHandler) -> dict[str, Any]: + data: dict = handler(self) + data.pop("api_key", None) + return data + + +class AiiProviderPublicWithoutKey(BaseModel): + id: Optional[int] = None + name: str + base_url: str + class AiiProviderPublic(BaseModel): id: Optional[int] = None @@ -32,36 +44,27 @@ class AiiModel(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) model_name: str max_context_length: int + reasonable: Optional[bool] = Field(default=None, nullable=True, description="模型是否具备思考能力") aii_provider_id: int = Field(default=None, foreign_key="aiiprovider.id") aii_provider: AiiProvider = Relationship(back_populates="aii_models") chatrooms: list["Chatroom"] = Relationship(back_populates="default_model") + @model_serializer(mode="wrap") + def serialize_model(self, handler: SerializerFunctionWrapHandler) -> dict[str, Any]: + data: dict = handler(self) + data["reasonable"] = bool(data.get("reasonable")) + return data + class AiiModelPublic(BaseModel): id: Optional[int] = None model_name: str max_context_length: int + reasonable: bool aii_provider_id: int -def z_aii_model(am: AiiModel) -> dict: - return { - "id": am.id, - "model_name": am.model_name, - "max_context_length": am.max_context_length, - "aii_provider_id": am.aii_provider_id, - } - - -def z_aii_provider(ap: AiiProvider) -> dict: - return { - "id": ap.id, - "name": ap.name, - "base_url": ap.base_url, - } - - from .model_story import Chatroom # noqa: E402 diff --git a/src/nyahome/database/model_story.py b/src/nyahome/database/model_story.py index 02c5972..ea89df4 100644 --- a/src/nyahome/database/model_story.py +++ b/src/nyahome/database/model_story.py @@ -92,6 +92,7 @@ class ChatroomChat(BaseModel): prefix: str mode: Literal["continue", "expand"] model_id: int + enable_thinking: bool class ChatroomChatAccept(BaseModel): diff --git a/src/nyahome/router/aii_router.py b/src/nyahome/router/aii_router.py index 59bd572..b5a24e2 100644 --- a/src/nyahome/router/aii_router.py +++ b/src/nyahome/router/aii_router.py @@ -1,4 +1,4 @@ -from typing import Annotated +from typing import Annotated, Sequence from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.exc import NoResultFound @@ -9,10 +9,9 @@ from nyahome.database import ( AiiModelPublic, AiiProvider, AiiProviderPublic, + AiiProviderPublicWithoutKey, ModelUser, get_session, - z_aii_model, - z_aii_provider, ) from nyahome.service.aii_service import apply_get_models, s_check_remote_model, s_list_remote_provider_models @@ -23,16 +22,15 @@ aii_router = APIRouter(tags=["Aii"], prefix="/aii") @aii_router.get("/model/", name="获取模型列表") -async def get_all_model(session: Annotated[Session, Depends(get_session)]) -> ReturnDto: +async def get_all_model(session: Annotated[Session, Depends(get_session)]) -> list[dict]: """ 获取 AI 模型列表。 此接口无需用户登录即可访问。 Returns: - 被 ReturnDto 包裹的 AiiModel 列表 + AiiModel 列表 """ - final_model_list = apply_get_models(session) - return ReturnDto(result=final_model_list) + return apply_get_models(session) @aii_router.post("/model/", name="添加模型") @@ -40,18 +38,18 @@ async def add_model( model: AiiModelPublic, user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)], -) -> ReturnDto: +) -> AiiModel: """ 添加新的 AI 模型。需要基于已添加的模型提供商。 此接口需要管理员访问。 - 添加模型时不会进行可用性检查,因此 WebUI 在前端实现了检查按钮。此端点不会负责检查。 + 不会进行可用性检查,因此 WebUI 在前端实现了检查按钮。此端点不会负责检查。 Raises: HTTPException: 401 用户无权限管理模型(未登录或非管理员) HTTPException: 404 模型提供商不存在 Returns: - 被 ReturnDto 包裹的、添加的 AiiModel + AiiModel """ if not user.is_admin: raise HTTPException(status_code=401, detail="用户无权限管理模型。") from None @@ -69,11 +67,57 @@ async def add_model( session.add(am) session.commit() session.refresh(am) - return ReturnDto(result=z_aii_model(am)) + return am + + +@aii_router.post("/model/{id_}", name="修改模型") +async def edit_model( + id_: int, + model: AiiModelPublic, + user: Annotated[ModelUser, Depends(verify_token)], + session: Annotated[Session, Depends(get_session)], +) -> AiiModel: + """ + 修改已添加的 AI 模型。 + 此接口需要管理员访问。 + 不会进行可用性检查,因此 WebUI 在前端实现了检查按钮。此端点不会负责检查。 + **只允许修改模型的名称、最大上下文长度和是否支持思考。** + + Raises: + HTTPException: 400 模型提供商 ID 不匹配 + HTTPException: 401 用户无权限管理模型(未登录或非管理员) + HTTPException: 404 模型提供商不存在 + + Returns: + AiiModel + """ + if not user.is_admin: + raise HTTPException(status_code=401, detail="用户无权限管理模型。") from None + + try: + ap: AiiProvider = session.exec(select(AiiProvider).where(AiiProvider.id == model.aii_provider_id)).one() + except NoResultFound: + raise HTTPException(status_code=404, detail="Provider 不存在。") from None + + try: + am: AiiModel = session.exec(select(AiiModel).where(AiiModel.id == id_)).one() + except NoResultFound: + raise HTTPException(status_code=404, detail="模型不存在。") from None + + if ap.id != am.aii_provider_id: + raise HTTPException(status_code=400, detail="模型提供商 ID 不匹配。") from None + + am.model_name = model.model_name + am.max_context_length = model.max_context_length + am.reasonable = model.reasonable + session.add(am) + session.commit() + session.refresh(am) + return am @aii_router.get("/provider/", name="获取提供商列表") -async def get_all_provider(session: Annotated[Session, Depends(get_session)]) -> ReturnDto: +async def get_all_provider(session: Annotated[Session, Depends(get_session)]) -> Sequence[AiiProvider]: """ 获取 AI 模型提供商列表。 此接口无需用户登录即可访问。 @@ -81,8 +125,7 @@ async def get_all_provider(session: Annotated[Session, Depends(get_session)]) -> Returns: 被 ReturnDto 包裹的 AiiProvider 列表 """ - aii_providers = session.exec(select(AiiProvider)).all() - return ReturnDto(result=[z_aii_provider(ap) for ap in aii_providers]) + return session.exec(select(AiiProvider)).all() @aii_router.post("/provider/", name="添加提供商") @@ -90,11 +133,11 @@ async def add_provider( provider: AiiProviderPublic, user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)], -) -> ReturnDto: +) -> AiiProvider: """ 添加新的 AI 模型提供商。 此接口需要管理员才能访问。 - 添加提供商时不会进行可用性检查,因此 WebUI 在前端实现了检查按钮。此端点不会负责检查。 + 不会进行可用性检查,因此 WebUI 在前端实现了检查按钮。此端点不会负责检查。 Raises: HTTPException: 401 表示用户未登录或非管理员。 @@ -108,7 +151,47 @@ async def add_provider( session.add(ap) session.commit() session.refresh(ap) - return ReturnDto(result=z_aii_provider(ap)) + return ap + + +@aii_router.post("/provider/{id_}/", name="修改提供商") +async def edit_provider( + id_: int, + provider: AiiProviderPublicWithoutKey, + user: Annotated[ModelUser, Depends(verify_token)], + session: Annotated[Session, Depends(get_session)], +) -> AiiProvider: + """ + 修改 AI 模型提供商。 + 此接口需要管理员才能访问。 + 不会进行可用性检查,因此 WebUI 在前端实现了检查按钮。此端点不会负责检查。 + **只允许修改模型提供商的名称和 Base URL。** + + Raises: + HTTPException: 400 模型提供商 ID 不匹配。 + HTTPException: 401 表示用户未登录或非管理员。 + HTTPException: 404 提供商不存在。 + + Returns: + 被 ReturnDto 包裹的、添加的 AiiProvider + """ + if not user.is_admin: + raise HTTPException(status_code=401, detail="用户无权限管理模型。") from None + + if provider.id != id_: + raise HTTPException(status_code=400, detail="模型提供商 ID 不匹配。") from None + + try: + ap: AiiProvider = session.exec(select(AiiProvider).where(AiiProvider.id == id_)).one() + except NoResultFound: + raise HTTPException(status_code=404, detail="提供商不存在。") from None + + ap.name = provider.name + ap.base_url = provider.base_url + session.add(ap) + session.commit() + session.refresh(ap) + return ap @aii_router.get("/provider/{id_}/remote/models/", name="获取提供商远端模型") diff --git a/src/nyahome/service/aii_service.py b/src/nyahome/service/aii_service.py index 7277ed8..9970da3 100644 --- a/src/nyahome/service/aii_service.py +++ b/src/nyahome/service/aii_service.py @@ -1,3 +1,5 @@ +from typing import Sequence + import openai from openai import AsyncOpenAI from sqlalchemy.orm import joinedload @@ -16,17 +18,18 @@ def apply_get_models(session: Session) -> list[dict]: Returns: """ - aii_models = session.exec(select(AiiModel).options(joinedload(AiiModel.aii_provider))).all() # type: ignore[arg-type] + aii_models: Sequence[AiiModel] = session.exec(select(AiiModel).options(joinedload(AiiModel.aii_provider))).all() # type: ignore[arg-type] final_model_list = [] for aii_model in aii_models: final_model_list.append({ "id": aii_model.id, "model_name": aii_model.model_name, - "max_content_length": aii_model.max_context_length, - "provider_id": aii_model.id, + "max_context_length": aii_model.max_context_length, + "provider_id": aii_model.aii_provider_id, "provider_name": aii_model.aii_provider.name, "base_url": aii_model.aii_provider.base_url, + "reasonable": bool(aii_model.reasonable), # 数据库中的 reasonable 字段可能为 None,在这里归一为 False }) return final_model_list diff --git a/src/nyahome/service/chat_service.py b/src/nyahome/service/chat_service.py index 377842d..488970a 100644 --- a/src/nyahome/service/chat_service.py +++ b/src/nyahome/service/chat_service.py @@ -105,11 +105,12 @@ def apply_chat(id_: int, user_id: int, chat: ChatroomChat, session: Session) -> "api_key": model.aii_provider.api_key, "model_name": model.model_name, "messages": messages, + "enable_thinking": chat.enable_thinking, } async def s_start_async_streaming_chat( - base_url: str, api_key: str, model_name: str, messages: list + base_url: str, api_key: str, model_name: str, messages: list, enable_thinking: bool ) -> AsyncGenerator[str, None]: client = AsyncOpenAI(base_url=base_url, api_key=api_key) stream = await client.chat.completions.create( @@ -117,6 +118,7 @@ async def s_start_async_streaming_chat( model=model_name, stream=True, reasoning_effort="high", + extra_body={"thinking": {"type": "enabled" if enable_thinking else "disabled"}}, ) # AI 说 SSE 好喵,推荐我用 SSE 喵,我不知道喵 @@ -135,6 +137,7 @@ async def s_start_async_streaming_chat( yield f"data: {json.dumps({'text': cd, 'type': 'output'}, ensure_ascii=False)}\n\n" logger.info(f"AI 完成输出 : {aii_message}") try: + # noinspection PyUnboundLocalVariable yield f"data: {json.dumps({'type': 'usage', **chuck.usage.model_dump()})}\n\n" # type: ignore[union-attr] finally: yield "data: [DONE]\n\n"