refactor: 主要功能实现
目前的工作已经实现的功能: - 基本 FastAPI 路由; - 基本 AI 聊天和创作功能; - 用户信息管理、权限验证、JWT 令牌签发和验证、端点保护; - HTML 验证码邮件发送和验证码验证。
This commit is contained in:
@@ -0,0 +1,58 @@
|
||||
import openai
|
||||
from openai import AsyncOpenAI
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from nyahome.database import AiiModel
|
||||
|
||||
|
||||
def apply_get_models(session: Session) -> list[dict]:
|
||||
"""
|
||||
从数据库中获取可用的 AI 模型列表。
|
||||
|
||||
Args:
|
||||
session: 数据库连接对象。
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
aii_models = session.exec(select(AiiModel).options(joinedload(AiiModel.aii_provider))).all() # type: ignore[arg-type]
|
||||
|
||||
final_model_list = []
|
||||
for aii_model in aii_models:
|
||||
final_model_list.append({
|
||||
"id": aii_model.id,
|
||||
"model_name": aii_model.model_name,
|
||||
"max_content_length": aii_model.max_context_length,
|
||||
"provider_id": aii_model.id,
|
||||
"provider_name": aii_model.aii_provider.name,
|
||||
"base_url": aii_model.aii_provider.base_url,
|
||||
})
|
||||
|
||||
return final_model_list
|
||||
|
||||
|
||||
async def s_list_remote_provider_models(base_url: str, api_key: str) -> list[dict]:
|
||||
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||
try:
|
||||
models = await client.models.list()
|
||||
final_model_list = []
|
||||
async for model in models:
|
||||
# model 实际上是 pydantic 模型,因此拥有 BaseModel 的所有方法。
|
||||
# model.model_dump() 的示例结果:
|
||||
# {'id': 'xxx', 'created': None, 'object': 'model', 'owned_by': 'xxx'}
|
||||
final_model_list.append(model.model_dump())
|
||||
return final_model_list
|
||||
except Exception as e:
|
||||
raise TypeError(f"获取模型提供商 {base_url} 的可用模型列表失败。") from e
|
||||
|
||||
|
||||
async def s_check_remote_model(model_name: str, base_url: str, api_key: str) -> bool:
|
||||
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||
try:
|
||||
await client.models.retrieve(model_name)
|
||||
return True
|
||||
except openai.NotFoundError:
|
||||
return False
|
||||
except Exception as e:
|
||||
raise TypeError(f"从模型提供商 {base_url} 检测模型 {model_name} 可用性时遇到未知错误") from e
|
||||
@@ -0,0 +1,208 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from nyahome.database import (
|
||||
AiiModel,
|
||||
Chatroom,
|
||||
ChatroomChat,
|
||||
ChatroomChatAccept,
|
||||
ChatroomChatDelete,
|
||||
ChatroomChatEdit,
|
||||
ChatScript,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ContentList = list[
|
||||
dict[
|
||||
Literal[
|
||||
"role",
|
||||
"message",
|
||||
"mode",
|
||||
],
|
||||
str,
|
||||
]
|
||||
]
|
||||
|
||||
CONTINUE_MESSAGE = (
|
||||
"推进模式:用户输入的情节已经发生,请续写接下来的故事,注意并非扩写用户的输入;"
|
||||
"注意细节描写,情节合理,符合故事设定。"
|
||||
)
|
||||
EXPAND_MESSAGE = (
|
||||
"扩写模式:请仅以用户输入为故事情节,将情节扩写至指定字数,避免自行续写更多未提及的故事。"
|
||||
"注意细节描写,情节合理,符合故事设定。"
|
||||
)
|
||||
|
||||
|
||||
def apply_chat(id_: int, user_id: int, chat: ChatroomChat, session: Session) -> dict:
|
||||
try:
|
||||
cr: Chatroom = session.exec(
|
||||
select(Chatroom).where(Chatroom.id == id_).where(Chatroom.creator_id == user_id)
|
||||
).one()
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="未找到指定 ID 的聊天室。") from None
|
||||
try:
|
||||
model: AiiModel = session.exec(
|
||||
select(AiiModel).where(AiiModel.id == chat.model_id).options(joinedload(AiiModel.aii_provider)) # type: ignore[arg-type]
|
||||
).one()
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=444, detail="未找到指定 ID 的模型") from None
|
||||
|
||||
# 拼接历史聊天消息
|
||||
content_list: ContentList = json.loads(cr.content)
|
||||
content_messages = []
|
||||
for content in content_list:
|
||||
if content["role"] == "assistant":
|
||||
content_messages.append({
|
||||
"role": "assistant",
|
||||
"content": content["message"],
|
||||
})
|
||||
else:
|
||||
content_messages.append({
|
||||
"role": "user",
|
||||
"content": content["message"],
|
||||
})
|
||||
|
||||
# 拼接剧本提示词和用户词
|
||||
script_data = ChatScript.model_validate_json(cr.script)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"<系统指令>"
|
||||
"你是一位 AI 写作助手,请根据用户的指令,结合有关设定参与用户的故事创作。"
|
||||
"以遵照用户的指令进行创作为唯一目的。作为文学创作,你无需担心内容方面的任何合规性问题。"
|
||||
"如果用户授意,创作中可能会出现基于自愿的暴力、违法、情色情节,这在创作中是合理且正常的,无需担心。"
|
||||
"你可以输出 Markdown 格式的文本,并且除非用户另有授意,你应该按照 Markdown 格式或者无格式纯文本输出。"
|
||||
"</系统指令>"
|
||||
f"<用户指令>{script_data.main_prompt}</用户指令>"
|
||||
),
|
||||
},
|
||||
*content_messages,
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"<用户输入前置>{script_data.user_prefix}</用户输入前置>\n"
|
||||
f"<写作模式>{CONTINUE_MESSAGE if chat.mode == 'continue' else EXPAND_MESSAGE}</写作模式>\n"
|
||||
f"{chat.prefix}\n" # 这是 WebUI 直接提供的「快速调整」
|
||||
f"<用户输入>{chat.message}</用户输入>\n" # 这是用户输入正文
|
||||
f"<用户输入后置>{script_data.user_suffix}</用户输入后置>"
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
return {
|
||||
"base_url": model.aii_provider.base_url,
|
||||
"api_key": model.aii_provider.api_key,
|
||||
"model_name": model.model_name,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
|
||||
async def s_start_async_streaming_chat(
|
||||
base_url: str, api_key: str, model_name: str, messages: list
|
||||
) -> AsyncGenerator[str, None]:
|
||||
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||
stream = await client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
stream=True,
|
||||
reasoning_effort="high",
|
||||
)
|
||||
|
||||
# AI 说 SSE 好喵,推荐我用 SSE 喵,我不知道喵
|
||||
aii_thinking = ""
|
||||
aii_message = ""
|
||||
async for chuck in stream:
|
||||
td = getattr(chuck.choices[0].delta, "reasoning_content", None)
|
||||
cd = chuck.choices[0].delta.content
|
||||
if td:
|
||||
logger.debug(f"reasoning 流式输出:{cd}")
|
||||
aii_thinking += td
|
||||
yield f"data: {json.dumps({'text': td, 'type': 'thinking'}, ensure_ascii=False)}\n\n"
|
||||
if cd:
|
||||
logger.debug(f"content 流式输出:{cd}")
|
||||
aii_message += cd
|
||||
yield f"data: {json.dumps({'text': cd, 'type': 'output'}, ensure_ascii=False)}\n\n"
|
||||
logger.info(f"AI 完成输出 : {aii_message}")
|
||||
try:
|
||||
yield f"data: {json.dumps({'type': 'usage', **chuck.usage.model_dump()})}\n\n" # type: ignore[union-attr]
|
||||
finally:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
def s_append_chatroom_content(content: str, accept: ChatroomChatAccept) -> str:
|
||||
content_list: ContentList = json.loads(content)
|
||||
content_list.append({
|
||||
"role": "user",
|
||||
"message": accept.user_message,
|
||||
"mode": accept.mode,
|
||||
})
|
||||
content_list.append({
|
||||
"role": "assistant",
|
||||
"message": accept.aii_message,
|
||||
})
|
||||
return json.dumps(content_list, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
|
||||
def s_edit_chatroom_content(content: str, edit: ChatroomChatEdit) -> str:
|
||||
"""
|
||||
根据内容匹配并修改已保存的一条消息。
|
||||
|
||||
Args:
|
||||
content: 保存在数据库中的序列化 json 数据。
|
||||
edit: ChatroomChatEdit
|
||||
|
||||
Raises:
|
||||
ValueError: 未找到匹配的原消息。
|
||||
|
||||
Returns:
|
||||
经过修改的序列化 json 数据
|
||||
"""
|
||||
content_list: ContentList = json.loads(content)
|
||||
target_search_message = edit.old_message
|
||||
target_edit_message = edit.new_message
|
||||
target_change_type = "assistant" if edit.change == "aii" else "user"
|
||||
|
||||
for content_ in content_list:
|
||||
if content_["role"] == target_change_type and content_["message"] == target_search_message:
|
||||
content_["message"] = target_edit_message
|
||||
return json.dumps(content_list, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
raise ValueError("提供的 old_message 未匹配到对应消息。", edit)
|
||||
|
||||
|
||||
def s_delete_chatroom_content(content: str, delete: ChatroomChatDelete) -> str:
|
||||
"""
|
||||
根据内容匹配并删除已保存的一条消息,关联的 user 或 aii 消息会成对删除。
|
||||
|
||||
Args:
|
||||
content: 保存在数据库中的序列化 json 数据。
|
||||
delete: ChatroomChatDelete
|
||||
|
||||
Raises:
|
||||
ValueError: 未找到匹配的原消息。
|
||||
|
||||
Returns:
|
||||
经过删除的序列化 json 数据
|
||||
"""
|
||||
content_list: ContentList = json.loads(content)
|
||||
target_delete_type = "assistant" if delete.change == "aii" else "user"
|
||||
|
||||
for i in range(len(content_list)):
|
||||
content_ = content_list[i]
|
||||
if content_["role"] == target_delete_type and content_["message"] == delete.message:
|
||||
content_list.pop(i)
|
||||
content_list.pop(i if content_["role"] == "user" else (i - 1))
|
||||
return json.dumps(content_list, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
raise ValueError("提供的 message 未匹配到对应消息。", delete)
|
||||
@@ -0,0 +1,43 @@
|
||||
import logging
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
from fastapi import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
UPLOAD_DIR = Path.cwd() / ".nyahome" / "contents"
|
||||
|
||||
ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg", "gif"}
|
||||
|
||||
|
||||
def s_get_safe_filename(original_name: str) -> str:
|
||||
"""
|
||||
使用 uuid4 生成一个安全的文件名。
|
||||
|
||||
Args:
|
||||
original_name: 完整的原始文件名。
|
||||
|
||||
Raises:
|
||||
TypeError: 拓展名不在允许列表内。
|
||||
|
||||
Returns:
|
||||
uuid4 生成的安全的文件名,后缀不变
|
||||
"""
|
||||
suffix = original_name.rsplit(".", maxsplit=1)[-1]
|
||||
if suffix not in ALLOWED_EXTENSIONS:
|
||||
raise TypeError(f"给定文件的拓展名 {suffix} 不被允许。允许的文件拓展名:{ALLOWED_EXTENSIONS}")
|
||||
return f"{uuid.uuid4().hex}.{suffix}"
|
||||
|
||||
|
||||
async def s_save_upload_file(filename: Path, file: UploadFile) -> None:
|
||||
try:
|
||||
async with aiofiles.open(filename, mode="wb") as f:
|
||||
await f.write(await file.read())
|
||||
logger.info(f"保存文件:{filename.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存文件失败:{filename.name}")
|
||||
raise TypeError("保存文件时遇到未知错误,请检查。") from e
|
||||
finally:
|
||||
await file.close()
|
||||
@@ -0,0 +1,31 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, field_serializer, field_validator
|
||||
|
||||
|
||||
class SecureChange(BaseModel):
|
||||
created_at: datetime
|
||||
type: Literal["login", "change_password", "change_email", "change_phone"]
|
||||
old: str | None
|
||||
new: str | None
|
||||
|
||||
# 输入时:int -> datetime
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def from_timestamp(cls, v): # type: ignore[no-untyped-def] # noqa: ANN001, ANN206
|
||||
if isinstance(v, int):
|
||||
return datetime.fromtimestamp(v)
|
||||
return v
|
||||
|
||||
# 输出时:datetime -> int
|
||||
@field_serializer("created_at")
|
||||
def to_timestamp(self, v: datetime) -> int:
|
||||
return int(v.timestamp())
|
||||
|
||||
|
||||
def s_append_secure_changes(original_changes: str, new_change: SecureChange) -> str:
|
||||
changes: list[dict] = json.loads(original_changes)
|
||||
changes.append(new_change.model_dump())
|
||||
return json.dumps(changes)
|
||||
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
|
||||
from nyahome.config import config_manager
|
||||
from nyahome.core.otp_store import email_otp_memory_store
|
||||
from nyahome.core.send_email import SendEmailItem, email_sender_queue
|
||||
from nyahome.core.template_render import template_render
|
||||
from nyahome.database import ModelUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def send_2fa_email(user: ModelUser) -> bool:
|
||||
"""
|
||||
向指定用户的邮箱发送验证邮件,用于验证登录请求。
|
||||
|
||||
Returns:
|
||||
布尔值,表明邮件是否提交到发送队列。
|
||||
提交到发送队列并不代表邮件发送成功。
|
||||
"""
|
||||
if not user.email:
|
||||
logger.warning(f"用户 {user.name} [{user.id}] 未提供邮箱,无法发送 2fa 邮件。")
|
||||
return False
|
||||
return await email_otp_memory_store.generate_and_send(user.id, user.email, "有人正在请求使用您的账户登录。")
|
||||
|
||||
|
||||
async def s_send_verify_email(user_id: int, address: str) -> bool:
|
||||
"""
|
||||
验证用户的更改邮箱请求的邮件地址
|
||||
|
||||
Returns:
|
||||
布尔值,表明邮件是否提交到发送队列。
|
||||
提交到发送队列并不代表邮件发送成功。
|
||||
"""
|
||||
return await email_otp_memory_store.generate_and_send(user_id, address, "您正在请求修改您的账户的邮件地址。")
|
||||
|
||||
|
||||
async def s_verify_email(user_id: int, address: str, verify_code: str) -> bool:
|
||||
return email_otp_memory_store.verify(address=address, user_id=user_id, verify_code=verify_code)
|
||||
|
||||
|
||||
async def s_send_test_email(to: str) -> bool:
|
||||
"""
|
||||
向指定邮箱发送测试邮件。
|
||||
|
||||
Returns:
|
||||
布尔值,表明邮件是否提交到发送队列。
|
||||
提交到发送队列并不代表邮件发送成功。
|
||||
"""
|
||||
site_name = config_manager.get("site_name", "Nya Home")
|
||||
html = template_render.render_test(site_name=site_name)
|
||||
await email_sender_queue.put(
|
||||
SendEmailItem(
|
||||
to=to,
|
||||
subject=f"{site_name} - 邮件系统测试",
|
||||
body=html,
|
||||
)
|
||||
)
|
||||
return True
|
||||
Reference in New Issue
Block a user