Files
NyaHome/src/nyahome/cli/cli_aii.py
T

101 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import Annotated, Sequence
import typer
from dotenv import load_dotenv
from rich.table import Table
from .cli import ENV_PATH, console
aii_app = typer.Typer()
@aii_app.command(name="list")
def list_all_provider() -> None:
"""
列出已设置的所有提供商和模型。
"""
load_dotenv(ENV_PATH)
from sqlmodel import Session, select
from nyahome.database import AiiProvider, engine
table = Table(title="AI 模型提供商与已录入模型")
table.add_column("ID", style="cyan", no_wrap=True)
table.add_column("提供商名称", style="white", no_wrap=True)
table.add_column("Base URL", style="white", no_wrap=True)
table.add_column("录入的模型", style="bright_black")
with Session(engine) as session:
aps: Sequence[AiiProvider] = session.exec(select(AiiProvider)).all()
for ap in aps:
table.add_row(str(ap.id), ap.name, ap.base_url, str(ap.aii_models))
console.print(table)
@aii_app.command()
def add_provider(
name: Annotated[str, typer.Argument(help="提供商名称")],
base_url: Annotated[str, typer.Argument(help="提供商 Base URLOpenAI 兼容端点)")],
api_key: Annotated[
str,
typer.Option(
"--api-key",
"-k",
help="提供商 API Key",
prompt=True,
hide_input=True,
),
],
) -> None:
"""
添加 AI 提供商。需要提供商名称、Base URL 和 API Key。
"""
load_dotenv(ENV_PATH)
from sqlmodel import Session
from nyahome.database import AiiProvider, engine
console.print(f"[cyan]正在添加模型提供商 [{name}]({base_url})[/cyan]")
with Session(engine) as session:
ap = AiiProvider(name=name, base_url=base_url, api_key=api_key)
session.add(ap)
session.commit()
session.refresh(ap)
console.print(f"[cyan]添加完成 [{ap.id}][{ap.name}]({ap.base_url})[/cyan]")
@aii_app.command()
def add_model(
model_name: Annotated[str, typer.Argument(help="模型名称(需准确填写)")],
max_context_length: Annotated[int, typer.Argument(help="最大上下文长度(单位为 k")],
provider_id: Annotated[
int,
typer.Option(
"--provider-id",
"-p",
help="该模型所属于的模型提供商 ID",
),
],
) -> None:
"""
添加 AI 模型。在此之前需要先添加该模型的提供商。
"""
load_dotenv(ENV_PATH)
from sqlmodel import Session
from nyahome.database import AiiModel, engine
with Session(engine) as session:
am = AiiModel(model_name=model_name, max_context_length=max_context_length, aii_provider_id=provider_id)
session.add(am)
session.commit()
session.refresh(am)
console.print(f"[cyan]已添加模型 [{am.id}][{am.model_name}]({am.aii_provider_id})[/cyan]")