Files
NyaHome/src/nyahome/router/chatroom_router.py
T

302 lines
10 KiB
Python

import json
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.exc import NoResultFound
from sqlmodel import Session, select
from nyahome.database import (
Chatroom,
ChatroomChat,
ChatroomChatAccept,
ChatroomChatDelete,
ChatroomChatEdit,
ChatroomPublic,
ChatScript,
ModelUser,
get_session,
)
from nyahome.service.chat_service import (
apply_chat,
s_append_chatroom_content,
s_delete_chatroom_content,
s_edit_chatroom_content,
s_start_async_streaming_chat,
)
from .auth import verify_token
from .response_model import ReturnDto
chatroom_router = APIRouter(tags=["Chatroom"], prefix="/chatroom")
@chatroom_router.get("/{id_}/", name="获取指定聊天室")
async def get_chatroom(
id_: int, user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)]
) -> ReturnDto:
"""
根据 ID 获取聊天室。这里获取到的是完整的聊天室信息。
Returns:
聊天室对象。
Raises:
HTTPException: 404 未找到指定 ID 的聊天室。
"""
try:
cr: Chatroom = session.exec(
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
).one()
except NoResultFound:
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
else:
return ReturnDto(result=cr.model_dump())
@chatroom_router.get("/", name="获取聊天室列表")
async def get_all_chatroom(
user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)]
) -> ReturnDto:
"""
获取全部聊天室。这里获取到的是 public 简略聊天室信息,不包含 content 和 script 字段。
Returns:
包含全部聊天室的列表。
"""
crs = session.exec(select(Chatroom).where(Chatroom.creator_id == user.id)).all()
return ReturnDto(result=[cr.model_dump(exclude={"content", "script"}) for cr in crs])
@chatroom_router.post("/", name="创建聊天室")
async def create_chatroom(
chatroom: ChatroomPublic,
user: Annotated[ModelUser, Depends(verify_token)],
session: Annotated[Session, Depends(get_session)],
) -> ReturnDto:
"""
创建聊天室。
在请求体中提供聊天室信息,详请参阅 ChatroomPublic。注意,提供的 id 会被忽略。
Returns:
创建的聊天室对象,包含由数据库分配的 id。
"""
cr = Chatroom(
name=chatroom.name,
description=chatroom.description,
content="[]",
script="{}",
feature_image=chatroom.feature_image if chatroom.feature_image != "" else None,
script_template_id=chatroom.script_template_id,
creator_id=user.id,
)
session.add(cr)
session.commit()
session.refresh(cr)
return ReturnDto(result=cr.model_dump())
@chatroom_router.post("/{id_}/", name="修改指定聊天室")
async def edit_chatroom(
id_: int,
chatroom: ChatroomPublic,
user: Annotated[ModelUser, Depends(verify_token)],
session: Annotated[Session, Depends(get_session)],
) -> ReturnDto:
"""
修改聊天室的基本信息。
content 和 script 需要从各自的独立端点请求修改,不包含在本端点的负责范围内。
Args:
id_: 聊天室 ID
chatroom: 聊天室基本信息,类型为 ChatroomPublic。注意 id 不可更改,如提供则会被忽略
user: 用户
session: 数据库连接对象
Raises:
HTTPException: 404 表示未找到聊天室
Returns:
修改过的聊天室对象,供前端更新。
"""
try:
cr: Chatroom = session.exec(
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
).one()
except NoResultFound:
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
cr.name = chatroom.name
cr.description = chatroom.description
if chatroom.feature_image != "":
cr.feature_image = chatroom.feature_image
cr.script_template_id = chatroom.script_template_id
cr.script_template_version = chatroom.script_template_version
cr.default_model_id = chatroom.default_model_id
session.add(cr)
session.commit()
session.refresh(cr)
return ReturnDto(result=cr.model_dump())
@chatroom_router.post("/{id_}/script/", name="修改聊天室脚本")
async def update_chatroom_script(
id_: int,
script: ChatScript,
user: Annotated[ModelUser, Depends(verify_token)],
session: Annotated[Session, Depends(get_session)],
) -> ReturnDto:
"""
更新聊天室的剧本(提示词与世界书)。
Args:
id_: 聊天室 ID
script: 剧本
user: 用户
session: 数据库连接对象
Raises:
HTTPException: 404 表示未找到聊天室
Returns:
result 字段包含最新的剧本。
"""
try:
cr: Chatroom = session.exec(
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
).one()
except NoResultFound:
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
cr.script = json.dumps(script.model_dump(), ensure_ascii=False)
session.add(cr)
session.commit()
session.refresh(cr)
return ReturnDto(result=script.model_dump())
@chatroom_router.post("/{id_}/chat/", name="聊天室发起模型创作")
async def post_chatroom_chat(
id_: int,
chat: ChatroomChat,
user: Annotated[ModelUser, Depends(verify_token)],
session: Annotated[Session, Depends(get_session)],
) -> StreamingResponse:
"""
在聊天室中发送新的用户消息,流式返回 AI 调用结果。
即:调用模型发起创作。
Args:
id_: (路径参数)聊天室 ID
chat: 用户消息
user: 用户
session: 数据库连接对象
Raises:
HTTPException: 404 表示聊天室未找到,444 表示模型未找到。
Returns:
SSE 流式输出,实质上相当于转发 AI 的流式输出结果。
"""
try:
return StreamingResponse(
s_start_async_streaming_chat(**apply_chat(id_, user.id, chat, session)),
media_type="text/event-stream",
)
except HTTPException as e:
raise e
@chatroom_router.post("/{id_}/chat/accept/", name="聊天室保存模型创作")
async def accept_chatroom_chat(
id_: int,
accept: ChatroomChatAccept,
user: Annotated[ModelUser, Depends(verify_token)],
session: Annotated[Session, Depends(get_session)],
) -> ReturnDto:
"""
此端点不负责调用 AI 生成输出,而是用于保存一对用户消息和 AI 输出到聊天室 content 的最后。
需要提供用户消息、AI 消息和创作模式。
Raises:
HTTPException: 404 表明未找到聊天室。
Returns:
ReturnDto,其中 result 字段是该聊天室的最新 content,以供前端刷新。
"""
try:
cr: Chatroom = session.exec(
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
).one()
except NoResultFound:
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
cr.content = s_append_chatroom_content(cr.content, accept)
session.add(cr)
session.commit()
session.refresh(cr)
return ReturnDto(result=cr.model_dump())
@chatroom_router.post("/{id_}/chat/edit/", name="聊天室编辑消息")
async def edit_chatroom_chat(
id_: int,
edit: ChatroomChatEdit,
user: Annotated[ModelUser, Depends(verify_token)],
session: Annotated[Session, Depends(get_session)],
) -> ReturnDto:
"""
此端点不负责调用 AI 生成输出,而是用于修改一条已经保存在聊天记录中的消息。
需要提供消息类型(用户/AI)、旧消息和新消息,以便进行替换。
Raises:
HTTPException: 404 表明未找到聊天室,400 表明聊天记录匹配失败,未更新。
Returns:
ReturnDto,其中 result 字段是该聊天室的最新 content,以供前端刷新。
"""
try:
cr: Chatroom = session.exec(
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
).one()
except NoResultFound:
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
try:
cr.content = s_edit_chatroom_content(cr.content, edit)
session.add(cr)
session.commit()
session.refresh(cr)
return ReturnDto(result=cr.model_dump())
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
@chatroom_router.post("/{id_}/chat/delete/", name="聊天室删除消息")
async def delete_chatroom_chat(
id_: int,
delete: ChatroomChatDelete,
user: Annotated[ModelUser, Depends(verify_token)],
session: Annotated[Session, Depends(get_session)],
) -> ReturnDto:
"""
此端点不负责调用 AI 生成输出,而是用于删除一条已经保存在聊天记录中的消息。关联的 user 或 aii 消息会一并删除。
需要提供消息和消息类型(用户/AI)。用户消息和 AI 消息是一对一成对的,所以总是会删除关联的一对(两条)消息。
Raises:
HTTPException: 404 表明未找到聊天室,400 表明聊天记录匹配失败,未更新。
Returns:
ReturnDto,其中 result 字段是该聊天室的最新 content,以供前端刷新。
"""
try:
cr: Chatroom = session.exec(
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
).one()
except NoResultFound:
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
try:
cr.content = s_delete_chatroom_content(cr.content, delete)
session.add(cr)
session.commit()
session.refresh(cr)
return ReturnDto(result=cr.model_dump())
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e