115 lines
3.3 KiB
Python
115 lines
3.3 KiB
Python
"""提示词模板:表内数据访问与种子。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import random
|
|
import sqlite3
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from content_manager.constants import PROMPT_TEMPLATE_SEEDS, PUBLISH_PLATFORM_CN
|
|
from content_manager.util.timeutil import now_unix
|
|
|
|
|
|
def seed_prompt_templates_if_empty(cur: sqlite3.Cursor) -> None:
|
|
ts = now_unix()
|
|
for platform, templates in PROMPT_TEMPLATE_SEEDS.items():
|
|
cur.execute("SELECT COUNT(*) FROM prompt_templates WHERE platform = ?", (platform,))
|
|
count = int(cur.fetchone()[0] or 0)
|
|
if count > 0:
|
|
continue
|
|
for idx, tpl in enumerate(templates, start=1):
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO prompt_templates (platform, name, template_text, is_active, created_at, updated_at)
|
|
VALUES (?, ?, ?, 1, ?, ?)
|
|
""",
|
|
(platform, f"{PUBLISH_PLATFORM_CN.get(platform, platform)}模板{idx}", tpl, ts, ts),
|
|
)
|
|
|
|
|
|
def count_by_platform(cur: sqlite3.Cursor, platform: str) -> int:
|
|
cur.execute("SELECT COUNT(*) FROM prompt_templates WHERE platform = ?", (platform,))
|
|
return int(cur.fetchone()[0] or 0)
|
|
|
|
|
|
def fetch_active_templates(conn: sqlite3.Connection, platform: str) -> List[Tuple[Any, ...]]:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
SELECT id, platform, name, template_text
|
|
FROM prompt_templates
|
|
WHERE platform = ? AND is_active = 1
|
|
ORDER BY id ASC
|
|
""",
|
|
(platform,),
|
|
)
|
|
return list(cur.fetchall())
|
|
|
|
|
|
def fetch_common_fallback(conn: sqlite3.Connection) -> List[Tuple[Any, ...]]:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
SELECT id, platform, name, template_text
|
|
FROM prompt_templates
|
|
WHERE platform = 'common' AND is_active = 1
|
|
ORDER BY id ASC
|
|
"""
|
|
)
|
|
return list(cur.fetchall())
|
|
|
|
|
|
def pick_random_template(rows: List[Tuple[Any, ...]]) -> Optional[Dict[str, Any]]:
|
|
if not rows:
|
|
return None
|
|
rid, p, name, text = random.choice(rows)
|
|
return {"id": int(rid), "platform": p, "name": name, "template_text": text}
|
|
|
|
|
|
def insert_usage(
|
|
conn: sqlite3.Connection,
|
|
template_id: int,
|
|
llm_target: str,
|
|
platform: str,
|
|
topic: str,
|
|
article_id: Optional[int],
|
|
) -> None:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO prompt_template_usage (template_id, llm_target, platform, topic, article_id, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(template_id, llm_target, platform, topic, article_id, now_unix()),
|
|
)
|
|
|
|
|
|
def list_templates(
|
|
conn: sqlite3.Connection,
|
|
platform: Optional[str],
|
|
limit: int,
|
|
) -> List[Tuple[Any, ...]]:
|
|
cur = conn.cursor()
|
|
if platform:
|
|
cur.execute(
|
|
"""
|
|
SELECT id, platform, name, is_active, updated_at
|
|
FROM prompt_templates
|
|
WHERE platform = ?
|
|
ORDER BY id DESC
|
|
LIMIT ?
|
|
""",
|
|
(platform, int(limit)),
|
|
)
|
|
else:
|
|
cur.execute(
|
|
"""
|
|
SELECT id, platform, name, is_active, updated_at
|
|
FROM prompt_templates
|
|
ORDER BY id DESC
|
|
LIMIT ?
|
|
""",
|
|
(int(limit),),
|
|
)
|
|
return list(cur.fetchall())
|