feat(nyahome): 支持模型的思考模式(DS)与编辑模型
增加了控制模型是否支持思考以及是否在调用时启用思考的开关,目前为 DeepSeek 适配。 WebUI 进行了同步的更新。
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from .engine import engine
|
||||
from .model_aii import AiiModel, AiiModelPublic, AiiProvider, AiiProviderPublic, z_aii_model, z_aii_provider
|
||||
from .model_aii import AiiModel, AiiModelPublic, AiiProvider, AiiProviderPublic, AiiProviderPublicWithoutKey
|
||||
from .model_story import (
|
||||
Chatroom,
|
||||
ChatroomChat,
|
||||
@@ -29,6 +29,7 @@ __all__ = [
|
||||
AiiModelPublic,
|
||||
AiiProvider,
|
||||
AiiProviderPublic,
|
||||
AiiProviderPublicWithoutKey,
|
||||
ChatScript,
|
||||
Chatroom,
|
||||
ChatroomChat,
|
||||
@@ -42,6 +43,4 @@ __all__ = [
|
||||
async_get_session,
|
||||
create_db,
|
||||
get_session,
|
||||
z_aii_model,
|
||||
z_aii_provider,
|
||||
]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, SerializerFunctionWrapHandler, model_serializer
|
||||
from sqlmodel import Field, Relationship, SQLModel
|
||||
|
||||
|
||||
@@ -16,6 +16,18 @@ class AiiProvider(SQLModel, table=True):
|
||||
|
||||
aii_models: list["AiiModel"] = Relationship(back_populates="aii_provider")
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize_provider(self, handler: SerializerFunctionWrapHandler) -> dict[str, Any]:
|
||||
data: dict = handler(self)
|
||||
data.pop("api_key", None)
|
||||
return data
|
||||
|
||||
|
||||
class AiiProviderPublicWithoutKey(BaseModel):
|
||||
id: Optional[int] = None
|
||||
name: str
|
||||
base_url: str
|
||||
|
||||
|
||||
class AiiProviderPublic(BaseModel):
|
||||
id: Optional[int] = None
|
||||
@@ -32,36 +44,27 @@ class AiiModel(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
model_name: str
|
||||
max_context_length: int
|
||||
reasonable: Optional[bool] = Field(default=None, nullable=True, description="模型是否具备思考能力")
|
||||
|
||||
aii_provider_id: int = Field(default=None, foreign_key="aiiprovider.id")
|
||||
aii_provider: AiiProvider = Relationship(back_populates="aii_models")
|
||||
|
||||
chatrooms: list["Chatroom"] = Relationship(back_populates="default_model")
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize_model(self, handler: SerializerFunctionWrapHandler) -> dict[str, Any]:
|
||||
data: dict = handler(self)
|
||||
data["reasonable"] = bool(data.get("reasonable"))
|
||||
return data
|
||||
|
||||
|
||||
class AiiModelPublic(BaseModel):
|
||||
id: Optional[int] = None
|
||||
model_name: str
|
||||
max_context_length: int
|
||||
reasonable: bool
|
||||
|
||||
aii_provider_id: int
|
||||
|
||||
|
||||
def z_aii_model(am: AiiModel) -> dict:
|
||||
return {
|
||||
"id": am.id,
|
||||
"model_name": am.model_name,
|
||||
"max_context_length": am.max_context_length,
|
||||
"aii_provider_id": am.aii_provider_id,
|
||||
}
|
||||
|
||||
|
||||
def z_aii_provider(ap: AiiProvider) -> dict:
|
||||
return {
|
||||
"id": ap.id,
|
||||
"name": ap.name,
|
||||
"base_url": ap.base_url,
|
||||
}
|
||||
|
||||
|
||||
from .model_story import Chatroom # noqa: E402
|
||||
|
||||
@@ -92,6 +92,7 @@ class ChatroomChat(BaseModel):
|
||||
prefix: str
|
||||
mode: Literal["continue", "expand"]
|
||||
model_id: int
|
||||
enable_thinking: bool
|
||||
|
||||
|
||||
class ChatroomChatAccept(BaseModel):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Sequence
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
@@ -9,10 +9,9 @@ from nyahome.database import (
|
||||
AiiModelPublic,
|
||||
AiiProvider,
|
||||
AiiProviderPublic,
|
||||
AiiProviderPublicWithoutKey,
|
||||
ModelUser,
|
||||
get_session,
|
||||
z_aii_model,
|
||||
z_aii_provider,
|
||||
)
|
||||
from nyahome.service.aii_service import apply_get_models, s_check_remote_model, s_list_remote_provider_models
|
||||
|
||||
@@ -23,16 +22,15 @@ aii_router = APIRouter(tags=["Aii"], prefix="/aii")
|
||||
|
||||
|
||||
@aii_router.get("/model/", name="获取模型列表")
|
||||
async def get_all_model(session: Annotated[Session, Depends(get_session)]) -> ReturnDto:
|
||||
async def get_all_model(session: Annotated[Session, Depends(get_session)]) -> list[dict]:
|
||||
"""
|
||||
获取 AI 模型列表。
|
||||
此接口无需用户登录即可访问。
|
||||
|
||||
Returns:
|
||||
被 ReturnDto 包裹的 AiiModel 列表
|
||||
AiiModel 列表
|
||||
"""
|
||||
final_model_list = apply_get_models(session)
|
||||
return ReturnDto(result=final_model_list)
|
||||
return apply_get_models(session)
|
||||
|
||||
|
||||
@aii_router.post("/model/", name="添加模型")
|
||||
@@ -40,18 +38,18 @@ async def add_model(
|
||||
model: AiiModelPublic,
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ReturnDto:
|
||||
) -> AiiModel:
|
||||
"""
|
||||
添加新的 AI 模型。需要基于已添加的模型提供商。
|
||||
此接口需要管理员访问。
|
||||
添加模型时不会进行可用性检查,因此 WebUI 在前端实现了检查按钮。此端点不会负责检查。
|
||||
不会进行可用性检查,因此 WebUI 在前端实现了检查按钮。此端点不会负责检查。
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 用户无权限管理模型(未登录或非管理员)
|
||||
HTTPException: 404 模型提供商不存在
|
||||
|
||||
Returns:
|
||||
被 ReturnDto 包裹的、添加的 AiiModel
|
||||
AiiModel
|
||||
"""
|
||||
if not user.is_admin:
|
||||
raise HTTPException(status_code=401, detail="用户无权限管理模型。") from None
|
||||
@@ -69,11 +67,57 @@ async def add_model(
|
||||
session.add(am)
|
||||
session.commit()
|
||||
session.refresh(am)
|
||||
return ReturnDto(result=z_aii_model(am))
|
||||
return am
|
||||
|
||||
|
||||
@aii_router.post("/model/{id_}", name="修改模型")
|
||||
async def edit_model(
|
||||
id_: int,
|
||||
model: AiiModelPublic,
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> AiiModel:
|
||||
"""
|
||||
修改已添加的 AI 模型。
|
||||
此接口需要管理员访问。
|
||||
不会进行可用性检查,因此 WebUI 在前端实现了检查按钮。此端点不会负责检查。
|
||||
**只允许修改模型的名称、最大上下文长度和是否支持思考。**
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 模型提供商 ID 不匹配
|
||||
HTTPException: 401 用户无权限管理模型(未登录或非管理员)
|
||||
HTTPException: 404 模型提供商不存在
|
||||
|
||||
Returns:
|
||||
AiiModel
|
||||
"""
|
||||
if not user.is_admin:
|
||||
raise HTTPException(status_code=401, detail="用户无权限管理模型。") from None
|
||||
|
||||
try:
|
||||
ap: AiiProvider = session.exec(select(AiiProvider).where(AiiProvider.id == model.aii_provider_id)).one()
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Provider 不存在。") from None
|
||||
|
||||
try:
|
||||
am: AiiModel = session.exec(select(AiiModel).where(AiiModel.id == id_)).one()
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="模型不存在。") from None
|
||||
|
||||
if ap.id != am.aii_provider_id:
|
||||
raise HTTPException(status_code=400, detail="模型提供商 ID 不匹配。") from None
|
||||
|
||||
am.model_name = model.model_name
|
||||
am.max_context_length = model.max_context_length
|
||||
am.reasonable = model.reasonable
|
||||
session.add(am)
|
||||
session.commit()
|
||||
session.refresh(am)
|
||||
return am
|
||||
|
||||
|
||||
@aii_router.get("/provider/", name="获取提供商列表")
|
||||
async def get_all_provider(session: Annotated[Session, Depends(get_session)]) -> ReturnDto:
|
||||
async def get_all_provider(session: Annotated[Session, Depends(get_session)]) -> Sequence[AiiProvider]:
|
||||
"""
|
||||
获取 AI 模型提供商列表。
|
||||
此接口无需用户登录即可访问。
|
||||
@@ -81,8 +125,7 @@ async def get_all_provider(session: Annotated[Session, Depends(get_session)]) ->
|
||||
Returns:
|
||||
被 ReturnDto 包裹的 AiiProvider 列表
|
||||
"""
|
||||
aii_providers = session.exec(select(AiiProvider)).all()
|
||||
return ReturnDto(result=[z_aii_provider(ap) for ap in aii_providers])
|
||||
return session.exec(select(AiiProvider)).all()
|
||||
|
||||
|
||||
@aii_router.post("/provider/", name="添加提供商")
|
||||
@@ -90,11 +133,11 @@ async def add_provider(
|
||||
provider: AiiProviderPublic,
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ReturnDto:
|
||||
) -> AiiProvider:
|
||||
"""
|
||||
添加新的 AI 模型提供商。
|
||||
此接口需要管理员才能访问。
|
||||
添加提供商时不会进行可用性检查,因此 WebUI 在前端实现了检查按钮。此端点不会负责检查。
|
||||
不会进行可用性检查,因此 WebUI 在前端实现了检查按钮。此端点不会负责检查。
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 表示用户未登录或非管理员。
|
||||
@@ -108,7 +151,47 @@ async def add_provider(
|
||||
session.add(ap)
|
||||
session.commit()
|
||||
session.refresh(ap)
|
||||
return ReturnDto(result=z_aii_provider(ap))
|
||||
return ap
|
||||
|
||||
|
||||
@aii_router.post("/provider/{id_}/", name="修改提供商")
|
||||
async def edit_provider(
|
||||
id_: int,
|
||||
provider: AiiProviderPublicWithoutKey,
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> AiiProvider:
|
||||
"""
|
||||
修改 AI 模型提供商。
|
||||
此接口需要管理员才能访问。
|
||||
不会进行可用性检查,因此 WebUI 在前端实现了检查按钮。此端点不会负责检查。
|
||||
**只允许修改模型提供商的名称和 Base URL。**
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 模型提供商 ID 不匹配。
|
||||
HTTPException: 401 表示用户未登录或非管理员。
|
||||
HTTPException: 404 提供商不存在。
|
||||
|
||||
Returns:
|
||||
被 ReturnDto 包裹的、添加的 AiiProvider
|
||||
"""
|
||||
if not user.is_admin:
|
||||
raise HTTPException(status_code=401, detail="用户无权限管理模型。") from None
|
||||
|
||||
if provider.id != id_:
|
||||
raise HTTPException(status_code=400, detail="模型提供商 ID 不匹配。") from None
|
||||
|
||||
try:
|
||||
ap: AiiProvider = session.exec(select(AiiProvider).where(AiiProvider.id == id_)).one()
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="提供商不存在。") from None
|
||||
|
||||
ap.name = provider.name
|
||||
ap.base_url = provider.base_url
|
||||
session.add(ap)
|
||||
session.commit()
|
||||
session.refresh(ap)
|
||||
return ap
|
||||
|
||||
|
||||
@aii_router.get("/provider/{id_}/remote/models/", name="获取提供商远端模型")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Sequence
|
||||
|
||||
import openai
|
||||
from openai import AsyncOpenAI
|
||||
from sqlalchemy.orm import joinedload
|
||||
@@ -16,17 +18,18 @@ def apply_get_models(session: Session) -> list[dict]:
|
||||
Returns:
|
||||
|
||||
"""
|
||||
aii_models = session.exec(select(AiiModel).options(joinedload(AiiModel.aii_provider))).all() # type: ignore[arg-type]
|
||||
aii_models: Sequence[AiiModel] = session.exec(select(AiiModel).options(joinedload(AiiModel.aii_provider))).all() # type: ignore[arg-type]
|
||||
|
||||
final_model_list = []
|
||||
for aii_model in aii_models:
|
||||
final_model_list.append({
|
||||
"id": aii_model.id,
|
||||
"model_name": aii_model.model_name,
|
||||
"max_content_length": aii_model.max_context_length,
|
||||
"provider_id": aii_model.id,
|
||||
"max_context_length": aii_model.max_context_length,
|
||||
"provider_id": aii_model.aii_provider_id,
|
||||
"provider_name": aii_model.aii_provider.name,
|
||||
"base_url": aii_model.aii_provider.base_url,
|
||||
"reasonable": bool(aii_model.reasonable), # 数据库中的 reasonable 字段可能为 None,在这里归一为 False
|
||||
})
|
||||
|
||||
return final_model_list
|
||||
|
||||
@@ -105,11 +105,12 @@ def apply_chat(id_: int, user_id: int, chat: ChatroomChat, session: Session) ->
|
||||
"api_key": model.aii_provider.api_key,
|
||||
"model_name": model.model_name,
|
||||
"messages": messages,
|
||||
"enable_thinking": chat.enable_thinking,
|
||||
}
|
||||
|
||||
|
||||
async def s_start_async_streaming_chat(
|
||||
base_url: str, api_key: str, model_name: str, messages: list
|
||||
base_url: str, api_key: str, model_name: str, messages: list, enable_thinking: bool
|
||||
) -> AsyncGenerator[str, None]:
|
||||
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||
stream = await client.chat.completions.create(
|
||||
@@ -117,6 +118,7 @@ async def s_start_async_streaming_chat(
|
||||
model=model_name,
|
||||
stream=True,
|
||||
reasoning_effort="high",
|
||||
extra_body={"thinking": {"type": "enabled" if enable_thinking else "disabled"}},
|
||||
)
|
||||
|
||||
# AI 说 SSE 好喵,推荐我用 SSE 喵,我不知道喵
|
||||
@@ -135,6 +137,7 @@ async def s_start_async_streaming_chat(
|
||||
yield f"data: {json.dumps({'text': cd, 'type': 'output'}, ensure_ascii=False)}\n\n"
|
||||
logger.info(f"AI 完成输出 : {aii_message}")
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
yield f"data: {json.dumps({'type': 'usage', **chuck.usage.model_dump()})}\n\n" # type: ignore[union-attr]
|
||||
finally:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
Reference in New Issue
Block a user