diff --git a/src/nyahome/cli/cli.py b/src/nyahome/cli/cli.py index 5f8e6bd..fa7285f 100644 --- a/src/nyahome/cli/cli.py +++ b/src/nyahome/cli/cli.py @@ -1,3 +1,7 @@ +from pathlib import Path + from rich.console import Console -console = Console() \ No newline at end of file +console = Console() + +ENV_PATH = Path.cwd() / ".nyahome" / ".env" diff --git a/src/nyahome/cli/cli_aii.py b/src/nyahome/cli/cli_aii.py new file mode 100644 index 0000000..99025ba --- /dev/null +++ b/src/nyahome/cli/cli_aii.py @@ -0,0 +1,100 @@ +from typing import Annotated, Sequence + +import typer +from dotenv import load_dotenv +from rich.table import Table + +from .cli import ENV_PATH, console + +aii_app = typer.Typer() + + +@aii_app.command(name="list") +def list_all_provider() -> None: + """ + 列出已设置的所有提供商和模型。 + """ + load_dotenv(ENV_PATH) + + from sqlmodel import Session, select + + from nyahome.database import AiiProvider, engine + + table = Table(title="AI 模型提供商与已录入模型") + table.add_column("ID", style="cyan", no_wrap=True) + table.add_column("提供商名称", style="white", no_wrap=True) + table.add_column("Base URL", style="white", no_wrap=True) + table.add_column("录入的模型", style="bright_black") + + with Session(engine) as session: + aps: Sequence[AiiProvider] = session.exec(select(AiiProvider)).all() + for ap in aps: + table.add_row(str(ap.id), ap.name, ap.base_url, str(ap.aii_models)) + + console.print(table) + + +@aii_app.command() +def add_provider( + name: Annotated[str, typer.Argument(help="提供商名称")], + base_url: Annotated[str, typer.Argument(help="提供商 Base URL(OpenAI 兼容端点)")], + api_key: Annotated[ + str, + typer.Option( + "--api-key", + "-k", + help="提供商 API Key", + prompt=True, + hide_input=True, + ), + ], +) -> None: + """ + 添加 AI 提供商。需要提供商名称、Base URL 和 API Key。 + """ + load_dotenv(ENV_PATH) + + from sqlmodel import Session + + from nyahome.database import AiiProvider, engine + + console.print(f"[cyan]正在添加模型提供商 [{name}]({base_url})[/cyan]") + + with Session(engine) as session: + ap = AiiProvider(name=name, base_url=base_url, api_key=api_key) + session.add(ap) + session.commit() + session.refresh(ap) + + console.print(f"[cyan]添加完成 [{ap.id}][{ap.name}]({ap.base_url})[/cyan]") + + +@aii_app.command() +def add_model( + model_name: Annotated[str, typer.Argument(help="模型名称(需准确填写)")], + max_context_length: Annotated[int, typer.Argument(help="最大上下文长度(单位为 k)")], + provider_id: Annotated[ + int, + typer.Option( + "--provider-id", + "-p", + help="该模型所属于的模型提供商 ID", + ), + ], +) -> None: + """ + 添加 AI 模型。在此之前需要先添加该模型的提供商。 + """ + load_dotenv(ENV_PATH) + + from sqlmodel import Session + + from nyahome.database import AiiModel, engine + + with Session(engine) as session: + am = AiiModel(model_name=model_name, max_context_length=max_context_length, aii_provider_id=provider_id) + session.add(am) + session.commit() + session.refresh(am) + + console.print(f"[cyan]已添加模型 [{am.id}][{am.model_name}]({am.aii_provider_id})[/cyan]") diff --git a/src/nyahome/cli/cli_env.py b/src/nyahome/cli/cli_env.py index 43b4c14..7659dcd 100644 --- a/src/nyahome/cli/cli_env.py +++ b/src/nyahome/cli/cli_env.py @@ -1,19 +1,15 @@ import os -from pathlib import Path from typing import Annotated import typer from dotenv import load_dotenv, set_key, unset_key from rich.table import Table -from .cli import console +from .cli import ENV_PATH, console env_app = typer.Typer() -ENV_PATH = Path.cwd() / ".nyahome" / ".env" - - @env_app.command(name="list") def list_all_envs() -> None: """ diff --git a/src/nyahome/manage.py b/src/nyahome/manage.py index 1e0ec3e..29c4300 100644 --- a/src/nyahome/manage.py +++ b/src/nyahome/manage.py @@ -9,6 +9,7 @@ import typer from nyahome import __version__ from nyahome.cli.cli import console +from nyahome.cli.cli_aii import aii_app from nyahome.cli.cli_env import ENV_PATH, env_app app = typer.Typer( @@ -75,6 +76,7 @@ def openapi( app.add_typer(env_app, name="env", no_args_is_help=True, help="设置 NyaHome 应用的环境变量。") +app.add_typer(aii_app, name="aii", no_args_is_help=True, help="添加、设置、修改 AI 提供商和模型。") if __name__ == "__main__":