Compare commits
3 Commits
1f1ac5f87a
...
884cea53a1
| Author | SHA1 | Date | |
|---|---|---|---|
|
884cea53a1
|
|||
|
52f6904bef
|
|||
|
a7140ea5c1
|
@@ -36,8 +36,29 @@
|
||||
scrollbar-color: rgba(255, 255, 255, 0.15) transparent;
|
||||
}
|
||||
|
||||
/* vitepress-openapi */
|
||||
/* ===== vitepress-oepnapi ===== */
|
||||
|
||||
/* Details */
|
||||
details {
|
||||
padding: 10px 6px;
|
||||
border-width: 1px;
|
||||
border-radius: 6px;
|
||||
border-color: transparent;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
details[open] {
|
||||
border-color: rgba(255, 255, 255, 0.3)
|
||||
}
|
||||
|
||||
details>summary {
|
||||
text-align: center;
|
||||
list-style: none;
|
||||
}
|
||||
|
||||
/* 底部 */
|
||||
div.vitepress-openapi {
|
||||
background: linear-gradient(45deg, hsla(58, 100%, 92%, 0.6), hsla(128, 100%, 75%, 0.5));
|
||||
border: 1px solid #64ffc4;
|
||||
border-radius: 6px;
|
||||
margin: 48px auto 0;
|
||||
@@ -48,4 +69,4 @@ div.vitepress-openapi {
|
||||
|
||||
div.vitepress-openapi p {
|
||||
line-height: 8px;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ const operationId = route.data.params.operationId
|
||||
<template #branding>
|
||||
<div class="vitepress-openapi">
|
||||
<p>API 文档是基于最新代码自动生成的</p>
|
||||
<p>由 VitePress OpenAPI 提供文档支持</p>
|
||||
<p>由 <a href="https://vitepress-openapi.vercel.app/" target="_blank">VitePress OpenAPI</a> 提供文档支持</p>
|
||||
</div>
|
||||
</template>
|
||||
</OAOperation>
|
||||
</OAOperation>
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
console = Console()
|
||||
|
||||
ENV_PATH = Path.cwd() / ".nyahome" / ".env"
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
from typing import Annotated, Sequence
|
||||
|
||||
import typer
|
||||
from dotenv import load_dotenv
|
||||
from rich.table import Table
|
||||
|
||||
from .cli import ENV_PATH, console
|
||||
|
||||
aii_app = typer.Typer()
|
||||
|
||||
|
||||
@aii_app.command(name="list")
|
||||
def list_all_provider() -> None:
|
||||
"""
|
||||
列出已设置的所有提供商和模型。
|
||||
"""
|
||||
load_dotenv(ENV_PATH)
|
||||
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from nyahome.database import AiiProvider, engine
|
||||
|
||||
table = Table(title="AI 模型提供商与已录入模型")
|
||||
table.add_column("ID", style="cyan", no_wrap=True)
|
||||
table.add_column("提供商名称", style="white", no_wrap=True)
|
||||
table.add_column("Base URL", style="white", no_wrap=True)
|
||||
table.add_column("录入的模型", style="bright_black")
|
||||
|
||||
with Session(engine) as session:
|
||||
aps: Sequence[AiiProvider] = session.exec(select(AiiProvider)).all()
|
||||
for ap in aps:
|
||||
table.add_row(str(ap.id), ap.name, ap.base_url, str(ap.aii_models))
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@aii_app.command()
|
||||
def add_provider(
|
||||
name: Annotated[str, typer.Argument(help="提供商名称")],
|
||||
base_url: Annotated[str, typer.Argument(help="提供商 Base URL(OpenAI 兼容端点)")],
|
||||
api_key: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--api-key",
|
||||
"-k",
|
||||
help="提供商 API Key",
|
||||
prompt=True,
|
||||
hide_input=True,
|
||||
),
|
||||
],
|
||||
) -> None:
|
||||
"""
|
||||
添加 AI 提供商。需要提供商名称、Base URL 和 API Key。
|
||||
"""
|
||||
load_dotenv(ENV_PATH)
|
||||
|
||||
from sqlmodel import Session
|
||||
|
||||
from nyahome.database import AiiProvider, engine
|
||||
|
||||
console.print(f"[cyan]正在添加模型提供商 [{name}]({base_url})[/cyan]")
|
||||
|
||||
with Session(engine) as session:
|
||||
ap = AiiProvider(name=name, base_url=base_url, api_key=api_key)
|
||||
session.add(ap)
|
||||
session.commit()
|
||||
session.refresh(ap)
|
||||
|
||||
console.print(f"[cyan]添加完成 [{ap.id}][{ap.name}]({ap.base_url})[/cyan]")
|
||||
|
||||
|
||||
@aii_app.command()
|
||||
def add_model(
|
||||
model_name: Annotated[str, typer.Argument(help="模型名称(需准确填写)")],
|
||||
max_context_length: Annotated[int, typer.Argument(help="最大上下文长度(单位为 k)")],
|
||||
provider_id: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
"--provider-id",
|
||||
"-p",
|
||||
help="该模型所属于的模型提供商 ID",
|
||||
),
|
||||
],
|
||||
) -> None:
|
||||
"""
|
||||
添加 AI 模型。在此之前需要先添加该模型的提供商。
|
||||
"""
|
||||
load_dotenv(ENV_PATH)
|
||||
|
||||
from sqlmodel import Session
|
||||
|
||||
from nyahome.database import AiiModel, engine
|
||||
|
||||
with Session(engine) as session:
|
||||
am = AiiModel(model_name=model_name, max_context_length=max_context_length, aii_provider_id=provider_id)
|
||||
session.add(am)
|
||||
session.commit()
|
||||
session.refresh(am)
|
||||
|
||||
console.print(f"[cyan]已添加模型 [{am.id}][{am.model_name}]({am.aii_provider_id})[/cyan]")
|
||||
@@ -1,19 +1,15 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
from dotenv import load_dotenv, set_key, unset_key
|
||||
from rich.table import Table
|
||||
|
||||
from .cli import console
|
||||
from .cli import ENV_PATH, console
|
||||
|
||||
env_app = typer.Typer()
|
||||
|
||||
|
||||
ENV_PATH = Path.cwd() / ".nyahome" / ".env"
|
||||
|
||||
|
||||
@env_app.command(name="list")
|
||||
def list_all_envs() -> None:
|
||||
"""
|
||||
|
||||
@@ -9,6 +9,7 @@ import typer
|
||||
|
||||
from nyahome import __version__
|
||||
from nyahome.cli.cli import console
|
||||
from nyahome.cli.cli_aii import aii_app
|
||||
from nyahome.cli.cli_env import ENV_PATH, env_app
|
||||
|
||||
app = typer.Typer(
|
||||
@@ -68,6 +69,10 @@ def openapi(
|
||||
"""
|
||||
根据代码导出 NyaHome 的 openapi.json 。
|
||||
"""
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(ENV_PATH)
|
||||
|
||||
from nyahome.server import save_openapi_json
|
||||
|
||||
save_openapi_json(path)
|
||||
@@ -75,6 +80,7 @@ def openapi(
|
||||
|
||||
|
||||
app.add_typer(env_app, name="env", no_args_is_help=True, help="设置 NyaHome 应用的环境变量。")
|
||||
app.add_typer(aii_app, name="aii", no_args_is_help=True, help="添加、设置、修改 AI 提供商和模型。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -49,8 +49,17 @@ class VerifyEmail(BaseModel):
|
||||
verify_code: str
|
||||
|
||||
|
||||
@admin_router.post("/login/name/")
|
||||
@admin_router.post("/login/name/", name="用户登录")
|
||||
async def nyahome_login_name(user: UserLogin, session: Annotated[Session, Depends(get_session)]) -> ReturnDto:
|
||||
"""
|
||||
使用用户名密码登录。
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 表示尝试登录的用户不存在。
|
||||
|
||||
Returns:
|
||||
ReturnDto,其中 result 字段包含 `user_id` 和 `access_token` 两个字段。
|
||||
"""
|
||||
try:
|
||||
u: ModelUser = session.exec(select(ModelUser).where(ModelUser.name == user.username)).one()
|
||||
except NoResultFound:
|
||||
@@ -75,15 +84,28 @@ async def nyahome_login_name(user: UserLogin, session: Annotated[Session, Depend
|
||||
raise HTTPException(status_code=401, detail="验证失败,请检查用户名和密码是否正确")
|
||||
|
||||
|
||||
@admin_router.get("/me/")
|
||||
@admin_router.get("/me/", name="获取登录用户信息")
|
||||
async def nyahome_get_me(user: Annotated[ModelUser, Depends(verify_token)]) -> ModelUser:
|
||||
"""
|
||||
获取当前登录的用户的详细信息。
|
||||
|
||||
Returns:
|
||||
ModelUser
|
||||
"""
|
||||
return user
|
||||
|
||||
|
||||
@admin_router.post("/me/")
|
||||
@admin_router.post("/me/", name="修改登录用户信息")
|
||||
async def nyahome_post_me(
|
||||
info: UserInfo, user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)]
|
||||
) -> ModelUser:
|
||||
"""
|
||||
修改当前登录的用户的详细信息。
|
||||
此端点可以修改除了用户密码、邮箱、手机号之外的大部分用户信息。
|
||||
|
||||
Returns:
|
||||
ModelUser
|
||||
"""
|
||||
user.name = info.name
|
||||
user.display_name = info.display_name
|
||||
user.avatar_url = info.avatar_url
|
||||
@@ -95,12 +117,21 @@ async def nyahome_post_me(
|
||||
return user
|
||||
|
||||
|
||||
@admin_router.post("/me/password/")
|
||||
@admin_router.post("/me/password/", name="修改用户密码")
|
||||
async def nyahome_change_password(
|
||||
change: ChangePassword,
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
修改用户密码。
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 提供的旧密码错误。
|
||||
|
||||
Returns:
|
||||
不重要的 ReturnDto,无异常本身即表示修改成功。
|
||||
"""
|
||||
if verify_password(change.old_password, user.password):
|
||||
user.password = save_password(change.new_password)
|
||||
change_ = SecureChange(
|
||||
@@ -116,12 +147,20 @@ async def nyahome_change_password(
|
||||
raise HTTPException(status_code=400, detail="修改密码需要提供旧的密码,但提供的旧密码错误。") from None
|
||||
|
||||
|
||||
@admin_router.post("/me/email-verify/")
|
||||
@admin_router.post("/me/email-verify/", name="验证并修改用户邮箱")
|
||||
async def nyahome_verify_email(
|
||||
to: VerifyEmail,
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
验证用户提供的邮箱以及验证码。
|
||||
需要先通过 `/me/email-verify/send/` 发送验证码。验证码有五分钟有效期,为六位数字,需以字符串形式提供。
|
||||
|
||||
Returns:
|
||||
ReturnDto,其中 success 字段表明是否成功。如果成功,则用户邮箱已被修改。
|
||||
不返回完整的 ModelUser,WebUI 自行负责前端用户信息更新。
|
||||
"""
|
||||
success = await s_verify_email(user_id=user.id, address=to.to, verify_code=to.verify_code)
|
||||
if success:
|
||||
old_email = user.email
|
||||
@@ -141,23 +180,37 @@ async def nyahome_verify_email(
|
||||
return ReturnDto(success=success)
|
||||
|
||||
|
||||
@admin_router.post("/me/email-verify/send/")
|
||||
@admin_router.post("/me/email-verify/send/", name="发送修改邮箱验证码")
|
||||
async def nyahome_verify_email_send(to: SendEmail, user: Annotated[ModelUser, Depends(verify_token)]) -> ReturnDto:
|
||||
"""
|
||||
请求对新的邮箱发送验证码。验证码有五分钟有效期,为六位数字。
|
||||
|
||||
Returns:
|
||||
ReturnDto,其中 success 字段表明是否成功。
|
||||
"""
|
||||
success = await s_send_verify_email(user.id, to.to)
|
||||
return ReturnDto(success=success)
|
||||
|
||||
|
||||
@admin_router.get("/me/secure_changes/")
|
||||
@admin_router.get("/me/secure_changes/", name="获取用户安全变更记录")
|
||||
async def nyahome_get_secure_changes(
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
) -> list[SecureChange]:
|
||||
"""
|
||||
获取用户的安全变更记录。
|
||||
安全变更记录包括:登录、修改密码、修改邮箱、修改手机号。
|
||||
|
||||
Returns:
|
||||
SecureChange 列表。
|
||||
"""
|
||||
return json.loads(user.secure_changes) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
@admin_router.get("/site_config/")
|
||||
@admin_router.get("/site_config/", name="获取 NyaHome 设置")
|
||||
async def get_site_config(user: Annotated[ModelUser, Depends(verify_token)]) -> dict[str, Any]:
|
||||
"""
|
||||
获取 NyaHome 的设置。
|
||||
需要管理员权限才能访问。
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 表示请求用户非管理员。
|
||||
@@ -170,13 +223,14 @@ async def get_site_config(user: Annotated[ModelUser, Depends(verify_token)]) ->
|
||||
return config_manager.get_config()
|
||||
|
||||
|
||||
@admin_router.post("/site_config/")
|
||||
@admin_router.post("/site_config/", name="修改 NyaHome 设置")
|
||||
async def set_site_config(
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
config_: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
设置 NyaHome 的设置。
|
||||
需要管理员权限才能访问。
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 表示请求用户非管理员。
|
||||
@@ -191,7 +245,7 @@ async def set_site_config(
|
||||
return final_config
|
||||
|
||||
|
||||
@admin_router.post("/email-test/")
|
||||
@admin_router.post("/email-test/", name="测试邮件发送")
|
||||
async def nyahome_test_email(to: SendEmail, user: Annotated[ModelUser, Depends(verify_token)]) -> ReturnDto:
|
||||
"""
|
||||
NyaHome 管理员面板中的测试邮件端点。
|
||||
|
||||
@@ -22,18 +22,37 @@ from .response_model import ReturnDto
|
||||
aii_router = APIRouter(tags=["Aii"], prefix="/aii")
|
||||
|
||||
|
||||
@aii_router.get("/model/")
|
||||
@aii_router.get("/model/", name="获取模型列表")
|
||||
async def get_all_model(session: Annotated[Session, Depends(get_session)]) -> ReturnDto:
|
||||
"""
|
||||
获取 AI 模型列表。
|
||||
此接口无需用户登录即可访问。
|
||||
|
||||
Returns:
|
||||
被 ReturnDto 包裹的 AiiModel 列表
|
||||
"""
|
||||
final_model_list = apply_get_models(session)
|
||||
return ReturnDto(result=final_model_list)
|
||||
|
||||
|
||||
@aii_router.post("/model/")
|
||||
@aii_router.post("/model/", name="添加模型")
|
||||
async def add_model(
|
||||
model: AiiModelPublic,
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
添加新的 AI 模型。需要基于已添加的模型提供商。
|
||||
此接口需要管理员访问。
|
||||
添加模型时不会进行可用性检查,因此 WebUI 在前端实现了检查按钮。此端点不会负责检查。
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 用户无权限管理模型(未登录或非管理员)
|
||||
HTTPException: 404 模型提供商不存在
|
||||
|
||||
Returns:
|
||||
被 ReturnDto 包裹的、添加的 AiiModel
|
||||
"""
|
||||
if not user.is_admin:
|
||||
raise HTTPException(status_code=401, detail="用户无权限管理模型。") from None
|
||||
|
||||
@@ -53,18 +72,36 @@ async def add_model(
|
||||
return ReturnDto(result=z_aii_model(am))
|
||||
|
||||
|
||||
@aii_router.get("/provider/")
|
||||
@aii_router.get("/provider/", name="获取提供商列表")
|
||||
async def get_all_provider(session: Annotated[Session, Depends(get_session)]) -> ReturnDto:
|
||||
"""
|
||||
获取 AI 模型提供商列表。
|
||||
此接口无需用户登录即可访问。
|
||||
|
||||
Returns:
|
||||
被 ReturnDto 包裹的 AiiProvider 列表
|
||||
"""
|
||||
aii_providers = session.exec(select(AiiProvider)).all()
|
||||
return ReturnDto(result=[z_aii_provider(ap) for ap in aii_providers])
|
||||
|
||||
|
||||
@aii_router.post("/provider/")
|
||||
@aii_router.post("/provider/", name="添加提供商")
|
||||
async def add_provider(
|
||||
provider: AiiProviderPublic,
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
添加新的 AI 模型提供商。
|
||||
此接口需要管理员才能访问。
|
||||
添加提供商时不会进行可用性检查,因此 WebUI 在前端实现了检查按钮。此端点不会负责检查。
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 表示用户未登录或非管理员。
|
||||
|
||||
Returns:
|
||||
被 ReturnDto 包裹的、添加的 AiiProvider
|
||||
"""
|
||||
if not user.is_admin:
|
||||
raise HTTPException(status_code=401, detail="用户无权限管理模型。") from None
|
||||
ap = AiiProvider(name=provider.name, base_url=provider.base_url, api_key=provider.api_key)
|
||||
@@ -74,10 +111,20 @@ async def add_provider(
|
||||
return ReturnDto(result=z_aii_provider(ap))
|
||||
|
||||
|
||||
@aii_router.get("/provider/{id_}/remote/models/")
|
||||
@aii_router.get("/provider/{id_}/remote/models/", name="获取提供商远端模型")
|
||||
async def get_provider_remote_models(
|
||||
id_: int, user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)]
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
查看指定模型提供商提供的远端模型列表。并非添加到 NyaHome 的模型列表。
|
||||
此接口需要管理员才能访问。
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 表示用户未登录或非管理员。
|
||||
|
||||
Returns:
|
||||
被 ReturnDto 包裹的、模型名称字符串列表
|
||||
"""
|
||||
if not user.is_admin:
|
||||
raise HTTPException(status_code=401, detail="用户无权限管理模型。") from None
|
||||
try:
|
||||
@@ -89,16 +136,12 @@ async def get_provider_remote_models(
|
||||
return ReturnDto(result=[m["id"] for m in models])
|
||||
|
||||
|
||||
@aii_router.get("/provider/{id_}/remote/model/{model_name}/")
|
||||
@aii_router.get("/provider/{id_}/remote/model/{model_name}/", name="检查指定远端模型可用性")
|
||||
async def check_remote_provider_model(
|
||||
id_: int, model_name: str, session: Annotated[Session, Depends(get_session)]
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
检测指定提供商的指定名称模型是否可用。
|
||||
Args:
|
||||
id_: 模型提供商 ID。
|
||||
model_name: 模型名称。
|
||||
session: 数据库连接对象。
|
||||
检测指定提供商的指定名称远端模型是否可用。
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 表明提供商 ID 未找到。
|
||||
@@ -113,8 +156,14 @@ async def check_remote_provider_model(
|
||||
return ReturnDto(result=await s_check_remote_model(model_name, ap.base_url, ap.api_key))
|
||||
|
||||
|
||||
@aii_router.post("/remote/provider/check/")
|
||||
@aii_router.post("/remote/provider/check/", name="检查指定提供商可用性")
|
||||
async def check_remote_provider(provider: AiiProviderPublic) -> ReturnDto:
|
||||
"""
|
||||
检查指定提供商是否可用。会返回提供商提供的模型数量作为测试。
|
||||
|
||||
Returns:
|
||||
ReturnDto,其中 success 字段为布尔值,表明可用状态;如果为真,result 字段是整型模型数量。
|
||||
"""
|
||||
try:
|
||||
count = len(await s_list_remote_provider_models(provider.base_url, provider.api_key))
|
||||
return ReturnDto(result=count)
|
||||
|
||||
@@ -31,7 +31,7 @@ from .response_model import ReturnDto
|
||||
chatroom_router = APIRouter(tags=["Chatroom"], prefix="/chatroom")
|
||||
|
||||
|
||||
@chatroom_router.get("/{id_}/")
|
||||
@chatroom_router.get("/{id_}/", name="获取指定聊天室")
|
||||
async def get_chatroom(
|
||||
id_: int, user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)]
|
||||
) -> ReturnDto:
|
||||
@@ -54,7 +54,7 @@ async def get_chatroom(
|
||||
return ReturnDto(result=cr.model_dump())
|
||||
|
||||
|
||||
@chatroom_router.get("/")
|
||||
@chatroom_router.get("/", name="获取聊天室列表")
|
||||
async def get_all_chatroom(
|
||||
user: Annotated[ModelUser, Depends(verify_token)], session: Annotated[Session, Depends(get_session)]
|
||||
) -> ReturnDto:
|
||||
@@ -68,7 +68,7 @@ async def get_all_chatroom(
|
||||
return ReturnDto(result=[cr.model_dump(exclude={"content", "script"}) for cr in crs])
|
||||
|
||||
|
||||
@chatroom_router.post("/")
|
||||
@chatroom_router.post("/", name="创建聊天室")
|
||||
async def create_chatroom(
|
||||
chatroom: ChatroomPublic,
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
@@ -96,7 +96,7 @@ async def create_chatroom(
|
||||
return ReturnDto(result=cr.model_dump())
|
||||
|
||||
|
||||
@chatroom_router.post("/{id_}/")
|
||||
@chatroom_router.post("/{id_}/", name="修改指定聊天室")
|
||||
async def edit_chatroom(
|
||||
id_: int,
|
||||
chatroom: ChatroomPublic,
|
||||
@@ -131,13 +131,14 @@ async def edit_chatroom(
|
||||
cr.feature_image = chatroom.feature_image
|
||||
cr.script_template_id = chatroom.script_template_id
|
||||
cr.script_template_version = chatroom.script_template_version
|
||||
cr.default_model_id = chatroom.default_model_id
|
||||
session.add(cr)
|
||||
session.commit()
|
||||
session.refresh(cr)
|
||||
return ReturnDto(result=cr.model_dump())
|
||||
|
||||
|
||||
@chatroom_router.post("/{id_}/script/")
|
||||
@chatroom_router.post("/{id_}/script/", name="修改聊天室脚本")
|
||||
async def update_chatroom_script(
|
||||
id_: int,
|
||||
script: ChatScript,
|
||||
@@ -172,7 +173,7 @@ async def update_chatroom_script(
|
||||
return ReturnDto(result=script.model_dump())
|
||||
|
||||
|
||||
@chatroom_router.post("/{id_}/chat/")
|
||||
@chatroom_router.post("/{id_}/chat/", name="聊天室发起模型创作")
|
||||
async def post_chatroom_chat(
|
||||
id_: int,
|
||||
chat: ChatroomChat,
|
||||
@@ -181,6 +182,7 @@ async def post_chatroom_chat(
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
在聊天室中发送新的用户消息,流式返回 AI 调用结果。
|
||||
即:调用模型发起创作。
|
||||
|
||||
Args:
|
||||
id_: (路径参数)聊天室 ID
|
||||
@@ -203,7 +205,7 @@ async def post_chatroom_chat(
|
||||
raise e
|
||||
|
||||
|
||||
@chatroom_router.post("/{id_}/chat/accept/")
|
||||
@chatroom_router.post("/{id_}/chat/accept/", name="聊天室保存模型创作")
|
||||
async def accept_chatroom_chat(
|
||||
id_: int,
|
||||
accept: ChatroomChatAccept,
|
||||
@@ -212,6 +214,7 @@ async def accept_chatroom_chat(
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
此端点不负责调用 AI 生成输出,而是用于保存一对用户消息和 AI 输出到聊天室 content 的最后。
|
||||
需要提供用户消息、AI 消息和创作模式。
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 表明未找到聊天室。
|
||||
@@ -232,7 +235,7 @@ async def accept_chatroom_chat(
|
||||
return ReturnDto(result=cr.model_dump())
|
||||
|
||||
|
||||
@chatroom_router.post("/{id_}/chat/edit/")
|
||||
@chatroom_router.post("/{id_}/chat/edit/", name="聊天室编辑消息")
|
||||
async def edit_chatroom_chat(
|
||||
id_: int,
|
||||
edit: ChatroomChatEdit,
|
||||
@@ -241,6 +244,7 @@ async def edit_chatroom_chat(
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
此端点不负责调用 AI 生成输出,而是用于修改一条已经保存在聊天记录中的消息。
|
||||
需要提供消息类型(用户/AI)、旧消息和新消息,以便进行替换。
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 表明未找到聊天室,400 表明聊天记录匹配失败,未更新。
|
||||
@@ -264,7 +268,7 @@ async def edit_chatroom_chat(
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
|
||||
@chatroom_router.post("/{id_}/chat/delete/")
|
||||
@chatroom_router.post("/{id_}/chat/delete/", name="聊天室删除消息")
|
||||
async def delete_chatroom_chat(
|
||||
id_: int,
|
||||
delete: ChatroomChatDelete,
|
||||
@@ -273,6 +277,7 @@ async def delete_chatroom_chat(
|
||||
) -> ReturnDto:
|
||||
"""
|
||||
此端点不负责调用 AI 生成输出,而是用于删除一条已经保存在聊天记录中的消息。关联的 user 或 aii 消息会一并删除。
|
||||
需要提供消息和消息类型(用户/AI)。用户消息和 AI 消息是一对一成对的,所以总是会删除关联的一对(两条)消息。
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 表明未找到聊天室,400 表明聊天记录匹配失败,未更新。
|
||||
|
||||
@@ -13,11 +13,17 @@ from .auth import verify_token
|
||||
file_router = APIRouter(tags=["File"], prefix="/file")
|
||||
|
||||
|
||||
@file_router.get("/")
|
||||
@file_router.get("/", name="获取文件列表")
|
||||
async def get_files(
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> Sequence[ModelUploadFile]:
|
||||
"""
|
||||
获取用户上传的文件列表。
|
||||
|
||||
Returns:
|
||||
ModelUploadFile 列表。
|
||||
"""
|
||||
files: Sequence[ModelUploadFile] = session.exec(
|
||||
select(ModelUploadFile).where(ModelUploadFile.uploader_id == user.id)
|
||||
).all()
|
||||
@@ -25,12 +31,29 @@ async def get_files(
|
||||
return files
|
||||
|
||||
|
||||
@file_router.post("/upload/")
|
||||
@file_router.post("/upload/", name="上传文件")
|
||||
async def file_upload(
|
||||
file: Annotated[UploadFile, File()],
|
||||
user: Annotated[ModelUser, Depends(verify_token)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> ModelUploadFile:
|
||||
"""
|
||||
仅允许单文件上传。
|
||||
文件存储在 `.nyahome/contents` 目录下,由 uuid4 重命名,保留原拓展名。
|
||||
允许上传的文件拓展名由 NyaHome 设置 `allow_upload_file_extensions` 约束。
|
||||
对于不允许上传的文件类型,将抛出 400 错误。
|
||||
|
||||
Args:
|
||||
file: 文件对象
|
||||
user: 经验证的用户
|
||||
session: 数据库连接对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 表示上传的文件类型不允许。文件类型仅由拓展名判断,不检查 MIME。
|
||||
|
||||
Returns:
|
||||
ModelUploadFile
|
||||
"""
|
||||
try:
|
||||
safe_name = s_get_safe_filename(file.filename) # type: ignore[arg-type]
|
||||
dest_path = UPLOAD_DIR / safe_name
|
||||
|
||||
@@ -5,16 +5,10 @@ from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from nyahome.config import config_manager
|
||||
from nyahome.core.otp_store import email_otp_memory_store
|
||||
from nyahome.core.send_email import email_sender_queue
|
||||
from nyahome.core.task import init_admin_user
|
||||
from nyahome.database import create_db
|
||||
from nyahome.router import admin_router, aii_router, chatroom_router, file_router, webui_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,8 +16,15 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app_: FastAPI) -> AsyncGenerator[None, Any]:
|
||||
load_dotenv(Path.cwd() / ".nyahome" / ".env")
|
||||
# 在生命周期函数内先加载环境变量,再局部导入 nyahome 核心模块
|
||||
logger.info("🚀 服务启动中...")
|
||||
|
||||
from nyahome.config import config_manager
|
||||
from nyahome.core.otp_store import email_otp_memory_store
|
||||
from nyahome.core.send_email import email_sender_queue
|
||||
from nyahome.core.task import init_admin_user
|
||||
from nyahome.database import create_db
|
||||
|
||||
create_db()
|
||||
await asyncio.gather(init_admin_user(), config_manager.async_load_config())
|
||||
email_sender_queue.start()
|
||||
|
||||
Reference in New Issue
Block a user