""" LawRisk embedding retrieval service. Responsibilities: - DB connection helpers (PostgreSQL via pg8000) - Schema management (fs_law_risk.law_sub, fs_law_risk.law_sub_per) - Embedding client (Aliyun DashScope OpenAI-compatible embeddings API) - Chat client for LLM-based selection (Qwen via OpenAI-compatible /chat/completions) - Search logic: embedding cosine or LLM subject selection Env vars used: - PG_HOST, PG_PORT, PG_USER, PG_PASSWORD (PostgreSQL credentials) - PG_DATABASE (defaults to fs_law_risk) - PG_ADMIN_DB (defaults to postgres; used for CREATE DATABASE) - DASHSCOPE_API_KEY (embedding API key) - DASHSCOPE_BASE_URL (defaults to https://dashscope.aliyuncs.com/compatible-mode/v1) - DASHSCOPE_EMBED_MODEL (defaults to text-embedding-v4) - DASHSCOPE_EMBED_DIM (defaults to 1024) - DASHSCOPE_CHAT_MODEL (defaults to qwen-plus-latest) """ from __future__ import annotations import json import math import os import ssl import urllib.request import urllib.error from typing import Any, Dict, Iterable, List, Optional, Tuple import pg8000.dbapi as pg DEFAULT_DB = "fs_law_risk" DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1" EMBED_MODEL = os.getenv("DASHSCOPE_EMBED_MODEL", "text-embedding-v4") EMBED_DIM = int(os.getenv("DASHSCOPE_EMBED_DIM", "1024")) EMBED_MAX_BATCH = max(1, int(os.getenv("DASHSCOPE_MAX_BATCH", "10"))) # DashScope limit <=10 CHAT_MODEL = os.getenv("DASHSCOPE_CHAT_MODEL", "qwen-plus-latest") # Similarity thresholds (env configurable) RETURN_IF_GE = float(os.getenv("LAWRISK_RETURN_IF_GE", "0.7")) FALLBACK_GT = float(os.getenv("LAWRISK_FALLBACK_GT", "0.4")) # Similarity thresholds (env configurable) RETURN_IF_GE = float(os.getenv("LAWRISK_RETURN_IF_GE", "0.7")) FALLBACK_GT = float(os.getenv("LAWRISK_FALLBACK_GT", "0.4")) def _pg_conn(database: Optional[str] = None, autocommit: bool = False) -> pg.Connection: host = os.getenv("PG_HOST", "8.138.196.105") port = int(os.getenv("PG_PORT", "5432")) user = os.getenv("PG_USER", "postgres") password = os.getenv("PG_PASSWORD", "difyai123456") dbname = database or os.getenv("PG_DATABASE", DEFAULT_DB) conn = pg.connect(host=host, port=port, user=user, password=password, database=dbname) conn.autocommit = autocommit return conn def ensure_database(dbname: str = DEFAULT_DB) -> None: # Create database if not exists by connecting to postgres admin_db = os.getenv("PG_ADMIN_DB", "postgres") with _pg_conn(database=admin_db, autocommit=True) as c: cur = c.cursor() cur.execute("SELECT 1 FROM pg_database WHERE datname=%s", (dbname,)) if cur.fetchone() is None: cur.execute(f"CREATE DATABASE {dbname}") def ensure_schema() -> None: with _pg_conn() as c: cur = c.cursor() # Store vectors and permit ids as JSONB for portability cur.execute( """ CREATE TABLE IF NOT EXISTS law_sub ( id TEXT PRIMARY KEY, name TEXT NOT NULL, vector JSONB NOT NULL ) """ ) cur.execute( """ CREATE TABLE IF NOT EXISTS law_sub_per ( subject_id TEXT PRIMARY KEY, permit_ids JSONB NOT NULL ) """ ) cur.execute( """ CREATE TABLE IF NOT EXISTS law_permit ( id TEXT PRIMARY KEY, name TEXT NOT NULL ) """ ) c.commit() class EmbeddingClient: def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None): self.api_key = api_key or os.getenv("DASHSCOPE_API_KEY") self.base_url = base_url or os.getenv("DASHSCOPE_BASE_URL", DEFAULT_BASE_URL) if not self.api_key: raise RuntimeError("DASHSCOPE_API_KEY is not set") def embed_texts(self, texts: List[str]) -> List[List[float]]: # sanitize inputs clean_inputs = [str(t) for t in texts if isinstance(t, str) and str(t).strip()] if not clean_inputs: raise ValueError("No valid input texts for embeddings") # chunk by provider batch limit and concatenate results to preserve order out: List[List[float]] = [] for i in range(0, len(clean_inputs), EMBED_MAX_BATCH): chunk = clean_inputs[i : i + EMBED_MAX_BATCH] out.extend(self._embed_batch(chunk)) if len(out) != len(clean_inputs): raise RuntimeError( f"Embedding API returned unexpected result count: got {len(out)}, want {len(clean_inputs)}" ) return out def _embed_batch(self, texts: List[str]) -> List[List[float]]: url = self.base_url.rstrip("/") + "/embeddings" body = { "model": EMBED_MODEL, "input": texts, "dimensions": EMBED_DIM, "encoding_format": "float", } data = json.dumps(body).encode("utf-8") req = urllib.request.Request(url, data=data, method="POST") req.add_header("Authorization", f"Bearer {self.api_key}") req.add_header("Content-Type", "application/json") ctx = ssl.create_default_context() try: with urllib.request.urlopen(req, context=ctx, timeout=30) as resp: raw = resp.read().decode("utf-8", errors="replace") except urllib.error.HTTPError as e: err_body = e.read().decode("utf-8", errors="replace") if hasattr(e, 'read') else "" raise RuntimeError( f"Embedding API error {e.code}: {err_body or e.reason} | sent={json.dumps(body, ensure_ascii=False)[:500]}" ) from e payload = json.loads(raw) out: List[List[float]] = [] for item in payload.get("data", []): emb = item.get("embedding") if isinstance(emb, list): out.append([float(x) for x in emb]) return out def embed_one(self, text: str) -> List[float]: return self.embed_texts([text])[0] class ChatClient: def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None): self.api_key = api_key or os.getenv("DASHSCOPE_API_KEY") self.base_url = base_url or os.getenv("DASHSCOPE_BASE_URL", DEFAULT_BASE_URL) if not self.api_key: raise RuntimeError("DASHSCOPE_API_KEY is not set") def chat(self, messages: List[Dict[str, str]], model: Optional[str] = None, temperature: float = 0.2) -> str: url = self.base_url.rstrip("/") + "/chat/completions" body = { "model": model or CHAT_MODEL, "messages": messages, "temperature": temperature, } data = json.dumps(body, ensure_ascii=False).encode("utf-8") req = urllib.request.Request(url, data=data, method="POST") req.add_header("Authorization", f"Bearer {self.api_key}") req.add_header("Content-Type", "application/json") ctx = ssl.create_default_context() try: with urllib.request.urlopen(req, context=ctx, timeout=60) as resp: raw = resp.read().decode("utf-8", errors="replace") except urllib.error.HTTPError as e: err_body = e.read().decode("utf-8", errors="replace") if hasattr(e, 'read') else "" raise RuntimeError( f"Chat API error {e.code}: {err_body or e.reason}" ) from e payload = json.loads(raw) choices = payload.get("choices", []) if not choices: raise RuntimeError("Chat API returned no choices") msg = choices[0].get("message", {}) content = msg.get("content", "") return str(content) def generate_question_suggestions(query: str, max_q: int = 5) -> List[str]: """Legacy LLM-based suggestion generator (kept for reference). Not used by default.""" try: chat = ChatClient() content = chat.chat([ {"role": "system", "content": "你是政务事项问答助手。请输出与主题事项高度相关的精简推荐问题,仅输出 JSON 数组。"}, {"role": "user", "content": f"请针对: {query} 给出不超过 {max_q} 条中文推荐问题,仅输出 JSON 数组。"}, ], model=CHAT_MODEL, temperature=0.3) txt = content.strip() start = txt.find("[") end = txt.rfind("]") arr = json.loads(txt[start : end + 1] if start != -1 and end != -1 and end > start else txt) out: List[str] = [] if isinstance(arr, list): for x in arr: if isinstance(x, str) and x.strip(): out.append(x.strip()) return out[:max_q] except Exception: return [] def _normalize_text(s: str) -> str: return "".join(ch for ch in str(s) if ch.isalnum()) def shortlist_subjects(query: str, k: int = 5) -> List[Tuple[str, str]]: """Return up to k subjects with highest lexical overlap to query. Simple char-level overlap score to keep it deterministic and fast. """ q = set(_normalize_text(query)) if not q: q = set(query) with _pg_conn() as c: cur = c.cursor() cur.execute("SELECT id, name FROM law_sub") subs = [(str(sid), str(name)) for sid, name in cur.fetchall()] scored: List[Tuple[float, Tuple[str, str]]] = [] for sid, name in subs: n = set(_normalize_text(name)) or set(name) inter = len(q & n) denom = max(1, len(n)) score = inter / denom if inter > 0: scored.append((score, (sid, name))) scored.sort(key=lambda x: x[0], reverse=True) return [row for _s, row in scored[:k]] def suggest_questions_from_subjects(subject_names: List[str], max_q: int = 5) -> List[str]: """Return subject names directly (no extra wording).""" out: List[str] = [] for nm in subject_names: nm = (nm or "").strip() if nm and nm not in out: out.append(nm) if len(out) >= max_q: break return out def suggest_questions_embed(query: str, max_q: int = 5) -> List[str]: """Use embeddings to pick top-N subject names (no added text).""" try: client = EmbeddingClient() qvec = client.embed_one(query) except Exception: # Embedding not available; fallback to lexical shortlist subs = shortlist_subjects(query, max(1, max_q)) return [name for _sid, name in subs][:max_q] # Load subjects with vectors with _pg_conn() as c: cur = c.cursor() cur.execute("SELECT id, name, vector FROM law_sub") subjects: List[Tuple[str, str, List[float]]] = [] for sid, name, vec_json in cur.fetchall(): if isinstance(vec_json, str): try: vec = json.loads(vec_json) except Exception: vec = [] else: vec = vec_json if isinstance(vec, list) and vec: subjects.append((str(sid), str(name), [float(x) for x in vec])) if not subjects: subs = shortlist_subjects(query, max(1, max_q)) return [name for _sid, name in subs][:max_q] scored: List[Tuple[float, Tuple[str, str]]] = [] for sid, name, vec in subjects: s = _cosine(qvec, vec) scored.append((s, (sid, name))) scored.sort(key=lambda x: x[0], reverse=True) # Take top subjects (more than max_q to allow templating to fill up to max_q) top_subjects = [nm for _score, (_sid, nm) in scored[: max_q]] return top_subjects def _cosine(a: List[float], b: List[float]) -> float: if not a or not b or len(a) != len(b): return 0.0 dot = 0.0 na = 0.0 nb = 0.0 for x, y in zip(a, b): dot += x * y na += x * x nb += y * y if na == 0.0 or nb == 0.0: return 0.0 return dot / math.sqrt(na * nb) def upsert_subjects( rows: Iterable[Tuple[str, str, List[float]]] ) -> None: """Upsert subjects into law_sub.""" with _pg_conn() as c: cur = c.cursor() for sid, name, vec in rows: cur.execute( """ INSERT INTO law_sub (id, name, vector) VALUES (%s, %s, %s::jsonb) ON CONFLICT (id) DO UPDATE SET name=EXCLUDED.name, vector=EXCLUDED.vector """, (sid, name, json.dumps(vec)), ) c.commit() def upsert_subject_permits(rows: Iterable[Tuple[str, List[str]]]) -> None: with _pg_conn() as c: cur = c.cursor() for sid, permit_ids in rows: cur.execute( """ INSERT INTO law_sub_per (subject_id, permit_ids) VALUES (%s, %s::jsonb) ON CONFLICT (subject_id) DO UPDATE SET permit_ids=EXCLUDED.permit_ids """, (sid, json.dumps(permit_ids)), ) c.commit() def upsert_permits(rows: Iterable[Tuple[str, str]]) -> None: """Upsert permit catalog into law_permit (id -> name).""" with _pg_conn() as c: cur = c.cursor() for pid, name in rows: cur.execute( """ INSERT INTO law_permit (id, name) VALUES (%s, %s) ON CONFLICT (id) DO UPDATE SET name=EXCLUDED.name """, (pid, name), ) c.commit() def search_subjects(query: str, return_debug: bool = False, top_k_debug: int = 5) -> Dict[str, Any]: """Search by embedding similarity, return JSON object compliant with PRD. Thresholds: - return all with score >= 0.5 - if none >= 0.5 but max > 0.4, return the single best one """ client = EmbeddingClient() qvec = client.embed_one(query) # load all subjects with _pg_conn() as c: cur = c.cursor() cur.execute("SELECT id, name, vector FROM law_sub") subs: List[Tuple[str, str, List[float]]] = [] for sid, name, vec_json in cur.fetchall(): # vec_json may come back as Python list or JSON string depending on driver version if isinstance(vec_json, str): try: vec = json.loads(vec_json) except Exception: vec = [] else: vec = vec_json subs.append((str(sid), str(name), list(vec) if isinstance(vec, list) else [])) # Build permit lookup cur.execute("SELECT subject_id, permit_ids FROM law_sub_per") per_map: Dict[str, List[str]] = {} for sid, pids in cur.fetchall(): # pids may be list or JSON string if isinstance(pids, str): try: p_list = json.loads(pids) except Exception: p_list = [] else: p_list = list(pids) if isinstance(pids, list) else [] per_map[str(sid)] = [str(x) for x in p_list] scored: List[Tuple[float, Tuple[str, str, List[float]]]] = [] for row in subs: score = _cosine(qvec, row[2]) scored.append((score, row)) scored.sort(key=lambda x: x[0], reverse=True) # Build permit name lookup permit_name: Dict[str, str] = {} try: with _pg_conn() as c2: cur2 = c2.cursor() cur2.execute("SELECT id, name FROM law_permit") for pid, pname in cur2.fetchall(): permit_name[str(pid)] = str(pname) except Exception: # If table missing or query fails, leave map empty; upstream should seed via ingest permit_name = {} results: List[Dict[str, Any]] = [] for score, (sid, name, _vec) in scored: if score >= RETURN_IF_GE: item = { "id": sid, "name": name, # Build permit map: name -> id "permit": {permit_name.get(pid, ""): pid for pid in per_map.get(sid, []) if permit_name.get(pid)}, } if return_debug: item["score"] = round(float(score), 6) results.append(item) if not results and scored and scored[0][0] > FALLBACK_GT: sid, name, _ = scored[0][1] best_score = scored[0][0] item = { "id": sid, "name": name, "permit": {permit_name.get(pid, ""): pid for pid in per_map.get(sid, []) if permit_name.get(pid)}, } if return_debug: item["score"] = round(float(best_score), 6) results = [item] out: Dict[str, Any] = {"risk_subject": results} if return_debug: decision = ( "returned_ge_threshold" if results else "returned_top_fallback" if (scored and scored[0][0] > FALLBACK_GT) else "no_match_below_fallback" ) top_list = [] for s, (sid, name, _v) in scored[: max(0, top_k_debug) or 5]: top_list.append({"id": sid, "name": name, "score": round(float(s), 6)}) out["debug"] = { "query": query, "qvec_dim": len(qvec), "thresholds": {"return_if_ge": RETURN_IF_GE, "fallback_gt": FALLBACK_GT}, "num_subjects": len(subs), "max_score": round(float(scored[0][0]), 6) if scored else 0.0, "top_candidates": top_list, "decision": decision, } return out def search_subjects_llm(query: str, return_debug: bool = False, top_k_debug: int = 5) -> Dict[str, Any]: """Use LLM to pick one or more subject IDs from the catalog by instruction. Steps: - Load subject id+name list from DB - Ask LLM (Qwen) to select at least one subject id from the list that best matches the user query - Map selected ids to full entries (name+permit_ids) and return """ # Load catalog with _pg_conn() as c: cur = c.cursor() cur.execute("SELECT id, name FROM law_sub") subjects = [(str(sid), str(name)) for sid, name in cur.fetchall()] cur.execute("SELECT subject_id, permit_ids FROM law_sub_per") per_map: Dict[str, List[str]] = {} for sid, pids in cur.fetchall(): if isinstance(pids, str): try: p_list = json.loads(pids) except Exception: p_list = [] else: p_list = list(pids) if isinstance(pids, list) else [] per_map[str(sid)] = [str(x) for x in p_list] # Build concise subject list block: id | name per line # Keep within reasonable token limits; if too long, truncate and rely on LLM suggestion quality. lines = [f"{sid}\t{name}" for sid, name in subjects] subjects_block = "\n".join(lines) system_msg = ( "你是政务事项检索助手。根据用户的中文查询,从给定的主题事项清单中选择最相关的主题事项。" "只允许从清单中选择,不能编造。若没有足够相关的主题,请返回空数组 []." "始终以 JSON 数组返回所选主题事项的 id 列表,例如: [\"id1\", \"id2\"]." ) user_msg = ( f"用户问题: {query}\n\n" f"主题事项清单(格式: id名称):\n{subjects_block}\n\n" "请仅输出 JSON 数组 (仅数组本身)。若无匹配请输出 []." ) chat = ChatClient() content = chat.chat([ {"role": "system", "content": system_msg}, {"role": "user", "content": user_msg}, ], model=CHAT_MODEL, temperature=0.2) # Try parsing as JSON array of strings; robustly extract if wrapped text exists selected_ids: List[str] = [] try: txt = content.strip() start = txt.find("[") end = txt.rfind("]") if start != -1 and end != -1 and end > start: arr = json.loads(txt[start : end + 1]) else: arr = json.loads(txt) if isinstance(arr, list): for x in arr: if isinstance(x, str): selected_ids.append(x) elif isinstance(x, dict) and "id" in x and isinstance(x["id"], str): selected_ids.append(x["id"]) except Exception: selected_ids = [] # Deduplicate and keep only ids that exist id_set = {sid for sid, _ in subjects} chosen = [] for sid in selected_ids: if sid in id_set and sid not in chosen: chosen.append(sid) # Allow empty result when nothing is relevant # Load permit names permit_name: Dict[str, str] = {} try: with _pg_conn() as c2: cur2 = c2.cursor() cur2.execute("SELECT id, name FROM law_permit") for pid, pname in cur2.fetchall(): permit_name[str(pid)] = str(pname) except Exception: permit_name = {} results = [] name_map = {sid: name for sid, name in subjects} for sid in chosen: results.append({ "id": sid, "name": name_map.get(sid, ""), "permit": {permit_name.get(pid, ""): pid for pid in per_map.get(sid, []) if permit_name.get(pid)}, }) out: Dict[str, Any] = {"risk_subject": results} if return_debug: out["debug"] = { "model": CHAT_MODEL, "num_subjects": len(subjects), "selected_ids": chosen, "allow_empty": True, } return out