302 lines
10 KiB
Python
302 lines
10 KiB
Python
import json
|
|
from typing import Annotated
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from sqlalchemy.exc import NoResultFound
|
|
from sqlmodel import Session, select
|
|
|
|
from nyahome.database import (
|
|
Chatroom,
|
|
ChatroomChat,
|
|
ChatroomChatAccept,
|
|
ChatroomChatDelete,
|
|
ChatroomChatEdit,
|
|
ChatroomPublic,
|
|
ChatScript,
|
|
ModelUser,
|
|
get_session,
|
|
)
|
|
from nyahome.service.chat_service import (
|
|
apply_chat,
|
|
s_append_chatroom_content,
|
|
s_delete_chatroom_content,
|
|
s_edit_chatroom_content,
|
|
s_start_async_streaming_chat,
|
|
)
|
|
|
|
from .auth import verify_token
|
|
from .response_model import ReturnDto
|
|
|
|
chatroom_router = APIRouter(tags=["Chatroom"], prefix="/chatroom")
|
|
|
|
|
|
@chatroom_router.get("/{id_}/", name="获取指定聊天室")
|
|
async def get_chatroom(
|
|
id_: int, user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)]
|
|
) -> ReturnDto:
|
|
"""
|
|
根据 ID 获取聊天室。这里获取到的是完整的聊天室信息。
|
|
|
|
Returns:
|
|
聊天室对象。
|
|
|
|
Raises:
|
|
HTTPException: 404 未找到指定 ID 的聊天室。
|
|
"""
|
|
try:
|
|
cr: Chatroom = session.exec(
|
|
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
|
|
).one()
|
|
except NoResultFound:
|
|
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
|
|
else:
|
|
return ReturnDto(result=cr.model_dump())
|
|
|
|
|
|
@chatroom_router.get("/", name="获取聊天室列表")
|
|
async def get_all_chatroom(
|
|
user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)]
|
|
) -> ReturnDto:
|
|
"""
|
|
获取全部聊天室。这里获取到的是 public 简略聊天室信息,不包含 content 和 script 字段。
|
|
|
|
Returns:
|
|
包含全部聊天室的列表。
|
|
"""
|
|
crs = session.exec(select(Chatroom).where(Chatroom.creator_id == user.id)).all()
|
|
return ReturnDto(result=[cr.model_dump(exclude={"content", "script"}) for cr in crs])
|
|
|
|
|
|
@chatroom_router.post("/", name="创建聊天室")
|
|
async def create_chatroom(
|
|
chatroom: ChatroomPublic,
|
|
user: Annotated[ModelUser, Depends(verify_token)],
|
|
session: Annotated[Session, Depends(get_session)],
|
|
) -> ReturnDto:
|
|
"""
|
|
创建聊天室。
|
|
在请求体中提供聊天室信息,详请参阅 ChatroomPublic。注意,提供的 id 会被忽略。
|
|
|
|
Returns:
|
|
创建的聊天室对象,包含由数据库分配的 id。
|
|
"""
|
|
cr = Chatroom(
|
|
name=chatroom.name,
|
|
description=chatroom.description,
|
|
content="[]",
|
|
script="{}",
|
|
feature_image=chatroom.feature_image if chatroom.feature_image != "" else None,
|
|
script_template_id=chatroom.script_template_id,
|
|
creator_id=user.id,
|
|
)
|
|
session.add(cr)
|
|
session.commit()
|
|
session.refresh(cr)
|
|
return ReturnDto(result=cr.model_dump())
|
|
|
|
|
|
@chatroom_router.post("/{id_}/", name="修改指定聊天室")
|
|
async def edit_chatroom(
|
|
id_: int,
|
|
chatroom: ChatroomPublic,
|
|
user: Annotated[ModelUser, Depends(verify_token)],
|
|
session: Annotated[Session, Depends(get_session)],
|
|
) -> ReturnDto:
|
|
"""
|
|
修改聊天室的基本信息。
|
|
content 和 script 需要从各自的独立端点请求修改,不包含在本端点的负责范围内。
|
|
|
|
Args:
|
|
id_: 聊天室 ID
|
|
chatroom: 聊天室基本信息,类型为 ChatroomPublic。注意 id 不可更改,如提供则会被忽略
|
|
user: 用户
|
|
session: 数据库连接对象
|
|
|
|
Raises:
|
|
HTTPException: 404 表示未找到聊天室
|
|
|
|
Returns:
|
|
修改过的聊天室对象,供前端更新。
|
|
"""
|
|
try:
|
|
cr: Chatroom = session.exec(
|
|
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
|
|
).one()
|
|
except NoResultFound:
|
|
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
|
|
cr.name = chatroom.name
|
|
cr.description = chatroom.description
|
|
if chatroom.feature_image != "":
|
|
cr.feature_image = chatroom.feature_image
|
|
cr.script_template_id = chatroom.script_template_id
|
|
cr.script_template_version = chatroom.script_template_version
|
|
cr.default_model_id = chatroom.default_model_id
|
|
session.add(cr)
|
|
session.commit()
|
|
session.refresh(cr)
|
|
return ReturnDto(result=cr.model_dump())
|
|
|
|
|
|
@chatroom_router.post("/{id_}/script/", name="修改聊天室脚本")
|
|
async def update_chatroom_script(
|
|
id_: int,
|
|
script: ChatScript,
|
|
user: Annotated[ModelUser, Depends(verify_token)],
|
|
session: Annotated[Session, Depends(get_session)],
|
|
) -> ReturnDto:
|
|
"""
|
|
更新聊天室的剧本(提示词与世界书)。
|
|
|
|
Args:
|
|
id_: 聊天室 ID
|
|
script: 剧本
|
|
user: 用户
|
|
session: 数据库连接对象
|
|
|
|
Raises:
|
|
HTTPException: 404 表示未找到聊天室
|
|
|
|
Returns:
|
|
result 字段包含最新的剧本。
|
|
"""
|
|
try:
|
|
cr: Chatroom = session.exec(
|
|
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
|
|
).one()
|
|
except NoResultFound:
|
|
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
|
|
cr.script = json.dumps(script.model_dump(), ensure_ascii=False)
|
|
session.add(cr)
|
|
session.commit()
|
|
session.refresh(cr)
|
|
return ReturnDto(result=script.model_dump())
|
|
|
|
|
|
@chatroom_router.post("/{id_}/chat/", name="聊天室发起模型创作")
|
|
async def post_chatroom_chat(
|
|
id_: int,
|
|
chat: ChatroomChat,
|
|
user: Annotated[ModelUser, Depends(verify_token)],
|
|
session: Annotated[Session, Depends(get_session)],
|
|
) -> StreamingResponse:
|
|
"""
|
|
在聊天室中发送新的用户消息,流式返回 AI 调用结果。
|
|
即:调用模型发起创作。
|
|
|
|
Args:
|
|
id_: (路径参数)聊天室 ID
|
|
chat: 用户消息
|
|
user: 用户
|
|
session: 数据库连接对象
|
|
|
|
Raises:
|
|
HTTPException: 404 表示聊天室未找到,444 表示模型未找到。
|
|
|
|
Returns:
|
|
SSE 流式输出,实质上相当于转发 AI 的流式输出结果。
|
|
"""
|
|
try:
|
|
return StreamingResponse(
|
|
s_start_async_streaming_chat(**apply_chat(id_, user.id, chat, session)),
|
|
media_type="text/event-stream",
|
|
)
|
|
except HTTPException as e:
|
|
raise e
|
|
|
|
|
|
@chatroom_router.post("/{id_}/chat/accept/", name="聊天室保存模型创作")
|
|
async def accept_chatroom_chat(
|
|
id_: int,
|
|
accept: ChatroomChatAccept,
|
|
user: Annotated[ModelUser, Depends(verify_token)],
|
|
session: Annotated[Session, Depends(get_session)],
|
|
) -> ReturnDto:
|
|
"""
|
|
此端点不负责调用 AI 生成输出,而是用于保存一对用户消息和 AI 输出到聊天室 content 的最后。
|
|
需要提供用户消息、AI 消息和创作模式。
|
|
|
|
Raises:
|
|
HTTPException: 404 表明未找到聊天室。
|
|
|
|
Returns:
|
|
ReturnDto,其中 result 字段是该聊天室的最新 content,以供前端刷新。
|
|
"""
|
|
try:
|
|
cr: Chatroom = session.exec(
|
|
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
|
|
).one()
|
|
except NoResultFound:
|
|
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
|
|
cr.content = s_append_chatroom_content(cr.content, accept)
|
|
session.add(cr)
|
|
session.commit()
|
|
session.refresh(cr)
|
|
return ReturnDto(result=cr.model_dump())
|
|
|
|
|
|
@chatroom_router.post("/{id_}/chat/edit/", name="聊天室编辑消息")
|
|
async def edit_chatroom_chat(
|
|
id_: int,
|
|
edit: ChatroomChatEdit,
|
|
user: Annotated[ModelUser, Depends(verify_token)],
|
|
session: Annotated[Session, Depends(get_session)],
|
|
) -> ReturnDto:
|
|
"""
|
|
此端点不负责调用 AI 生成输出,而是用于修改一条已经保存在聊天记录中的消息。
|
|
需要提供消息类型(用户/AI)、旧消息和新消息,以便进行替换。
|
|
|
|
Raises:
|
|
HTTPException: 404 表明未找到聊天室,400 表明聊天记录匹配失败,未更新。
|
|
|
|
Returns:
|
|
ReturnDto,其中 result 字段是该聊天室的最新 content,以供前端刷新。
|
|
"""
|
|
try:
|
|
cr: Chatroom = session.exec(
|
|
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
|
|
).one()
|
|
except NoResultFound:
|
|
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
|
|
try:
|
|
cr.content = s_edit_chatroom_content(cr.content, edit)
|
|
session.add(cr)
|
|
session.commit()
|
|
session.refresh(cr)
|
|
return ReturnDto(result=cr.model_dump())
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e)) from e
|
|
|
|
|
|
@chatroom_router.post("/{id_}/chat/delete/", name="聊天室删除消息")
|
|
async def delete_chatroom_chat(
|
|
id_: int,
|
|
delete: ChatroomChatDelete,
|
|
user: Annotated[ModelUser, Depends(verify_token)],
|
|
session: Annotated[Session, Depends(get_session)],
|
|
) -> ReturnDto:
|
|
"""
|
|
此端点不负责调用 AI 生成输出,而是用于删除一条已经保存在聊天记录中的消息。关联的 user 或 aii 消息会一并删除。
|
|
需要提供消息和消息类型(用户/AI)。用户消息和 AI 消息是一对一成对的,所以总是会删除关联的一对(两条)消息。
|
|
|
|
Raises:
|
|
HTTPException: 404 表明未找到聊天室,400 表明聊天记录匹配失败,未更新。
|
|
|
|
Returns:
|
|
ReturnDto,其中 result 字段是该聊天室的最新 content,以供前端刷新。
|
|
"""
|
|
try:
|
|
cr: Chatroom = session.exec(
|
|
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
|
|
).one()
|
|
except NoResultFound:
|
|
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
|
|
try:
|
|
cr.content = s_delete_chatroom_content(cr.content, delete)
|
|
session.add(cr)
|
|
session.commit()
|
|
session.refresh(cr)
|
|
return ReturnDto(result=cr.model_dump())
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e)) from e
|