import json from typing import Annotated from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from sqlalchemy.exc import NoResultFound from sqlmodel import Session, select from nyahome.database import ( Chatroom, ChatroomChat, ChatroomChatAccept, ChatroomChatDelete, ChatroomChatEdit, ChatroomPublic, ChatScript, ModelUser, get_session, ) from nyahome.service.chat_service import ( apply_chat, s_append_chatroom_content, s_delete_chatroom_content, s_edit_chatroom_content, s_start_async_streaming_chat, ) from .auth import verify_token from .response_model import ReturnDto chatroom_router = APIRouter(tags=["Chatroom"], prefix="/chatroom") @chatroom_router.get("/{id_}/", name="获取指定聊天室") async def get_chatroom( id_: int, user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)] ) -> ReturnDto: """ 根据 ID 获取聊天室。这里获取到的是完整的聊天室信息。 Returns: 聊天室对象。 Raises: HTTPException: 404 未找到指定 ID 的聊天室。 """ try: cr: Chatroom = session.exec( select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id) ).one() except NoResultFound: raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None else: return ReturnDto(result=cr.model_dump()) @chatroom_router.get("/", name="获取聊天室列表") async def get_all_chatroom( user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)] ) -> ReturnDto: """ 获取全部聊天室。这里获取到的是 public 简略聊天室信息,不包含 content 和 script 字段。 Returns: 包含全部聊天室的列表。 """ crs = session.exec(select(Chatroom).where(Chatroom.creator_id == user.id)).all() return ReturnDto(result=[cr.model_dump(exclude={"content", "script"}) for cr in crs]) @chatroom_router.post("/", name="创建聊天室") async def create_chatroom( chatroom: ChatroomPublic, user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)], ) -> ReturnDto: """ 创建聊天室。 在请求体中提供聊天室信息,详请参阅 ChatroomPublic。注意,提供的 id 会被忽略。 Returns: 创建的聊天室对象,包含由数据库分配的 id。 """ cr = Chatroom( name=chatroom.name, description=chatroom.description, content="[]", script="{}", feature_image=chatroom.feature_image if chatroom.feature_image != "" else None, script_template_id=chatroom.script_template_id, creator_id=user.id, ) session.add(cr) session.commit() session.refresh(cr) return ReturnDto(result=cr.model_dump()) @chatroom_router.post("/{id_}/", name="修改指定聊天室") async def edit_chatroom( id_: int, chatroom: ChatroomPublic, user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)], ) -> ReturnDto: """ 修改聊天室的基本信息。 content 和 script 需要从各自的独立端点请求修改,不包含在本端点的负责范围内。 Args: id_: 聊天室 ID chatroom: 聊天室基本信息,类型为 ChatroomPublic。注意 id 不可更改,如提供则会被忽略 user: 用户 session: 数据库连接对象 Raises: HTTPException: 404 表示未找到聊天室 Returns: 修改过的聊天室对象,供前端更新。 """ try: cr: Chatroom = session.exec( select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id) ).one() except NoResultFound: raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None cr.name = chatroom.name cr.description = chatroom.description if chatroom.feature_image != "": cr.feature_image = chatroom.feature_image cr.script_template_id = chatroom.script_template_id cr.script_template_version = chatroom.script_template_version cr.default_model_id = chatroom.default_model_id session.add(cr) session.commit() session.refresh(cr) return ReturnDto(result=cr.model_dump()) @chatroom_router.post("/{id_}/script/", name="修改聊天室脚本") async def update_chatroom_script( id_: int, script: ChatScript, user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)], ) -> ReturnDto: """ 更新聊天室的剧本(提示词与世界书)。 Args: id_: 聊天室 ID script: 剧本 user: 用户 session: 数据库连接对象 Raises: HTTPException: 404 表示未找到聊天室 Returns: result 字段包含最新的剧本。 """ try: cr: Chatroom = session.exec( select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id) ).one() except NoResultFound: raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None cr.script = json.dumps(script.model_dump(), ensure_ascii=False) session.add(cr) session.commit() session.refresh(cr) return ReturnDto(result=script.model_dump()) @chatroom_router.post("/{id_}/chat/", name="聊天室发起模型创作") async def post_chatroom_chat( id_: int, chat: ChatroomChat, user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)], ) -> StreamingResponse: """ 在聊天室中发送新的用户消息,流式返回 AI 调用结果。 即:调用模型发起创作。 Args: id_: (路径参数)聊天室 ID chat: 用户消息 user: 用户 session: 数据库连接对象 Raises: HTTPException: 404 表示聊天室未找到,444 表示模型未找到。 Returns: SSE 流式输出,实质上相当于转发 AI 的流式输出结果。 """ try: return StreamingResponse( s_start_async_streaming_chat(**apply_chat(id_, user.id, chat, session)), media_type="text/event-stream", ) except HTTPException as e: raise e @chatroom_router.post("/{id_}/chat/accept/", name="聊天室保存模型创作") async def accept_chatroom_chat( id_: int, accept: ChatroomChatAccept, user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)], ) -> ReturnDto: """ 此端点不负责调用 AI 生成输出,而是用于保存一对用户消息和 AI 输出到聊天室 content 的最后。 需要提供用户消息、AI 消息和创作模式。 Raises: HTTPException: 404 表明未找到聊天室。 Returns: ReturnDto,其中 result 字段是该聊天室的最新 content,以供前端刷新。 """ try: cr: Chatroom = session.exec( select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id) ).one() except NoResultFound: raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None cr.content = s_append_chatroom_content(cr.content, accept) session.add(cr) session.commit() session.refresh(cr) return ReturnDto(result=cr.model_dump()) @chatroom_router.post("/{id_}/chat/edit/", name="聊天室编辑消息") async def edit_chatroom_chat( id_: int, edit: ChatroomChatEdit, user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)], ) -> ReturnDto: """ 此端点不负责调用 AI 生成输出,而是用于修改一条已经保存在聊天记录中的消息。 需要提供消息类型(用户/AI)、旧消息和新消息,以便进行替换。 Raises: HTTPException: 404 表明未找到聊天室,400 表明聊天记录匹配失败,未更新。 Returns: ReturnDto,其中 result 字段是该聊天室的最新 content,以供前端刷新。 """ try: cr: Chatroom = session.exec( select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id) ).one() except NoResultFound: raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None try: cr.content = s_edit_chatroom_content(cr.content, edit) session.add(cr) session.commit() session.refresh(cr) return ReturnDto(result=cr.model_dump()) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) from e @chatroom_router.post("/{id_}/chat/delete/", name="聊天室删除消息") async def delete_chatroom_chat( id_: int, delete: ChatroomChatDelete, user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)], ) -> ReturnDto: """ 此端点不负责调用 AI 生成输出,而是用于删除一条已经保存在聊天记录中的消息。关联的 user 或 aii 消息会一并删除。 需要提供消息和消息类型(用户/AI)。用户消息和 AI 消息是一对一成对的,所以总是会删除关联的一对(两条)消息。 Raises: HTTPException: 404 表明未找到聊天室,400 表明聊天记录匹配失败,未更新。 Returns: ReturnDto,其中 result 字段是该聊天室的最新 content,以供前端刷新。 """ try: cr: Chatroom = session.exec( select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user.id) ).one() except NoResultFound: raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None try: cr.content = s_delete_chatroom_content(cr.content, delete) session.add(cr) session.commit() session.refresh(cr) return ReturnDto(result=cr.model_dump()) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) from e