From 03928c6c590ff958575192fe6e35475dcccb13c8 Mon Sep 17 00:00:00 2001 From: MangoFanFanw Date: Thu, 4 Jun 2026 18:51:09 +0800 Subject: [PATCH] =?UTF-8?q?feat(cli,database):=20=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E5=BC=95=E6=93=8E=E5=88=9B=E5=BB=BA?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/nyahome/cli/cli.py | 7 ------- src/nyahome/cli/cli_aii.py | 15 ++++++++++++++- src/nyahome/cli/cli_check.py | 5 +++-- src/nyahome/cli/cli_env.py | 7 +++++-- src/nyahome/database/engine.py | 34 +++++++++++++++++++++++++++++++++- src/nyahome/manage.py | 3 ++- 6 files changed, 57 insertions(+), 14 deletions(-) diff --git a/src/nyahome/cli/cli.py b/src/nyahome/cli/cli.py index c7de33e..487443a 100644 --- a/src/nyahome/cli/cli.py +++ b/src/nyahome/cli/cli.py @@ -7,10 +7,3 @@ console = Console() DATA_DIR = Path.cwd() / ".nyahome" ENV_PATH = DATA_DIR / ".env" LOGGING_YAML = DATA_DIR / "logging.yaml" - -db_driver_available = { - "sqlite": ["sqlite3"], - "mysql": ["pymysql"], - "postgresql": ["psycopg"], -} -db_type_allowlist = ["sqlite", "mysql", "postgresql"] diff --git a/src/nyahome/cli/cli_aii.py b/src/nyahome/cli/cli_aii.py index 5ff629e..8a3f5e4 100644 --- a/src/nyahome/cli/cli_aii.py +++ b/src/nyahome/cli/cli_aii.py @@ -81,6 +81,14 @@ def add_model( help="该模型所属于的模型提供商 ID", ), ], + reasonable: Annotated[ + bool, + typer.Option( + "--reasonable", + "-r", + help="支持思考", + ), + ] = False, ) -> None: """ 添加 AI 模型。在此之前需要先添加该模型的提供商。 @@ -92,7 +100,12 @@ def add_model( from nyahome.database import AiiModel, engine with Session(engine) as session: - am = AiiModel(model_name=model_name, max_context_length=max_context_length, aii_provider_id=provider_id) + am = AiiModel( + model_name=model_name, + max_context_length=max_context_length, + aii_provider_id=provider_id, + reasonable=reasonable, + ) session.add(am) session.commit() session.refresh(am) diff --git a/src/nyahome/cli/cli_check.py b/src/nyahome/cli/cli_check.py index 56fcf6b..1373c3d 100644 --- a/src/nyahome/cli/cli_check.py +++ b/src/nyahome/cli/cli_check.py @@ -4,7 +4,8 @@ from importlib.util import find_spec from pathlib import Path from typing import Mapping -from nyahome.cli.cli import console, db_driver_available, db_type_allowlist +from nyahome.cli.cli import console +from nyahome.database.engine import db_driver_available, db_type_allowlist class CliWarning: @@ -73,7 +74,7 @@ def check_database_type(environ: Mapping[str, str | None]) -> None: if not db_host: cw.warning("NYAHOME_DB_HOST 未设置,将使用 [cyan]localhost[/cyan] 作为默认值。") if not db_port: - cw.warning("NYAHOME_DB_PORT 未设置,将使用 [cyan]3006[/cyan] 作为默认值。") + cw.warning("NYAHOME_DB_PORT 未设置,将使用 [cyan]3306[/cyan] 作为默认值。") cw.info("自检未检查数据库状态是否可用。") else: cw.info("使用 sqlite 数据库,跳过数据库凭证检查。") diff --git a/src/nyahome/cli/cli_env.py b/src/nyahome/cli/cli_env.py index 7659dcd..c7fc584 100644 --- a/src/nyahome/cli/cli_env.py +++ b/src/nyahome/cli/cli_env.py @@ -42,8 +42,11 @@ def set_env( 保存在 .nyahome 内的 .env 文件。 """ - set_key(ENV_PATH, f"NYAHOME_{key.upper()}", value) - console.print(f"[cyan]已设置环境变量 NYAHOME_{key}。[/cyan]") + key = key.upper() + if not key.startswith("NYAHOME_"): + key = f"NYAHOME_{key}" + set_key(ENV_PATH, key, value) + console.print(f"[cyan]已设置环境变量 {key}。[/cyan]") @env_app.command(name="unset") diff --git a/src/nyahome/database/engine.py b/src/nyahome/database/engine.py index 8201959..8a9346f 100644 --- a/src/nyahome/database/engine.py +++ b/src/nyahome/database/engine.py @@ -1,5 +1,37 @@ import os +from sqlalchemy import Engine from sqlmodel import create_engine -engine = create_engine(os.environ["NYAHOME_SQL_URL"]) +db_driver_available = { + "sqlite": ["sqlite3"], + "mysql": ["pymysql"], + "postgresql": ["psycopg"], +} +db_type_allowlist = ["sqlite", "mysql", "postgresql"] + + +def build_engine() -> Engine: + from logging import getLogger + + logger = getLogger(__name__) + + db_type = os.environ.get("NYAHOME_DB_TYPE", "sqlite") + db_driver = os.environ.get("NYAHOME_DB_DRIVER") + if db_type not in db_type_allowlist: + logger.warning(f"数据库类型 {db_type} 不受 NyaHome 官方支持,建议改用受支持的数据库:{db_type_allowlist}") + else: + if db_driver is None: + db_driver = db_driver_available[db_type][0] + + if db_type == "sqlite": + return create_engine(f"sqlite+{db_driver}:///.nyahome/nyahome.db") + db_name = os.environ.get("NYAHOME_DB_NAME", "nyahome") + db_user = os.environ.get("NYAHOME_DB_USER", "nyahome") + db_password = os.environ.get("NYAHOME_DB_PASSWORD", "nyahome") + db_host = os.environ.get("NYAHOME_DB_HOST", "localhost") + db_port = os.environ.get("NYAHOME_DB_PORT", "3306") + return create_engine(f"{db_type}+{db_driver}://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}") + + +engine = build_engine() diff --git a/src/nyahome/manage.py b/src/nyahome/manage.py index 5b5ef7b..a847bac 100644 --- a/src/nyahome/manage.py +++ b/src/nyahome/manage.py @@ -89,8 +89,9 @@ def init() -> None: from dotenv import set_key from rich.prompt import Confirm, IntPrompt, Prompt - from nyahome.cli.cli import DATA_DIR, ENV_PATH, LOGGING_YAML, db_driver_available, db_type_allowlist + from nyahome.cli.cli import DATA_DIR, ENV_PATH, LOGGING_YAML from nyahome.cli.cli_check import LOGGING_YAML_CONTENT + from nyahome.database.engine import db_driver_available, db_type_allowlist console.print("\n准备初始化 NyaHome。")