fs-lawrisk/lawrisk/services/lawrisk_service.py

569 lines
21 KiB
Python

"""
LawRisk embedding retrieval service.
Responsibilities:
- DB connection helpers (PostgreSQL via pg8000)
- Schema management (fs_law_risk.law_sub, fs_law_risk.law_sub_per)
- Embedding client (Aliyun DashScope OpenAI-compatible embeddings API)
- Chat client for LLM-based selection (Qwen via OpenAI-compatible /chat/completions)
- Search logic: embedding cosine or LLM subject selection
Env vars used:
- PG_HOST, PG_PORT, PG_USER, PG_PASSWORD (PostgreSQL credentials)
- PG_DATABASE (defaults to fs_law_risk)
- PG_ADMIN_DB (defaults to postgres; used for CREATE DATABASE)
- DASHSCOPE_API_KEY (embedding API key)
- DASHSCOPE_BASE_URL (defaults to https://dashscope.aliyuncs.com/compatible-mode/v1)
- DASHSCOPE_EMBED_MODEL (defaults to text-embedding-v4)
- DASHSCOPE_EMBED_DIM (defaults to 1024)
- DASHSCOPE_CHAT_MODEL (defaults to qwen-plus-latest)
"""
from __future__ import annotations
import json
import math
import os
import ssl
import urllib.request
import urllib.error
from typing import Any, Dict, Iterable, List, Optional, Tuple
import pg8000.dbapi as pg
DEFAULT_DB = "fs_law_risk"
DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
EMBED_MODEL = os.getenv("DASHSCOPE_EMBED_MODEL", "text-embedding-v4")
EMBED_DIM = int(os.getenv("DASHSCOPE_EMBED_DIM", "1024"))
EMBED_MAX_BATCH = max(1, int(os.getenv("DASHSCOPE_MAX_BATCH", "10"))) # DashScope limit <=10
CHAT_MODEL = os.getenv("DASHSCOPE_CHAT_MODEL", "qwen-plus-latest")
# Similarity thresholds (env configurable)
RETURN_IF_GE = float(os.getenv("LAWRISK_RETURN_IF_GE", "0.7"))
FALLBACK_GT = float(os.getenv("LAWRISK_FALLBACK_GT", "0.4"))
# Similarity thresholds (env configurable)
RETURN_IF_GE = float(os.getenv("LAWRISK_RETURN_IF_GE", "0.7"))
FALLBACK_GT = float(os.getenv("LAWRISK_FALLBACK_GT", "0.4"))
def _pg_conn(database: Optional[str] = None, autocommit: bool = False) -> pg.Connection:
host = os.getenv("PG_HOST", "8.138.196.105")
port = int(os.getenv("PG_PORT", "5432"))
user = os.getenv("PG_USER", "postgres")
password = os.getenv("PG_PASSWORD", "difyai123456")
dbname = database or os.getenv("PG_DATABASE", DEFAULT_DB)
conn = pg.connect(host=host, port=port, user=user, password=password, database=dbname)
conn.autocommit = autocommit
return conn
def ensure_database(dbname: str = DEFAULT_DB) -> None:
# Create database if not exists by connecting to postgres
admin_db = os.getenv("PG_ADMIN_DB", "postgres")
with _pg_conn(database=admin_db, autocommit=True) as c:
cur = c.cursor()
cur.execute("SELECT 1 FROM pg_database WHERE datname=%s", (dbname,))
if cur.fetchone() is None:
cur.execute(f"CREATE DATABASE {dbname}")
def ensure_schema() -> None:
with _pg_conn() as c:
cur = c.cursor()
# Store vectors and permit ids as JSONB for portability
cur.execute(
"""
CREATE TABLE IF NOT EXISTS law_sub (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
vector JSONB NOT NULL
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS law_sub_per (
subject_id TEXT PRIMARY KEY,
permit_ids JSONB NOT NULL
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS law_permit (
id TEXT PRIMARY KEY,
name TEXT NOT NULL
)
"""
)
c.commit()
class EmbeddingClient:
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
self.api_key = api_key or os.getenv("DASHSCOPE_API_KEY")
self.base_url = base_url or os.getenv("DASHSCOPE_BASE_URL", DEFAULT_BASE_URL)
if not self.api_key:
raise RuntimeError("DASHSCOPE_API_KEY is not set")
def embed_texts(self, texts: List[str]) -> List[List[float]]:
# sanitize inputs
clean_inputs = [str(t) for t in texts if isinstance(t, str) and str(t).strip()]
if not clean_inputs:
raise ValueError("No valid input texts for embeddings")
# chunk by provider batch limit and concatenate results to preserve order
out: List[List[float]] = []
for i in range(0, len(clean_inputs), EMBED_MAX_BATCH):
chunk = clean_inputs[i : i + EMBED_MAX_BATCH]
out.extend(self._embed_batch(chunk))
if len(out) != len(clean_inputs):
raise RuntimeError(
f"Embedding API returned unexpected result count: got {len(out)}, want {len(clean_inputs)}"
)
return out
def _embed_batch(self, texts: List[str]) -> List[List[float]]:
url = self.base_url.rstrip("/") + "/embeddings"
body = {
"model": EMBED_MODEL,
"input": texts,
"dimensions": EMBED_DIM,
"encoding_format": "float",
}
data = json.dumps(body).encode("utf-8")
req = urllib.request.Request(url, data=data, method="POST")
req.add_header("Authorization", f"Bearer {self.api_key}")
req.add_header("Content-Type", "application/json")
ctx = ssl.create_default_context()
try:
with urllib.request.urlopen(req, context=ctx, timeout=30) as resp:
raw = resp.read().decode("utf-8", errors="replace")
except urllib.error.HTTPError as e:
err_body = e.read().decode("utf-8", errors="replace") if hasattr(e, 'read') else ""
raise RuntimeError(
f"Embedding API error {e.code}: {err_body or e.reason} | sent={json.dumps(body, ensure_ascii=False)[:500]}"
) from e
payload = json.loads(raw)
out: List[List[float]] = []
for item in payload.get("data", []):
emb = item.get("embedding")
if isinstance(emb, list):
out.append([float(x) for x in emb])
return out
def embed_one(self, text: str) -> List[float]:
return self.embed_texts([text])[0]
class ChatClient:
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
self.api_key = api_key or os.getenv("DASHSCOPE_API_KEY")
self.base_url = base_url or os.getenv("DASHSCOPE_BASE_URL", DEFAULT_BASE_URL)
if not self.api_key:
raise RuntimeError("DASHSCOPE_API_KEY is not set")
def chat(self, messages: List[Dict[str, str]], model: Optional[str] = None, temperature: float = 0.2) -> str:
url = self.base_url.rstrip("/") + "/chat/completions"
body = {
"model": model or CHAT_MODEL,
"messages": messages,
"temperature": temperature,
}
data = json.dumps(body, ensure_ascii=False).encode("utf-8")
req = urllib.request.Request(url, data=data, method="POST")
req.add_header("Authorization", f"Bearer {self.api_key}")
req.add_header("Content-Type", "application/json")
ctx = ssl.create_default_context()
try:
with urllib.request.urlopen(req, context=ctx, timeout=60) as resp:
raw = resp.read().decode("utf-8", errors="replace")
except urllib.error.HTTPError as e:
err_body = e.read().decode("utf-8", errors="replace") if hasattr(e, 'read') else ""
raise RuntimeError(
f"Chat API error {e.code}: {err_body or e.reason}"
) from e
payload = json.loads(raw)
choices = payload.get("choices", [])
if not choices:
raise RuntimeError("Chat API returned no choices")
msg = choices[0].get("message", {})
content = msg.get("content", "")
return str(content)
def generate_question_suggestions(query: str, max_q: int = 5) -> List[str]:
"""Legacy LLM-based suggestion generator (kept for reference). Not used by default."""
try:
chat = ChatClient()
content = chat.chat([
{"role": "system", "content": "你是政务事项问答助手。请输出与主题事项高度相关的精简推荐问题,仅输出 JSON 数组。"},
{"role": "user", "content": f"请针对: {query} 给出不超过 {max_q} 条中文推荐问题,仅输出 JSON 数组。"},
], model=CHAT_MODEL, temperature=0.3)
txt = content.strip()
start = txt.find("[")
end = txt.rfind("]")
arr = json.loads(txt[start : end + 1] if start != -1 and end != -1 and end > start else txt)
out: List[str] = []
if isinstance(arr, list):
for x in arr:
if isinstance(x, str) and x.strip():
out.append(x.strip())
return out[:max_q]
except Exception:
return []
def _normalize_text(s: str) -> str:
return "".join(ch for ch in str(s) if ch.isalnum())
def shortlist_subjects(query: str, k: int = 5) -> List[Tuple[str, str]]:
"""Return up to k subjects with highest lexical overlap to query.
Simple char-level overlap score to keep it deterministic and fast.
"""
q = set(_normalize_text(query))
if not q:
q = set(query)
with _pg_conn() as c:
cur = c.cursor()
cur.execute("SELECT id, name FROM law_sub")
subs = [(str(sid), str(name)) for sid, name in cur.fetchall()]
scored: List[Tuple[float, Tuple[str, str]]] = []
for sid, name in subs:
n = set(_normalize_text(name)) or set(name)
inter = len(q & n)
denom = max(1, len(n))
score = inter / denom
if inter > 0:
scored.append((score, (sid, name)))
scored.sort(key=lambda x: x[0], reverse=True)
return [row for _s, row in scored[:k]]
def suggest_questions_from_subjects(subject_names: List[str], max_q: int = 5) -> List[str]:
"""Return subject names directly (no extra wording)."""
out: List[str] = []
for nm in subject_names:
nm = (nm or "").strip()
if nm and nm not in out:
out.append(nm)
if len(out) >= max_q:
break
return out
def suggest_questions_embed(query: str, max_q: int = 5) -> List[str]:
"""Use embeddings to pick top-N subject names (no added text)."""
try:
client = EmbeddingClient()
qvec = client.embed_one(query)
except Exception:
# Embedding not available; fallback to lexical shortlist
subs = shortlist_subjects(query, max(1, max_q))
return [name for _sid, name in subs][:max_q]
# Load subjects with vectors
with _pg_conn() as c:
cur = c.cursor()
cur.execute("SELECT id, name, vector FROM law_sub")
subjects: List[Tuple[str, str, List[float]]] = []
for sid, name, vec_json in cur.fetchall():
if isinstance(vec_json, str):
try:
vec = json.loads(vec_json)
except Exception:
vec = []
else:
vec = vec_json
if isinstance(vec, list) and vec:
subjects.append((str(sid), str(name), [float(x) for x in vec]))
if not subjects:
subs = shortlist_subjects(query, max(1, max_q))
return [name for _sid, name in subs][:max_q]
scored: List[Tuple[float, Tuple[str, str]]] = []
for sid, name, vec in subjects:
s = _cosine(qvec, vec)
scored.append((s, (sid, name)))
scored.sort(key=lambda x: x[0], reverse=True)
# Take top subjects (more than max_q to allow templating to fill up to max_q)
top_subjects = [nm for _score, (_sid, nm) in scored[: max_q]]
return top_subjects
def _cosine(a: List[float], b: List[float]) -> float:
if not a or not b or len(a) != len(b):
return 0.0
dot = 0.0
na = 0.0
nb = 0.0
for x, y in zip(a, b):
dot += x * y
na += x * x
nb += y * y
if na == 0.0 or nb == 0.0:
return 0.0
return dot / math.sqrt(na * nb)
def upsert_subjects(
rows: Iterable[Tuple[str, str, List[float]]]
) -> None:
"""Upsert subjects into law_sub."""
with _pg_conn() as c:
cur = c.cursor()
for sid, name, vec in rows:
cur.execute(
"""
INSERT INTO law_sub (id, name, vector)
VALUES (%s, %s, %s::jsonb)
ON CONFLICT (id) DO UPDATE SET name=EXCLUDED.name, vector=EXCLUDED.vector
""",
(sid, name, json.dumps(vec)),
)
c.commit()
def upsert_subject_permits(rows: Iterable[Tuple[str, List[str]]]) -> None:
with _pg_conn() as c:
cur = c.cursor()
for sid, permit_ids in rows:
cur.execute(
"""
INSERT INTO law_sub_per (subject_id, permit_ids)
VALUES (%s, %s::jsonb)
ON CONFLICT (subject_id) DO UPDATE SET permit_ids=EXCLUDED.permit_ids
""",
(sid, json.dumps(permit_ids)),
)
c.commit()
def upsert_permits(rows: Iterable[Tuple[str, str]]) -> None:
"""Upsert permit catalog into law_permit (id -> name)."""
with _pg_conn() as c:
cur = c.cursor()
for pid, name in rows:
cur.execute(
"""
INSERT INTO law_permit (id, name)
VALUES (%s, %s)
ON CONFLICT (id) DO UPDATE SET name=EXCLUDED.name
""",
(pid, name),
)
c.commit()
def search_subjects(query: str, return_debug: bool = False, top_k_debug: int = 5) -> Dict[str, Any]:
"""Search by embedding similarity, return JSON object compliant with PRD.
Thresholds:
- return all with score >= 0.5
- if none >= 0.5 but max > 0.4, return the single best one
"""
client = EmbeddingClient()
qvec = client.embed_one(query)
# load all subjects
with _pg_conn() as c:
cur = c.cursor()
cur.execute("SELECT id, name, vector FROM law_sub")
subs: List[Tuple[str, str, List[float]]] = []
for sid, name, vec_json in cur.fetchall():
# vec_json may come back as Python list or JSON string depending on driver version
if isinstance(vec_json, str):
try:
vec = json.loads(vec_json)
except Exception:
vec = []
else:
vec = vec_json
subs.append((str(sid), str(name), list(vec) if isinstance(vec, list) else []))
# Build permit lookup
cur.execute("SELECT subject_id, permit_ids FROM law_sub_per")
per_map: Dict[str, List[str]] = {}
for sid, pids in cur.fetchall():
# pids may be list or JSON string
if isinstance(pids, str):
try:
p_list = json.loads(pids)
except Exception:
p_list = []
else:
p_list = list(pids) if isinstance(pids, list) else []
per_map[str(sid)] = [str(x) for x in p_list]
scored: List[Tuple[float, Tuple[str, str, List[float]]]] = []
for row in subs:
score = _cosine(qvec, row[2])
scored.append((score, row))
scored.sort(key=lambda x: x[0], reverse=True)
# Build permit name lookup
permit_name: Dict[str, str] = {}
try:
with _pg_conn() as c2:
cur2 = c2.cursor()
cur2.execute("SELECT id, name FROM law_permit")
for pid, pname in cur2.fetchall():
permit_name[str(pid)] = str(pname)
except Exception:
# If table missing or query fails, leave map empty; upstream should seed via ingest
permit_name = {}
results: List[Dict[str, Any]] = []
for score, (sid, name, _vec) in scored:
if score >= RETURN_IF_GE:
item = {
"id": sid,
"name": name,
# Build permit map: name -> id
"permit": {permit_name.get(pid, ""): pid for pid in per_map.get(sid, []) if permit_name.get(pid)},
}
if return_debug:
item["score"] = round(float(score), 6)
results.append(item)
if not results and scored and scored[0][0] > FALLBACK_GT:
sid, name, _ = scored[0][1]
best_score = scored[0][0]
item = {
"id": sid,
"name": name,
"permit": {permit_name.get(pid, ""): pid for pid in per_map.get(sid, []) if permit_name.get(pid)},
}
if return_debug:
item["score"] = round(float(best_score), 6)
results = [item]
out: Dict[str, Any] = {"risk_subject": results}
if return_debug:
decision = (
"returned_ge_threshold" if results
else "returned_top_fallback" if (scored and scored[0][0] > FALLBACK_GT)
else "no_match_below_fallback"
)
top_list = []
for s, (sid, name, _v) in scored[: max(0, top_k_debug) or 5]:
top_list.append({"id": sid, "name": name, "score": round(float(s), 6)})
out["debug"] = {
"query": query,
"qvec_dim": len(qvec),
"thresholds": {"return_if_ge": RETURN_IF_GE, "fallback_gt": FALLBACK_GT},
"num_subjects": len(subs),
"max_score": round(float(scored[0][0]), 6) if scored else 0.0,
"top_candidates": top_list,
"decision": decision,
}
return out
def search_subjects_llm(query: str, return_debug: bool = False, top_k_debug: int = 5) -> Dict[str, Any]:
"""Use LLM to pick one or more subject IDs from the catalog by instruction.
Steps:
- Load subject id+name list from DB
- Ask LLM (Qwen) to select at least one subject id from the list that best matches the user query
- Map selected ids to full entries (name+permit_ids) and return
"""
# Load catalog
with _pg_conn() as c:
cur = c.cursor()
cur.execute("SELECT id, name FROM law_sub")
subjects = [(str(sid), str(name)) for sid, name in cur.fetchall()]
cur.execute("SELECT subject_id, permit_ids FROM law_sub_per")
per_map: Dict[str, List[str]] = {}
for sid, pids in cur.fetchall():
if isinstance(pids, str):
try:
p_list = json.loads(pids)
except Exception:
p_list = []
else:
p_list = list(pids) if isinstance(pids, list) else []
per_map[str(sid)] = [str(x) for x in p_list]
# Build concise subject list block: id | name per line
# Keep within reasonable token limits; if too long, truncate and rely on LLM suggestion quality.
lines = [f"{sid}\t{name}" for sid, name in subjects]
subjects_block = "\n".join(lines)
system_msg = (
"你是政务事项检索助手。根据用户的中文查询,从给定的主题事项清单中选择最相关的主题事项。"
"只允许从清单中选择,不能编造。若没有足够相关的主题,请返回空数组 []."
"始终以 JSON 数组返回所选主题事项的 id 列表,例如: [\"id1\", \"id2\"]."
)
user_msg = (
f"用户问题: {query}\n\n"
f"主题事项清单(格式: id<TAB>名称):\n{subjects_block}\n\n"
"请仅输出 JSON 数组 (仅数组本身)。若无匹配请输出 []."
)
chat = ChatClient()
content = chat.chat([
{"role": "system", "content": system_msg},
{"role": "user", "content": user_msg},
], model=CHAT_MODEL, temperature=0.2)
# Try parsing as JSON array of strings; robustly extract if wrapped text exists
selected_ids: List[str] = []
try:
txt = content.strip()
start = txt.find("[")
end = txt.rfind("]")
if start != -1 and end != -1 and end > start:
arr = json.loads(txt[start : end + 1])
else:
arr = json.loads(txt)
if isinstance(arr, list):
for x in arr:
if isinstance(x, str):
selected_ids.append(x)
elif isinstance(x, dict) and "id" in x and isinstance(x["id"], str):
selected_ids.append(x["id"])
except Exception:
selected_ids = []
# Deduplicate and keep only ids that exist
id_set = {sid for sid, _ in subjects}
chosen = []
for sid in selected_ids:
if sid in id_set and sid not in chosen:
chosen.append(sid)
# Allow empty result when nothing is relevant
# Load permit names
permit_name: Dict[str, str] = {}
try:
with _pg_conn() as c2:
cur2 = c2.cursor()
cur2.execute("SELECT id, name FROM law_permit")
for pid, pname in cur2.fetchall():
permit_name[str(pid)] = str(pname)
except Exception:
permit_name = {}
results = []
name_map = {sid: name for sid, name in subjects}
for sid in chosen:
results.append({
"id": sid,
"name": name_map.get(sid, ""),
"permit": {permit_name.get(pid, ""): pid for pid in per_map.get(sid, []) if permit_name.get(pid)},
})
out: Dict[str, Any] = {"risk_subject": results}
if return_debug:
out["debug"] = {
"model": CHAT_MODEL,
"num_subjects": len(subjects),
"selected_ids": chosen,
"allow_empty": True,
}
return out