refactor: 主要功能实现
目前的工作已经实现的功能: - 基本 FastAPI 路由; - 基本 AI 聊天和创作功能; - 用户信息管理、权限验证、JWT 令牌签发和验证、端点保护; - HTML 验证码邮件发送和验证码验证。
This commit is contained in:
@@ -0,0 +1,296 @@
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from nyahome.database import (
|
||||
Chatroom,
|
||||
ChatroomChat,
|
||||
ChatroomChatAccept,
|
||||
ChatroomChatDelete,
|
||||
ChatroomChatEdit,
|
||||
ChatroomPublic,
|
||||
ChatScript,
|
||||
ModelUser,
|
||||
get_session,
|
||||
)
|
||||
from nyahome.service.chat_service import (
|
||||
apply_chat,
|
||||
s_append_chatroom_content,
|
||||
s_delete_chatroom_content,
|
||||
s_edit_chatroom_content,
|
||||
s_start_async_streaming_chat,
|
||||
)
|
||||
|
||||
from .auth import verify_token
|
||||
from .response_model import ReturnDto
|
||||
|
||||
chatroom_router = APIRouter(tags=["Chatroom"], prefix="/chatroom")
|
||||
|
||||
|
||||
@chatroom_router.get("/{id_}/")
|
||||
async def get_chatroom(
|
||||
id_: int, user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)]
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
根据 ID 获取聊天室。这里获取到的是完整的聊天室信息。
|
||||
|
||||
Returns:
|
||||
聊天室对象。
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 未找到指定 ID 的聊天室。
|
||||
"""
|
||||
try:
|
||||
cr: Chatroom = session.exec(
|
||||
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
|
||||
).one()
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
|
||||
else:
|
||||
return ReturnDto(result=cr.model_dump())
|
||||
|
||||
|
||||
@chatroom_router.get("/")
|
||||
async def get_all_chatroom(
|
||||
user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)]
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
获取全部聊天室。这里获取到的是 public 简略聊天室信息,不包含 content 和 script 字段。
|
||||
|
||||
Returns:
|
||||
包含全部聊天室的列表。
|
||||
"""
|
||||
crs = session.exec(select(Chatroom).where(Chatroom.creator_id == user.id)).all()
|
||||
return ReturnDto(result=[cr.model_dump(exclude={"content", "script"}) for cr in crs])
|
||||
|
||||
|
||||
@chatroom_router.post("/")
|
||||
async def create_chatroom(
|
||||
chatroom: ChatroomPublic,
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
创建聊天室。
|
||||
在请求体中提供聊天室信息,详请参阅 ChatroomPublic。注意,提供的 id 会被忽略。
|
||||
|
||||
Returns:
|
||||
创建的聊天室对象,包含由数据库分配的 id。
|
||||
"""
|
||||
cr = Chatroom(
|
||||
name=chatroom.name,
|
||||
description=chatroom.description,
|
||||
content="[]",
|
||||
script="{}",
|
||||
feature_image=chatroom.feature_image if chatroom.feature_image != "" else None,
|
||||
script_template_id=chatroom.script_template_id,
|
||||
creator_id=user.id,
|
||||
)
|
||||
session.add(cr)
|
||||
session.commit()
|
||||
session.refresh(cr)
|
||||
return ReturnDto(result=cr.model_dump())
|
||||
|
||||
|
||||
@chatroom_router.post("/{id_}/")
|
||||
async def edit_chatroom(
|
||||
id_: int,
|
||||
chatroom: ChatroomPublic,
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
修改聊天室的基本信息。
|
||||
content 和 script 需要从各自的独立端点请求修改,不包含在本端点的负责范围内。
|
||||
|
||||
Args:
|
||||
id_: 聊天室 ID
|
||||
chatroom: 聊天室基本信息,类型为 ChatroomPublic。注意 id 不可更改,如提供则会被忽略
|
||||
user: 用户
|
||||
session: 数据库连接对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 表示未找到聊天室
|
||||
|
||||
Returns:
|
||||
修改过的聊天室对象,供前端更新。
|
||||
"""
|
||||
try:
|
||||
cr: Chatroom = session.exec(
|
||||
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
|
||||
).one()
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
|
||||
cr.name = chatroom.name
|
||||
cr.description = chatroom.description
|
||||
if chatroom.feature_image != "":
|
||||
cr.feature_image = chatroom.feature_image
|
||||
cr.script_template_id = chatroom.script_template_id
|
||||
cr.script_template_version = chatroom.script_template_version
|
||||
session.add(cr)
|
||||
session.commit()
|
||||
session.refresh(cr)
|
||||
return ReturnDto(result=cr.model_dump())
|
||||
|
||||
|
||||
@chatroom_router.post("/{id_}/script/")
|
||||
async def update_chatroom_script(
|
||||
id_: int,
|
||||
script: ChatScript,
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
更新聊天室的剧本(提示词与世界书)。
|
||||
|
||||
Args:
|
||||
id_: 聊天室 ID
|
||||
script: 剧本
|
||||
user: 用户
|
||||
session: 数据库连接对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 表示未找到聊天室
|
||||
|
||||
Returns:
|
||||
result 字段包含最新的剧本。
|
||||
"""
|
||||
try:
|
||||
cr: Chatroom = session.exec(
|
||||
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
|
||||
).one()
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
|
||||
cr.script = json.dumps(script.model_dump(), ensure_ascii=False)
|
||||
session.add(cr)
|
||||
session.commit()
|
||||
session.refresh(cr)
|
||||
return ReturnDto(result=script.model_dump())
|
||||
|
||||
|
||||
@chatroom_router.post("/{id_}/chat/")
|
||||
async def post_chatroom_chat(
|
||||
id_: int,
|
||||
chat: ChatroomChat,
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
在聊天室中发送新的用户消息,流式返回 AI 调用结果。
|
||||
|
||||
Args:
|
||||
id_: (路径参数)聊天室 ID
|
||||
chat: 用户消息
|
||||
user: 用户
|
||||
session: 数据库连接对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 表示聊天室未找到,444 表示模型未找到。
|
||||
|
||||
Returns:
|
||||
SSE 流式输出,实质上相当于转发 AI 的流式输出结果。
|
||||
"""
|
||||
try:
|
||||
return StreamingResponse(
|
||||
s_start_async_streaming_chat(**apply_chat(id_, user.id, chat, session)),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
|
||||
|
||||
@chatroom_router.post("/{id_}/chat/accept/")
|
||||
async def accept_chatroom_chat(
|
||||
id_: int,
|
||||
accept: ChatroomChatAccept,
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
此端点不负责调用 AI 生成输出,而是用于保存一对用户消息和 AI 输出到聊天室 content 的最后。
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 表明未找到聊天室。
|
||||
|
||||
Returns:
|
||||
ReturnDto,其中 result 字段是该聊天室的最新 content,以供前端刷新。
|
||||
"""
|
||||
try:
|
||||
cr: Chatroom = session.exec(
|
||||
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
|
||||
).one()
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
|
||||
cr.content = s_append_chatroom_content(cr.content, accept)
|
||||
session.add(cr)
|
||||
session.commit()
|
||||
session.refresh(cr)
|
||||
return ReturnDto(result=cr.model_dump())
|
||||
|
||||
|
||||
@chatroom_router.post("/{id_}/chat/edit/")
|
||||
async def edit_chatroom_chat(
|
||||
id_: int,
|
||||
edit: ChatroomChatEdit,
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
此端点不负责调用 AI 生成输出,而是用于修改一条已经保存在聊天记录中的消息。
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 表明未找到聊天室,400 表明聊天记录匹配失败,未更新。
|
||||
|
||||
Returns:
|
||||
ReturnDto,其中 result 字段是该聊天室的最新 content,以供前端刷新。
|
||||
"""
|
||||
try:
|
||||
cr: Chatroom = session.exec(
|
||||
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
|
||||
).one()
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
|
||||
try:
|
||||
cr.content = s_edit_chatroom_content(cr.content, edit)
|
||||
session.add(cr)
|
||||
session.commit()
|
||||
session.refresh(cr)
|
||||
return ReturnDto(result=cr.model_dump())
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
|
||||
@chatroom_router.post("/{id_}/chat/delete/")
|
||||
async def delete_chatroom_chat(
|
||||
id_: int,
|
||||
delete: ChatroomChatDelete,
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
此端点不负责调用 AI 生成输出,而是用于删除一条已经保存在聊天记录中的消息。关联的 user 或 aii 消息会一并删除。
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 表明未找到聊天室,400 表明聊天记录匹配失败,未更新。
|
||||
|
||||
Returns:
|
||||
ReturnDto,其中 result 字段是该聊天室的最新 content,以供前端刷新。
|
||||
"""
|
||||
try:
|
||||
cr: Chatroom = session.exec(
|
||||
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id)
|
||||
).one()
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
|
||||
try:
|
||||
cr.content = s_delete_chatroom_content(cr.content, delete)
|
||||
session.add(cr)
|
||||
session.commit()
|
||||
session.refresh(cr)
|
||||
return ReturnDto(result=cr.model_dump())
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
Reference in New Issue
Block a user