feat(cli,database): 更新数据库引擎创建模式
This commit is contained in:
@@ -7,10 +7,3 @@ console = Console()
|
|||||||
DATA_DIR = Path.cwd() / ".nyahome"
|
DATA_DIR = Path.cwd() / ".nyahome"
|
||||||
ENV_PATH = DATA_DIR / ".env"
|
ENV_PATH = DATA_DIR / ".env"
|
||||||
LOGGING_YAML = DATA_DIR / "logging.yaml"
|
LOGGING_YAML = DATA_DIR / "logging.yaml"
|
||||||
|
|
||||||
db_driver_available = {
|
|
||||||
"sqlite": ["sqlite3"],
|
|
||||||
"mysql": ["pymysql"],
|
|
||||||
"postgresql": ["psycopg"],
|
|
||||||
}
|
|
||||||
db_type_allowlist = ["sqlite", "mysql", "postgresql"]
|
|
||||||
|
|||||||
@@ -81,6 +81,14 @@ def add_model(
|
|||||||
help="该模型所属于的模型提供商 ID",
|
help="该模型所属于的模型提供商 ID",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
reasonable: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
"--reasonable",
|
||||||
|
"-r",
|
||||||
|
help="支持思考",
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
添加 AI 模型。在此之前需要先添加该模型的提供商。
|
添加 AI 模型。在此之前需要先添加该模型的提供商。
|
||||||
@@ -92,7 +100,12 @@ def add_model(
|
|||||||
from nyahome.database import AiiModel, engine
|
from nyahome.database import AiiModel, engine
|
||||||
|
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
am = AiiModel(model_name=model_name, max_context_length=max_context_length, aii_provider_id=provider_id)
|
am = AiiModel(
|
||||||
|
model_name=model_name,
|
||||||
|
max_context_length=max_context_length,
|
||||||
|
aii_provider_id=provider_id,
|
||||||
|
reasonable=reasonable,
|
||||||
|
)
|
||||||
session.add(am)
|
session.add(am)
|
||||||
session.commit()
|
session.commit()
|
||||||
session.refresh(am)
|
session.refresh(am)
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ from importlib.util import find_spec
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Mapping
|
from typing import Mapping
|
||||||
|
|
||||||
from nyahome.cli.cli import console, db_driver_available, db_type_allowlist
|
from nyahome.cli.cli import console
|
||||||
|
from nyahome.database.engine import db_driver_available, db_type_allowlist
|
||||||
|
|
||||||
|
|
||||||
class CliWarning:
|
class CliWarning:
|
||||||
@@ -73,7 +74,7 @@ def check_database_type(environ: Mapping[str, str | None]) -> None:
|
|||||||
if not db_host:
|
if not db_host:
|
||||||
cw.warning("NYAHOME_DB_HOST 未设置,将使用 [cyan]localhost[/cyan] 作为默认值。")
|
cw.warning("NYAHOME_DB_HOST 未设置,将使用 [cyan]localhost[/cyan] 作为默认值。")
|
||||||
if not db_port:
|
if not db_port:
|
||||||
cw.warning("NYAHOME_DB_PORT 未设置,将使用 [cyan]3006[/cyan] 作为默认值。")
|
cw.warning("NYAHOME_DB_PORT 未设置,将使用 [cyan]3306[/cyan] 作为默认值。")
|
||||||
cw.info("自检未检查数据库状态是否可用。")
|
cw.info("自检未检查数据库状态是否可用。")
|
||||||
else:
|
else:
|
||||||
cw.info("使用 sqlite 数据库,跳过数据库凭证检查。")
|
cw.info("使用 sqlite 数据库,跳过数据库凭证检查。")
|
||||||
|
|||||||
@@ -42,8 +42,11 @@ def set_env(
|
|||||||
|
|
||||||
保存在 .nyahome 内的 .env 文件。
|
保存在 .nyahome 内的 .env 文件。
|
||||||
"""
|
"""
|
||||||
set_key(ENV_PATH, f"NYAHOME_{key.upper()}", value)
|
key = key.upper()
|
||||||
console.print(f"[cyan]已设置环境变量 NYAHOME_{key}。[/cyan]")
|
if not key.startswith("NYAHOME_"):
|
||||||
|
key = f"NYAHOME_{key}"
|
||||||
|
set_key(ENV_PATH, key, value)
|
||||||
|
console.print(f"[cyan]已设置环境变量 {key}。[/cyan]")
|
||||||
|
|
||||||
|
|
||||||
@env_app.command(name="unset")
|
@env_app.command(name="unset")
|
||||||
|
|||||||
@@ -1,5 +1,37 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
from sqlalchemy import Engine
|
||||||
from sqlmodel import create_engine
|
from sqlmodel import create_engine
|
||||||
|
|
||||||
engine = create_engine(os.environ["NYAHOME_SQL_URL"])
|
db_driver_available = {
|
||||||
|
"sqlite": ["sqlite3"],
|
||||||
|
"mysql": ["pymysql"],
|
||||||
|
"postgresql": ["psycopg"],
|
||||||
|
}
|
||||||
|
db_type_allowlist = ["sqlite", "mysql", "postgresql"]
|
||||||
|
|
||||||
|
|
||||||
|
def build_engine() -> Engine:
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
db_type = os.environ.get("NYAHOME_DB_TYPE", "sqlite")
|
||||||
|
db_driver = os.environ.get("NYAHOME_DB_DRIVER")
|
||||||
|
if db_type not in db_type_allowlist:
|
||||||
|
logger.warning(f"数据库类型 {db_type} 不受 NyaHome 官方支持,建议改用受支持的数据库:{db_type_allowlist}")
|
||||||
|
else:
|
||||||
|
if db_driver is None:
|
||||||
|
db_driver = db_driver_available[db_type][0]
|
||||||
|
|
||||||
|
if db_type == "sqlite":
|
||||||
|
return create_engine(f"sqlite+{db_driver}:///.nyahome/nyahome.db")
|
||||||
|
db_name = os.environ.get("NYAHOME_DB_NAME", "nyahome")
|
||||||
|
db_user = os.environ.get("NYAHOME_DB_USER", "nyahome")
|
||||||
|
db_password = os.environ.get("NYAHOME_DB_PASSWORD", "nyahome")
|
||||||
|
db_host = os.environ.get("NYAHOME_DB_HOST", "localhost")
|
||||||
|
db_port = os.environ.get("NYAHOME_DB_PORT", "3306")
|
||||||
|
return create_engine(f"{db_type}+{db_driver}://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}")
|
||||||
|
|
||||||
|
|
||||||
|
engine = build_engine()
|
||||||
|
|||||||
@@ -89,8 +89,9 @@ def init() -> None:
|
|||||||
from dotenv import set_key
|
from dotenv import set_key
|
||||||
from rich.prompt import Confirm, IntPrompt, Prompt
|
from rich.prompt import Confirm, IntPrompt, Prompt
|
||||||
|
|
||||||
from nyahome.cli.cli import DATA_DIR, ENV_PATH, LOGGING_YAML, db_driver_available, db_type_allowlist
|
from nyahome.cli.cli import DATA_DIR, ENV_PATH, LOGGING_YAML
|
||||||
from nyahome.cli.cli_check import LOGGING_YAML_CONTENT
|
from nyahome.cli.cli_check import LOGGING_YAML_CONTENT
|
||||||
|
from nyahome.database.engine import db_driver_available, db_type_allowlist
|
||||||
|
|
||||||
console.print("\n准备初始化 NyaHome。")
|
console.print("\n准备初始化 NyaHome。")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user