Files
skill-template/llm-manager/scripts/db.py
2026-04-04 10:35:02 +08:00

293 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
llm-manager 本地数据库:
- llm_keys: API Key 记录
- llm_web_accounts: 网页模式账号关联记录(账号主数据仍由 account-manager 管理)
"""
import os
import sqlite3
import time
from typing import Optional
from providers import get_data_root, get_user_id
SKILL_SLUG = "llm-manager"
# SQLite 无独立 DATETIME时间统一存 INTEGER Unix 秒UTC
LLM_KEYS_TABLE_SQL = """
CREATE TABLE llm_keys (
id INTEGER PRIMARY KEY AUTOINCREMENT, -- 主键(自增)
provider TEXT NOT NULL, -- 平台 slugdoubao/deepseek/qianwen/kimi/yiyan/yuanbao
label TEXT NOT NULL DEFAULT '', -- 用户自定义备注如「公司Key」
api_key TEXT NOT NULL, -- API Key 原文
default_model TEXT, -- 默认模型doubao 须填 ep-xxx
is_active INTEGER NOT NULL DEFAULT 1, -- 是否启用0 停用 1 启用
last_used_at INTEGER, -- 最近调用时间Unix 秒;从未用过为 NULL
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL
);
"""
LLM_WEB_ACCOUNTS_TABLE_SQL = """
CREATE TABLE llm_web_accounts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
provider TEXT NOT NULL,
account_id INTEGER NOT NULL,
account_name TEXT NOT NULL DEFAULT '',
login_status INTEGER NOT NULL DEFAULT 1,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
UNIQUE(provider, account_id)
);
"""
def _now_unix() -> int:
return int(time.time())
def get_skill_data_dir() -> str:
path = os.path.join(get_data_root(), get_user_id(), SKILL_SLUG)
os.makedirs(path, exist_ok=True)
return path
def get_db_path() -> str:
return os.path.join(get_skill_data_dir(), f"{SKILL_SLUG}.db")
def get_conn():
return sqlite3.connect(get_db_path())
def init_db():
conn = get_conn()
try:
cur = conn.cursor()
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='llm_keys'")
if not cur.fetchone():
cur.executescript(LLM_KEYS_TABLE_SQL)
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='llm_web_accounts'")
if not cur.fetchone():
cur.executescript(LLM_WEB_ACCOUNTS_TABLE_SQL)
conn.commit()
finally:
conn.close()
# ---------------------------------------------------------------------------
# CRUD
# ---------------------------------------------------------------------------
def add_key(provider: str, api_key: str, model: Optional[str] = None, label: str = "") -> int:
init_db()
now = _now_unix()
conn = get_conn()
try:
cur = conn.cursor()
cur.execute(
"""
INSERT INTO llm_keys (provider, label, api_key, default_model, is_active, created_at, updated_at)
VALUES (?, ?, ?, ?, 1, ?, ?)
""",
(provider, label or "", api_key, model, now, now),
)
new_id = cur.lastrowid
conn.commit()
return new_id
finally:
conn.close()
def upsert_web_account(provider: str, account_id: int, account_name: str = "", login_status: int = 1) -> int:
init_db()
now = _now_unix()
conn = get_conn()
try:
cur = conn.cursor()
cur.execute(
"""
INSERT INTO llm_web_accounts (provider, account_id, account_name, login_status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(provider, account_id) DO UPDATE SET
account_name = excluded.account_name,
login_status = excluded.login_status,
updated_at = excluded.updated_at
""",
(provider, int(account_id), account_name or "", int(login_status or 0), now, now),
)
conn.commit()
cur.execute(
"SELECT id FROM llm_web_accounts WHERE provider = ? AND account_id = ?",
(provider, int(account_id)),
)
row = cur.fetchone()
return int(row[0]) if row else 0
finally:
conn.close()
def list_keys(provider: Optional[str] = None, limit: int = 10) -> list:
init_db()
if not isinstance(limit, int) or limit <= 0:
limit = 10
conn = get_conn()
try:
cur = conn.cursor()
if provider:
cur.execute(
"SELECT id, provider, label, api_key, default_model, is_active, last_used_at, created_at "
"FROM llm_keys WHERE provider = ? ORDER BY created_at DESC, id DESC LIMIT ?",
(provider, limit),
)
else:
cur.execute(
"SELECT id, provider, label, api_key, default_model, is_active, last_used_at, created_at "
"FROM llm_keys ORDER BY created_at DESC, id DESC LIMIT ?",
(limit,),
)
rows = cur.fetchall()
finally:
conn.close()
result = []
for row in rows:
result.append({
"id": row[0],
"provider": row[1],
"label": row[2] or "",
"api_key": row[3],
"default_model": row[4] or "",
"is_active": row[5],
"last_used_at": row[6],
"created_at": row[7],
})
return result
def list_web_accounts(provider: Optional[str] = None, limit: int = 10) -> list:
init_db()
if not isinstance(limit, int) or limit <= 0:
limit = 10
conn = get_conn()
try:
cur = conn.cursor()
if provider:
cur.execute(
"SELECT id, provider, account_id, account_name, login_status, created_at, updated_at "
"FROM llm_web_accounts WHERE provider = ? ORDER BY created_at DESC, id DESC LIMIT ?",
(provider, limit),
)
else:
cur.execute(
"SELECT id, provider, account_id, account_name, login_status, created_at, updated_at "
"FROM llm_web_accounts ORDER BY created_at DESC, id DESC LIMIT ?",
(limit,),
)
rows = cur.fetchall()
finally:
conn.close()
result = []
for row in rows:
result.append({
"id": row[0],
"provider": row[1],
"account_id": row[2],
"account_name": row[3] or "",
"login_status": int(row[4] or 0),
"created_at": row[5],
"updated_at": row[6],
})
return result
def get_key_by_id(key_id: int) -> Optional[dict]:
init_db()
conn = get_conn()
try:
cur = conn.cursor()
cur.execute(
"SELECT id, provider, label, api_key, default_model, is_active, last_used_at, created_at "
"FROM llm_keys WHERE id = ?",
(key_id,),
)
row = cur.fetchone()
if not row:
return None
return {
"id": row[0],
"provider": row[1],
"label": row[2] or "",
"api_key": row[3],
"default_model": row[4] or "",
"is_active": row[5],
"last_used_at": row[6],
"created_at": row[7],
}
finally:
conn.close()
def delete_key(key_id: int) -> bool:
init_db()
conn = get_conn()
try:
cur = conn.cursor()
cur.execute("SELECT id FROM llm_keys WHERE id = ?", (key_id,))
if not cur.fetchone():
return False
cur.execute("DELETE FROM llm_keys WHERE id = ?", (key_id,))
conn.commit()
return True
finally:
conn.close()
def find_active_key(provider: str) -> Optional[dict]:
"""查找该平台第一个 is_active=1 的 key按 id 升序)。"""
init_db()
conn = get_conn()
try:
cur = conn.cursor()
cur.execute(
"SELECT id, provider, label, api_key, default_model, is_active, last_used_at "
"FROM llm_keys WHERE provider = ? AND is_active = 1 ORDER BY id LIMIT 1",
(provider,),
)
row = cur.fetchone()
if not row:
return None
return {
"id": row[0],
"provider": row[1],
"label": row[2] or "",
"api_key": row[3],
"default_model": row[4] or "",
"is_active": row[5],
"last_used_at": row[6],
}
finally:
conn.close()
def mark_key_used(key_id: int):
now = _now_unix()
conn = get_conn()
try:
cur = conn.cursor()
cur.execute(
"UPDATE llm_keys SET last_used_at = ?, updated_at = ? WHERE id = ?",
(now, now, key_id),
)
conn.commit()
finally:
conn.close()
def _mask_key(api_key: str) -> str:
"""展示时打码前4位 + ... + 后4位。"""
k = api_key or ""
if len(k) <= 8:
return k[:2] + "****"
return k[:4] + "..." + k[-4:]