101 lines
2.9 KiB
Python
101 lines
2.9 KiB
Python
from typing import Annotated, Sequence
|
||
|
||
import typer
|
||
from dotenv import load_dotenv
|
||
from rich.table import Table
|
||
|
||
from .cli import ENV_PATH, console
|
||
|
||
aii_app = typer.Typer()
|
||
|
||
|
||
@aii_app.command(name="list")
|
||
def list_all_provider() -> None:
|
||
"""
|
||
列出已设置的所有提供商和模型。
|
||
"""
|
||
load_dotenv(ENV_PATH)
|
||
|
||
from sqlmodel import Session, select
|
||
|
||
from nyahome.database import AiiProvider, engine
|
||
|
||
table = Table(title="AI 模型提供商与已录入模型")
|
||
table.add_column("ID", style="cyan", no_wrap=True)
|
||
table.add_column("提供商名称", style="white", no_wrap=True)
|
||
table.add_column("Base URL", style="white", no_wrap=True)
|
||
table.add_column("录入的模型", style="bright_black")
|
||
|
||
with Session(engine) as session:
|
||
aps: Sequence[AiiProvider] = session.exec(select(AiiProvider)).all()
|
||
for ap in aps:
|
||
table.add_row(str(ap.id), ap.name, ap.base_url, str(ap.aii_models))
|
||
|
||
console.print(table)
|
||
|
||
|
||
@aii_app.command()
|
||
def add_provider(
|
||
name: Annotated[str, typer.Argument(help="提供商名称")],
|
||
base_url: Annotated[str, typer.Argument(help="提供商 Base URL(OpenAI 兼容端点)")],
|
||
api_key: Annotated[
|
||
str,
|
||
typer.Option(
|
||
"--api-key",
|
||
"-k",
|
||
help="提供商 API Key",
|
||
prompt=True,
|
||
hide_input=True,
|
||
),
|
||
],
|
||
) -> None:
|
||
"""
|
||
添加 AI 提供商。需要提供商名称、Base URL 和 API Key。
|
||
"""
|
||
load_dotenv(ENV_PATH)
|
||
|
||
from sqlmodel import Session
|
||
|
||
from nyahome.database import AiiProvider, engine
|
||
|
||
console.print(f"[cyan]正在添加模型提供商 [{name}]({base_url})[/cyan]")
|
||
|
||
with Session(engine) as session:
|
||
ap = AiiProvider(name=name, base_url=base_url, api_key=api_key)
|
||
session.add(ap)
|
||
session.commit()
|
||
session.refresh(ap)
|
||
|
||
console.print(f"[cyan]添加完成 [{ap.id}][{ap.name}]({ap.base_url})[/cyan]")
|
||
|
||
|
||
@aii_app.command()
|
||
def add_model(
|
||
model_name: Annotated[str, typer.Argument(help="模型名称(需准确填写)")],
|
||
max_context_length: Annotated[int, typer.Argument(help="最大上下文长度(单位为 k)")],
|
||
provider_id: Annotated[
|
||
int,
|
||
typer.Option(
|
||
"--provider-id",
|
||
"-p",
|
||
help="该模型所属于的模型提供商 ID",
|
||
),
|
||
],
|
||
) -> None:
|
||
"""
|
||
添加 AI 模型。在此之前需要先添加该模型的提供商。
|
||
"""
|
||
load_dotenv(ENV_PATH)
|
||
|
||
from sqlmodel import Session
|
||
|
||
from nyahome.database import AiiModel, engine
|
||
|
||
with Session(engine) as session:
|
||
am = AiiModel(model_name=model_name, max_context_length=max_context_length, aii_provider_id=provider_id)
|
||
session.add(am)
|
||
session.commit()
|
||
session.refresh(am)
|
||
|
||
console.print(f"[cyan]已添加模型 [{am.id}][{am.model_name}]({am.aii_provider_id})[/cyan]")
|