chore: initial commit
This commit is contained in:
commit
d6d92fd966
|
|
@ -0,0 +1,10 @@
|
|||
.env
|
||||
.venv/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.log
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
.env.*
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
# Repository Guidelines
|
||||
|
||||
## Project Structure & Module Organization
|
||||
- Root scripts: smart_cors_middleware.py (Flask CORS add-on), export_risk_json.py (PostgreSQL export).
|
||||
- Data/outputs: risk_tables_export.json (generated by export script).
|
||||
- Docs: PRD.md.
|
||||
- Python 3.10+ is required (uses PEP 604 unions like str | None).
|
||||
|
||||
## Build, Test, and Development Commands
|
||||
- Create venv (Windows):
|
||||
~~~powershell
|
||||
python -m venv .venv; .venv\Scripts\activate; pip install Flask pg8000 black ruff pytest
|
||||
~~~
|
||||
- Run DB export (writes risk_tables_export.json):
|
||||
~~~bash
|
||||
python export_risk_json.py
|
||||
~~~
|
||||
- Verify CORS middleware in your Flask app (diagnosis endpoint):
|
||||
~~~bash
|
||||
curl -i http://localhost:5000/api/cors-diagnosis
|
||||
~~~
|
||||
- Lint/format (optional tools): ruff . and black .
|
||||
- Tests (when added): pytest -q
|
||||
|
||||
## Coding Style & Naming Conventions
|
||||
- Python: 4-space indents, UTF-8 files, snake_case for functions/vars, SCREAMING_SNAKE_CASE for constants.
|
||||
- Prefer type hints; keep functions small and side-effect free.
|
||||
- Formatting: black (line length 100). Linting: ruff (default rules).
|
||||
- Filenames: modules like smart_cors_middleware.py; tests as test_*.py under tests/.
|
||||
|
||||
## Testing Guidelines
|
||||
- Framework: pytest with Flask test client for middleware behavior.
|
||||
- Target cases: origin matching (wildcard, exact, subdomains), preflight handling, X-CORS-Decision header, NGINX_CORS_MODE behavior.
|
||||
- Coverage: prioritize core branches in _origin_matches, preflight (OPTIONS), and after_request logic.
|
||||
|
||||
## Commit & Pull Request Guidelines
|
||||
- No Git history found here; use Conventional Commits (e.g., feat: add CORS diagnosis endpoint).
|
||||
- PRs should include: purpose, concise summary, screenshots or curl examples for HTTP changes, and any config/env notes.
|
||||
- Link related issues; keep PRs focused and under ~300 changed lines when possible.
|
||||
|
||||
## Security & Configuration Tips
|
||||
- Do NOT hardcode secrets. Move DB credentials in export_risk_json.py to env vars and load via os.getenv() or a .env file.
|
||||
- CORS env vars supported by middleware: ALLOWED_ORIGINS, CORS_STRICT, CORS_DEBUG, NGINX_CORS_MODE, CORS_MAX_AGE, CORS_EXPOSE_HEADERS.
|
||||
- Validate inputs from the DB export; avoid writing outside the repo.
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
# LawRisk 检索接口文档
|
||||
|
||||
- Base URL: https://YOUR_HOST
|
||||
- 路径: /fs-ai-asistant/api/workflow/lawrisk
|
||||
- 方法: POST(推荐), GET(便捷调试)
|
||||
- 鉴权: 无(如需可在网关/反向代理层添加)
|
||||
- CORS: 已启用(复用 smart_cors_middleware.py)
|
||||
|
||||
## 请求格式
|
||||
- Content-Type: application/x-www-form-urlencoded
|
||||
- 表单字段
|
||||
- query | q | text (string,必填): 用户输入的中文问题
|
||||
- mode (string,可选): llm(默认) 或 embed
|
||||
- debug (boolean-like,可选): 1/true/yes/on 视为开启调试
|
||||
- top (int,可选): 调试时返回候选数量(默认 5)
|
||||
|
||||
提示: GET 模式下同名参数通过查询串传递;POST 模式优先解析 x-www-form-urlencoded,若未提供则回退 JSON(application/json)。
|
||||
|
||||
## 响应
|
||||
- 成功 (200)
|
||||
- risk_subject: 数组,每项包含
|
||||
- id (string)
|
||||
- name (string)
|
||||
- permit_ids (string[])
|
||||
- score (number,可选,仅在 debug=1 且 embed 模式或回退时出现)
|
||||
- debug (object,可选,debug=1 时返回)
|
||||
- model (string): 使用模型(如 qwen-plus-latest)
|
||||
- num_subjects (number): 参与检索的主题数量
|
||||
- selected_ids (string[], 仅 llm 模式): LLM 选择的主题 ID 列表
|
||||
- thresholds (object, 仅 embed 模式): 相似度阈值
|
||||
- top_candidates (array, 仅 embed 模式): 前 N 候选及分数
|
||||
- allow_empty (boolean): LLM 允许返回空结果
|
||||
- 失败
|
||||
- 400: { "error": "query is required" }
|
||||
- 500: { "error": "<错误信息>" }
|
||||
|
||||
## 示例
|
||||
POST(推荐,LLM 模式 + 调试)
|
||||
|
||||
curl -s -X POST "http://www.chinaweal.com.cn:8090/fs-ai-asistant/api/workflow/lawrisk" -H "Content-Type: application/x-www-form-urlencoded" -d "query=我要办一家电影院&mode=llm&debug=1&top=5"
|
||||
|
||||
示例响应(命中)
|
||||
|
||||
{
|
||||
"risk_subject": [
|
||||
{"id":"384a...05e7","name":"开办电影院","permit_ids":["04bf...","509b...","..."]}
|
||||
],
|
||||
"debug": {
|
||||
"model": "qwen-plus-latest",
|
||||
"num_subjects": 123,
|
||||
"selected_ids": ["384a...05e7"],
|
||||
"allow_empty": true
|
||||
}
|
||||
}
|
||||
|
||||
示例响应(无匹配,允许空)
|
||||
|
||||
{
|
||||
"risk_subject": [],
|
||||
"debug": {
|
||||
"model": "qwen-plus-latest",
|
||||
"num_subjects": 123,
|
||||
"selected_ids": [],
|
||||
"allow_empty": true
|
||||
}
|
||||
}
|
||||
|
||||
GET 便捷调试
|
||||
|
||||
curl -s "http://www.chinaweal.com.cn:8090T/fs-ai-asistant/api/workflow/lawrisk?query=%E6%88%91%E8%A6%81%E5%8A%9E%E4%B8%80%E5%AE%B6%E7%94%B5%E5%BD%B1%E9%99%A2&mode=llm&debug=1&top=5"
|
||||
|
||||
前端调用示例(fetch)
|
||||
|
||||
fetch("http://www.chinaweal.com.cn:8090/fs-ai-asistant/api/workflow/lawrisk", {
|
||||
method: "POST",
|
||||
headers: {"Content-Type": "application/x-www-form-urlencoded"},
|
||||
body: new URLSearchParams({ query: "我要办一家电影院", mode: "llm", debug: "1", top: "5" })
|
||||
}).then(r => r.json()).then(console.log)
|
||||
|
||||
## 模式说明
|
||||
- llm(默认):将主题清单(id 与名称)传给 Qwen(qwen-plus-latest),由 LLM 选择最相关的一个或多个主题 ID;若判断无匹配,返回空数组。
|
||||
- embed(可选):基于向量相似度检索;阈值可通过环境变量配置(LAWRISK_RETURN_IF_GE、LAWRISK_FALLBACK_GT)。
|
||||
|
||||
## 兼容与跨域
|
||||
- 服务端已启用 CORS,可在 .env 中配置:ALLOWED_ORIGINS、CORS_STRICT、CORS_DEBUG、NGINX_CORS_MODE 等。
|
||||
- 如需鉴权(例如加 Token),建议在网关或反代层统一处理。
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
# Database Schema & Query Guide
|
||||
|
||||
## Overview
|
||||
The `licensing_risks` PostgreSQL database stores municipal licensing risk prompts parsed from Excel workbooks. Each record links regions, themes, permits, and risk narratives so downstream systems can query compliance obligations quickly.
|
||||
|
||||
## Tables
|
||||
| Table | Purpose | Key Columns |
|
||||
| --- | --- | --- |
|
||||
| `regions` | Administrative areas (市级、禅城区等) | `id` (PK), `name` (unique) |
|
||||
| `business_scopes` | Scoped经营范围条目 | `id` (PK), `description` |
|
||||
| `region_scopes` | Region-to-scope mapping | `region_id` → `regions.id`, `scope_id` → `business_scopes.id` |
|
||||
| `themes` | “一照通行”主题事项 | `id` (PK), `name` |
|
||||
| `region_themes` | Region-to-theme mapping | `region_id`, `theme_id` |
|
||||
| `permits` | 许可(备案)事项 | `id` (PK), `name` |
|
||||
| `region_theme_permits` | Region + theme + permit linkage | `region_id`, `theme_id`, `permit_id` |
|
||||
| `risks` | 风险提示主体信息 | `id` (PK), `risk_content`, `legal_basis`, `document_no`, `summary` |
|
||||
| `region_permit_risks` | Region + permit + risk linkage | `region_id`, `permit_id`, `risk_id` |
|
||||
|
||||
All primary keys are integer sequences; unique indexes and `ON CONFLICT DO NOTHING` logic make repeated imports idempotent. Foreign keys should be enforced in the target schema to prevent orphan rows.
|
||||
|
||||
## Query Cheatsheet
|
||||
### 列出所有主题事项(总表)
|
||||
```sql
|
||||
SELECT t.id,
|
||||
t.name AS theme_name,
|
||||
r.name AS region_name
|
||||
FROM themes t
|
||||
JOIN region_themes rt ON rt.theme_id = t.id
|
||||
JOIN regions r ON r.id = rt.region_id
|
||||
ORDER BY r.name, t.name;
|
||||
```
|
||||
|
||||
### 根据主题事项获取许可事项列表
|
||||
Replace `%主题关键词%` with the desired theme name or keyword.
|
||||
```sql
|
||||
SELECT DISTINCT p.id,
|
||||
p.name AS permit_name,
|
||||
r.name AS region_name
|
||||
FROM permits p
|
||||
JOIN region_theme_permits rtp
|
||||
ON rtp.permit_id = p.id
|
||||
JOIN region_themes rt
|
||||
ON rt.region_id = rtp.region_id
|
||||
AND rt.theme_id = rtp.theme_id
|
||||
JOIN themes t ON t.id = rt.theme_id
|
||||
JOIN regions r ON r.id = rt.region_id
|
||||
WHERE t.name ILIKE '%主题关键词%'
|
||||
ORDER BY r.name, permit_name;
|
||||
```
|
||||
|
||||
### 根据许可事项检索风险条目
|
||||
Substitute `'具体许可名称'` with the permit you care about.
|
||||
```sql
|
||||
SELECT r.name AS region_name,
|
||||
p.name AS permit_name,
|
||||
rk.risk_content,
|
||||
rk.legal_basis,
|
||||
rk.document_no,
|
||||
rk.summary
|
||||
FROM region_permit_risks rpr
|
||||
JOIN regions r ON r.id = rpr.region_id
|
||||
JOIN permits p ON p.id = rpr.permit_id
|
||||
JOIN risks rk ON rk.id = rpr.risk_id
|
||||
WHERE p.name = '具体许可名称'
|
||||
ORDER BY r.name, rk.risk_content;
|
||||
```
|
||||
For fuzzy lookups, switch to `WHERE p.name ILIKE '%关键词%'`.
|
||||
|
||||
## Execution Tips
|
||||
- Connect via `psql -h 172.24.240.1 -U postgres -d licensing_risks`.
|
||||
- Export query results with `\copy (SELECT …) TO '/tmp/export.csv' WITH CSV HEADER;`.
|
||||
- Run queries after imports commit; the loaders already wrap operations in transactions.
|
||||
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
我需要你帮我构建一个检索系统,用户会输入问题(中文),期望匹配到对应的事项,输出事项ID 事项名称 许可事项列表,例如:
|
||||
|
||||
用户输入:我要办一家电影院
|
||||
|
||||
输出:
|
||||
|
||||
"risk\_subject": \[
|
||||
|
||||
{
|
||||
|
||||
"id": "384aeb24a23e913268aad33354f705e7",
|
||||
|
||||
"name": "开办电影院",
|
||||
|
||||
"permit\_ids": \[
|
||||
|
||||
"04bfa019634ca1aa0b9f7c783fd85dce",
|
||||
|
||||
"509b2872fc7c38c08f252a2b426fd49f",
|
||||
|
||||
"54a79077-bd72-4ea9-8bb1-35afc69e2973",
|
||||
|
||||
"709b4718d72229311066e529650b8abf",
|
||||
|
||||
"8d49de002f24d37fcf3663574723e693",
|
||||
|
||||
"8f7c8c613adfbd815a78c1e60ec4330e",
|
||||
|
||||
"a0572119839422e1d11ee8801d6c58b7",
|
||||
|
||||
"fa2f3e05c92297be096b63e25d30bfbe"
|
||||
|
||||
]
|
||||
|
||||
}]
|
||||
|
||||
我希望你能用embedding模型来处理
|
||||
|
||||
首先先把事项名称从risk\_tables\_export.json 中提取出来,然后建立一个fs\_law\_risk数据库,建立表law\_sub用来存放事项向量,以json文件中的ID为主键,保存名称和向量到数据库
|
||||
|
||||
再建立一个表,名为law\_sub\_per,保存主题事项与许可事项的映射关系,需要有主题事项id,许可事项id列表
|
||||
|
||||
设置embedding相似度阈值0.5,大于阈值以上的事项全部返回
|
||||
|
||||
如果检索结果都小于0.5,但大于0.4,返回第一个
|
||||
|
||||
暴露接口/fs-ai-asistant/api/workflow/lawrisk
|
||||
|
||||
跨域问题处理请复用:smart\_cors\_middleware.py,你可以把这个文件移动到合适的目录
|
||||
|
||||
* 你可以使用的postgreSQL:
|
||||
|
||||
- IP :8.138.196.105
|
||||
|
||||
- port:5432
|
||||
|
||||
- user:postgres
|
||||
|
||||
- password:difyai123456
|
||||
|
||||
* API 以及doc参考
|
||||
|
||||
我们应该只需要用同步接口
|
||||
|
||||
- 通用文本向量同步接口API详情:https://help.aliyun.com/zh/model-studio/text-embedding-synchronous-api?spm=a2c4g.11186623.help-menu-2400256.d\_2\_7\_0.693e48233phHX8
|
||||
|
||||
- 通用文本批处理接口API详情:https://help.aliyun.com/zh/model-studio/text-embedding-batch-api?spm=a2c4g.11186623.help-menu-2400256.d\_2\_7\_1.59233560WBHuRz
|
||||
|
||||
- API key:sk-288824ef003e4e02bb963b8b3024b06a
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from flask import Flask, jsonify, request
|
||||
|
||||
from env_loader import load_env
|
||||
from smart_cors_middleware import init_smart_cors
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from lawrisk_service import (
|
||||
ensure_database,
|
||||
ensure_schema,
|
||||
search_subjects,
|
||||
search_subjects_llm,
|
||||
shortlist_subjects,
|
||||
suggest_questions_from_subjects,
|
||||
suggest_questions_embed,
|
||||
)
|
||||
|
||||
|
||||
def create_app() -> Flask:
|
||||
# Load .env before creating app to make CORS/env configs available
|
||||
load_env()
|
||||
# Ensure DB and schema exist before serving
|
||||
try:
|
||||
ensure_database()
|
||||
ensure_schema()
|
||||
except Exception:
|
||||
# Do not block app start; errors will surface on first request
|
||||
pass
|
||||
app = Flask(__name__)
|
||||
# Enable CORS using existing middleware
|
||||
init_smart_cors(app)
|
||||
|
||||
@app.route("/fs-ai-asistant/api/workflow/lawrisk", methods=["POST", "GET"])
|
||||
def lawrisk_search():
|
||||
if request.method == "GET":
|
||||
query = request.args.get("query") or request.args.get("q") or request.args.get("text")
|
||||
debug_flag = request.args.get("debug") in {"1", "true", "yes", "on"}
|
||||
top_k = request.args.get("top")
|
||||
try:
|
||||
top_k_int = int(top_k) if top_k else 5
|
||||
except Exception:
|
||||
top_k_int = 5
|
||||
mode = (request.args.get("mode") or "llm").lower()
|
||||
else:
|
||||
# Prefer x-www-form-urlencoded; fallback to JSON if provided
|
||||
if request.is_json:
|
||||
payload = request.get_json(silent=True) or {}
|
||||
else:
|
||||
payload = request.form.to_dict(flat=True) if request.form else {}
|
||||
|
||||
query = payload.get("query") or payload.get("q") or payload.get("text")
|
||||
debug_flag = str(payload.get("debug", "")).strip().lower() in {"1", "true", "yes", "on"}
|
||||
try:
|
||||
top_k_int = int(payload.get("top", 5))
|
||||
except Exception:
|
||||
top_k_int = 5
|
||||
mode = str(payload.get("mode", "llm")).lower()
|
||||
|
||||
if not query or not isinstance(query, str):
|
||||
return jsonify({"error": "query is required"}), 400
|
||||
try:
|
||||
t0 = time.time()
|
||||
with ThreadPoolExecutor(max_workers=3) as ex:
|
||||
fut_ret = ex.submit(
|
||||
search_subjects if mode == "embed" else search_subjects_llm,
|
||||
query,
|
||||
debug_flag,
|
||||
top_k_int,
|
||||
)
|
||||
# Use embedding-based question suggestion (falls back internally if not available)
|
||||
fut_qs = ex.submit(suggest_questions_embed, query, max(1, top_k_int))
|
||||
|
||||
result = fut_ret.result()
|
||||
rec_questions = fut_qs.result() or []
|
||||
|
||||
# If debug requested, still log to backend for visibility
|
||||
if debug_flag and isinstance(result, dict) and "debug" in result:
|
||||
dbg = result["debug"]
|
||||
model = dbg.get("model") or "embed"
|
||||
app.logger.info("[LAWRISK-DEBUG] mode=%s", model)
|
||||
|
||||
# Extract risk_subject and optional debug
|
||||
risk_subject = []
|
||||
dbg = {}
|
||||
if isinstance(result, dict):
|
||||
risk_subject = result.get("risk_subject", [])
|
||||
if debug_flag:
|
||||
dbg = result.get("debug", {})
|
||||
|
||||
found = bool(risk_subject)
|
||||
llm_resp = "" if found else "抱歉,无法检索到相关答案"
|
||||
exec_time = int((time.time() - t0) * 1000)
|
||||
|
||||
# rec_questions 已由 embedding 建议生成(内部包含兜底)
|
||||
|
||||
data = {
|
||||
"llmRespond": llm_resp,
|
||||
"lawRisk": "",
|
||||
"questionExtend": rec_questions,
|
||||
"conversationId": "",
|
||||
"messageId": "",
|
||||
"roundNumber": 0,
|
||||
"conversationInfo": {},
|
||||
"knowledgeSources": [],
|
||||
"totalKnowledgeSources": 0,
|
||||
"executionTime": exec_time,
|
||||
"workflowStatus": "ok" if found else "no_match",
|
||||
"executionSteps": [],
|
||||
"costStatistics": {},
|
||||
"workflowTrackingId": "",
|
||||
# extra fields requested
|
||||
"risk_subject": risk_subject,
|
||||
"debug": dbg if debug_flag else {},
|
||||
}
|
||||
resp = {"success": True, "message": "OK", "data": data}
|
||||
return jsonify(resp)
|
||||
except Exception as e:
|
||||
app.logger.exception("lawrisk_search error")
|
||||
return jsonify({"success": False, "message": str(e), "data": {}}), 500
|
||||
|
||||
# Basic health check
|
||||
@app.get("/healthz")
|
||||
def healthz():
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(os.getenv("PORT", "8000"))
|
||||
app = create_app()
|
||||
app.run(host="0.0.0.0", port=port)
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
"""Minimal .env loader to populate os.environ from a .env file.
|
||||
|
||||
Supports lines of the form KEY=VALUE, optional quotes, and comments (# ...).
|
||||
By default, does not override existing environment variables.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def load_env(path: str = ".env", override: bool = False) -> None:
|
||||
if not os.path.exists(path):
|
||||
return
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for raw in f:
|
||||
line = raw.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if "=" not in line:
|
||||
continue
|
||||
key, val = line.split("=", 1)
|
||||
key = key.strip()
|
||||
val = val.strip()
|
||||
# Strip quotes if wrapped
|
||||
if (val.startswith("\"") and val.endswith("\"")) or (
|
||||
val.startswith("'") and val.endswith("'")
|
||||
):
|
||||
val = val[1:-1]
|
||||
if key and (override or key not in os.environ):
|
||||
os.environ[key] = val
|
||||
except Exception:
|
||||
# Fail silently; callers still can rely on existing env
|
||||
return
|
||||
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
import pg8000
|
||||
from env_loader import load_env
|
||||
|
||||
# Read DB config from environment; provide sensible defaults
|
||||
CONFIG = {
|
||||
'host': os.getenv('PG_HOST', '8.138.196.105'),
|
||||
'port': int(os.getenv('PG_PORT', '5432')),
|
||||
'database': os.getenv('PG_DATABASE', 'fs_law_risk'),
|
||||
'user': os.getenv('PG_USER', 'postgres'),
|
||||
'password': os.getenv('PG_PASSWORD', 'difyai123456'),
|
||||
}
|
||||
|
||||
# Export file path can be overridden via env
|
||||
OUTPUT = os.getenv('RISK_EXPORT_OUTPUT', 'risk_tables_export.json')
|
||||
|
||||
SQLS = {
|
||||
'risk_subject': "SELECT sub_id AS id, sub_name AS name FROM public.risk_subject ORDER BY sub_name;",
|
||||
'risk_permit': "SELECT per_id AS id, per_name AS name FROM public.risk_permit ORDER BY per_name;",
|
||||
'risk_sub_per': "SELECT sub_id, per_id FROM public.risk_sub_per;",
|
||||
}
|
||||
|
||||
def fetch_all(cursor):
|
||||
cols = [d[0] for d in cursor.description]
|
||||
return [dict(zip(cols, row)) for row in cursor.fetchall()]
|
||||
|
||||
def main():
|
||||
load_env()
|
||||
try:
|
||||
conn = pg8000.connect(**CONFIG)
|
||||
except Exception as e:
|
||||
print('DB connect error:', e, file=sys.stderr)
|
||||
sys.exit(2)
|
||||
|
||||
result = {}
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(SQLS['risk_subject'])
|
||||
subjects = fetch_all(cur)
|
||||
cur.execute(SQLS['risk_permit'])
|
||||
permits = fetch_all(cur)
|
||||
cur.execute(SQLS['risk_sub_per'])
|
||||
rels = fetch_all(cur)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# Build mapping: subject -> list of permit_ids
|
||||
sub_to_permit_ids = {}
|
||||
for r in rels:
|
||||
sub_id = r['sub_id']
|
||||
per_id = r['per_id']
|
||||
sub_to_permit_ids.setdefault(sub_id, set()).add(per_id)
|
||||
|
||||
# Subjects with aggregated permit_ids
|
||||
subjects_out = []
|
||||
for s in subjects:
|
||||
subjects_out.append({
|
||||
'id': s['id'],
|
||||
'name': s['name'],
|
||||
'permit_ids': sorted(list(sub_to_permit_ids.get(s['id'], [])))
|
||||
})
|
||||
|
||||
# Final JSON: keep full permit catalog (id+name), and subjects contain aggregated permit_ids
|
||||
result['risk_subject'] = subjects_out
|
||||
result['risk_permit'] = permits
|
||||
|
||||
with open(OUTPUT, 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print('Exported to', OUTPUT)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
"""Ingest subjects and permit mappings into PostgreSQL.
|
||||
|
||||
Reads risk_tables_export.json from repo root, embeds each subject name using
|
||||
Aliyun DashScope embeddings (OpenAI-compatible), and stores into:
|
||||
- law_sub(id TEXT PK, name TEXT, vector JSONB)
|
||||
- law_sub_per(subject_id TEXT PK, permit_ids JSONB)
|
||||
|
||||
Usage:
|
||||
python ingest_lawrisk.py
|
||||
|
||||
Ensure env var DASHSCOPE_API_KEY is set.
|
||||
Optionally set PG_* vars (PG_HOST, PG_PORT, PG_USER, PG_PASSWORD, PG_DATABASE).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from env_loader import load_env
|
||||
from lawrisk_service import (
|
||||
ensure_database,
|
||||
ensure_schema,
|
||||
EmbeddingClient,
|
||||
upsert_subjects,
|
||||
upsert_subject_permits,
|
||||
upsert_permits,
|
||||
)
|
||||
|
||||
REPO_JSON = os.getenv("LAWRISK_JSON", "risk_tables_export.json")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# Load .env first so service reads correct env
|
||||
load_env()
|
||||
ensure_database()
|
||||
ensure_schema()
|
||||
|
||||
with open(REPO_JSON, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
subjects = data.get("risk_subject", [])
|
||||
permits = data.get("risk_permit", [])
|
||||
# Prepare embeddings in small batches to avoid large payloads
|
||||
client = EmbeddingClient()
|
||||
|
||||
batched_rows = []
|
||||
BATCH = 32
|
||||
names: List[str] = []
|
||||
metas = [] # (id, name)
|
||||
for s in subjects:
|
||||
sid = s.get("id")
|
||||
name = s.get("name")
|
||||
if not sid or not name:
|
||||
continue
|
||||
names.append(name)
|
||||
metas.append((sid, name))
|
||||
if len(names) >= BATCH:
|
||||
vecs = client.embed_texts(names)
|
||||
for (sid, name), vec in zip(metas, vecs):
|
||||
batched_rows.append((sid, name, vec))
|
||||
names.clear()
|
||||
metas.clear()
|
||||
|
||||
if names:
|
||||
vecs = client.embed_texts(names)
|
||||
for (sid, name), vec in zip(metas, vecs):
|
||||
batched_rows.append((sid, name, vec))
|
||||
|
||||
upsert_subjects(batched_rows)
|
||||
|
||||
# Build subject->permit_ids mapping
|
||||
per_rows = []
|
||||
for s in subjects:
|
||||
sid = s.get("id")
|
||||
pids = s.get("permit_ids", [])
|
||||
if sid and isinstance(pids, list):
|
||||
# ensure strings
|
||||
per_rows.append((sid, [str(x) for x in pids]))
|
||||
|
||||
if per_rows:
|
||||
upsert_subject_permits(per_rows)
|
||||
|
||||
# Upsert permit catalog (id -> name)
|
||||
per_catalog = []
|
||||
for p in permits:
|
||||
pid = p.get("id")
|
||||
pname = p.get("name")
|
||||
if pid and pname:
|
||||
per_catalog.append((pid, pname))
|
||||
if per_catalog:
|
||||
upsert_permits(per_catalog)
|
||||
|
||||
print(
|
||||
f"Ingested {len(batched_rows)} subjects, {len(per_rows)} subject-permit mappings, and {len(per_catalog)} permits into PostgreSQL."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,568 @@
|
|||
"""
|
||||
LawRisk embedding retrieval service.
|
||||
|
||||
Responsibilities:
|
||||
- DB connection helpers (PostgreSQL via pg8000)
|
||||
- Schema management (fs_law_risk.law_sub, fs_law_risk.law_sub_per)
|
||||
- Embedding client (Aliyun DashScope OpenAI-compatible embeddings API)
|
||||
- Chat client for LLM-based selection (Qwen via OpenAI-compatible /chat/completions)
|
||||
- Search logic: embedding cosine or LLM subject selection
|
||||
|
||||
Env vars used:
|
||||
- PG_HOST, PG_PORT, PG_USER, PG_PASSWORD (PostgreSQL credentials)
|
||||
- PG_DATABASE (defaults to fs_law_risk)
|
||||
- PG_ADMIN_DB (defaults to postgres; used for CREATE DATABASE)
|
||||
- DASHSCOPE_API_KEY (embedding API key)
|
||||
- DASHSCOPE_BASE_URL (defaults to https://dashscope.aliyuncs.com/compatible-mode/v1)
|
||||
- DASHSCOPE_EMBED_MODEL (defaults to text-embedding-v4)
|
||||
- DASHSCOPE_EMBED_DIM (defaults to 1024)
|
||||
- DASHSCOPE_CHAT_MODEL (defaults to qwen-plus-latest)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import ssl
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import pg8000.dbapi as pg
|
||||
|
||||
|
||||
DEFAULT_DB = "fs_law_risk"
|
||||
DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
EMBED_MODEL = os.getenv("DASHSCOPE_EMBED_MODEL", "text-embedding-v4")
|
||||
EMBED_DIM = int(os.getenv("DASHSCOPE_EMBED_DIM", "1024"))
|
||||
EMBED_MAX_BATCH = max(1, int(os.getenv("DASHSCOPE_MAX_BATCH", "10"))) # DashScope limit <=10
|
||||
CHAT_MODEL = os.getenv("DASHSCOPE_CHAT_MODEL", "qwen-plus-latest")
|
||||
|
||||
# Similarity thresholds (env configurable)
|
||||
RETURN_IF_GE = float(os.getenv("LAWRISK_RETURN_IF_GE", "0.7"))
|
||||
FALLBACK_GT = float(os.getenv("LAWRISK_FALLBACK_GT", "0.4"))
|
||||
# Similarity thresholds (env configurable)
|
||||
RETURN_IF_GE = float(os.getenv("LAWRISK_RETURN_IF_GE", "0.7"))
|
||||
FALLBACK_GT = float(os.getenv("LAWRISK_FALLBACK_GT", "0.4"))
|
||||
|
||||
|
||||
def _pg_conn(database: Optional[str] = None, autocommit: bool = False) -> pg.Connection:
|
||||
host = os.getenv("PG_HOST", "8.138.196.105")
|
||||
port = int(os.getenv("PG_PORT", "5432"))
|
||||
user = os.getenv("PG_USER", "postgres")
|
||||
password = os.getenv("PG_PASSWORD", "difyai123456")
|
||||
dbname = database or os.getenv("PG_DATABASE", DEFAULT_DB)
|
||||
conn = pg.connect(host=host, port=port, user=user, password=password, database=dbname)
|
||||
conn.autocommit = autocommit
|
||||
return conn
|
||||
|
||||
|
||||
def ensure_database(dbname: str = DEFAULT_DB) -> None:
|
||||
# Create database if not exists by connecting to postgres
|
||||
admin_db = os.getenv("PG_ADMIN_DB", "postgres")
|
||||
with _pg_conn(database=admin_db, autocommit=True) as c:
|
||||
cur = c.cursor()
|
||||
cur.execute("SELECT 1 FROM pg_database WHERE datname=%s", (dbname,))
|
||||
if cur.fetchone() is None:
|
||||
cur.execute(f"CREATE DATABASE {dbname}")
|
||||
|
||||
|
||||
def ensure_schema() -> None:
|
||||
with _pg_conn() as c:
|
||||
cur = c.cursor()
|
||||
# Store vectors and permit ids as JSONB for portability
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS law_sub (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
vector JSONB NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS law_sub_per (
|
||||
subject_id TEXT PRIMARY KEY,
|
||||
permit_ids JSONB NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS law_permit (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
c.commit()
|
||||
|
||||
|
||||
class EmbeddingClient:
|
||||
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
|
||||
self.api_key = api_key or os.getenv("DASHSCOPE_API_KEY")
|
||||
self.base_url = base_url or os.getenv("DASHSCOPE_BASE_URL", DEFAULT_BASE_URL)
|
||||
if not self.api_key:
|
||||
raise RuntimeError("DASHSCOPE_API_KEY is not set")
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
# sanitize inputs
|
||||
clean_inputs = [str(t) for t in texts if isinstance(t, str) and str(t).strip()]
|
||||
if not clean_inputs:
|
||||
raise ValueError("No valid input texts for embeddings")
|
||||
|
||||
# chunk by provider batch limit and concatenate results to preserve order
|
||||
out: List[List[float]] = []
|
||||
for i in range(0, len(clean_inputs), EMBED_MAX_BATCH):
|
||||
chunk = clean_inputs[i : i + EMBED_MAX_BATCH]
|
||||
out.extend(self._embed_batch(chunk))
|
||||
if len(out) != len(clean_inputs):
|
||||
raise RuntimeError(
|
||||
f"Embedding API returned unexpected result count: got {len(out)}, want {len(clean_inputs)}"
|
||||
)
|
||||
return out
|
||||
|
||||
def _embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
url = self.base_url.rstrip("/") + "/embeddings"
|
||||
body = {
|
||||
"model": EMBED_MODEL,
|
||||
"input": texts,
|
||||
"dimensions": EMBED_DIM,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
data = json.dumps(body).encode("utf-8")
|
||||
req = urllib.request.Request(url, data=data, method="POST")
|
||||
req.add_header("Authorization", f"Bearer {self.api_key}")
|
||||
req.add_header("Content-Type", "application/json")
|
||||
ctx = ssl.create_default_context()
|
||||
try:
|
||||
with urllib.request.urlopen(req, context=ctx, timeout=30) as resp:
|
||||
raw = resp.read().decode("utf-8", errors="replace")
|
||||
except urllib.error.HTTPError as e:
|
||||
err_body = e.read().decode("utf-8", errors="replace") if hasattr(e, 'read') else ""
|
||||
raise RuntimeError(
|
||||
f"Embedding API error {e.code}: {err_body or e.reason} | sent={json.dumps(body, ensure_ascii=False)[:500]}"
|
||||
) from e
|
||||
payload = json.loads(raw)
|
||||
out: List[List[float]] = []
|
||||
for item in payload.get("data", []):
|
||||
emb = item.get("embedding")
|
||||
if isinstance(emb, list):
|
||||
out.append([float(x) for x in emb])
|
||||
return out
|
||||
|
||||
def embed_one(self, text: str) -> List[float]:
|
||||
return self.embed_texts([text])[0]
|
||||
|
||||
|
||||
class ChatClient:
|
||||
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
|
||||
self.api_key = api_key or os.getenv("DASHSCOPE_API_KEY")
|
||||
self.base_url = base_url or os.getenv("DASHSCOPE_BASE_URL", DEFAULT_BASE_URL)
|
||||
if not self.api_key:
|
||||
raise RuntimeError("DASHSCOPE_API_KEY is not set")
|
||||
|
||||
def chat(self, messages: List[Dict[str, str]], model: Optional[str] = None, temperature: float = 0.2) -> str:
|
||||
url = self.base_url.rstrip("/") + "/chat/completions"
|
||||
body = {
|
||||
"model": model or CHAT_MODEL,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
data = json.dumps(body, ensure_ascii=False).encode("utf-8")
|
||||
req = urllib.request.Request(url, data=data, method="POST")
|
||||
req.add_header("Authorization", f"Bearer {self.api_key}")
|
||||
req.add_header("Content-Type", "application/json")
|
||||
ctx = ssl.create_default_context()
|
||||
try:
|
||||
with urllib.request.urlopen(req, context=ctx, timeout=60) as resp:
|
||||
raw = resp.read().decode("utf-8", errors="replace")
|
||||
except urllib.error.HTTPError as e:
|
||||
err_body = e.read().decode("utf-8", errors="replace") if hasattr(e, 'read') else ""
|
||||
raise RuntimeError(
|
||||
f"Chat API error {e.code}: {err_body or e.reason}"
|
||||
) from e
|
||||
payload = json.loads(raw)
|
||||
choices = payload.get("choices", [])
|
||||
if not choices:
|
||||
raise RuntimeError("Chat API returned no choices")
|
||||
msg = choices[0].get("message", {})
|
||||
content = msg.get("content", "")
|
||||
return str(content)
|
||||
|
||||
|
||||
def generate_question_suggestions(query: str, max_q: int = 5) -> List[str]:
|
||||
"""Legacy LLM-based suggestion generator (kept for reference). Not used by default."""
|
||||
try:
|
||||
chat = ChatClient()
|
||||
content = chat.chat([
|
||||
{"role": "system", "content": "你是政务事项问答助手。请输出与主题事项高度相关的精简推荐问题,仅输出 JSON 数组。"},
|
||||
{"role": "user", "content": f"请针对: {query} 给出不超过 {max_q} 条中文推荐问题,仅输出 JSON 数组。"},
|
||||
], model=CHAT_MODEL, temperature=0.3)
|
||||
txt = content.strip()
|
||||
start = txt.find("[")
|
||||
end = txt.rfind("]")
|
||||
arr = json.loads(txt[start : end + 1] if start != -1 and end != -1 and end > start else txt)
|
||||
out: List[str] = []
|
||||
if isinstance(arr, list):
|
||||
for x in arr:
|
||||
if isinstance(x, str) and x.strip():
|
||||
out.append(x.strip())
|
||||
return out[:max_q]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def _normalize_text(s: str) -> str:
|
||||
return "".join(ch for ch in str(s) if ch.isalnum())
|
||||
|
||||
|
||||
def shortlist_subjects(query: str, k: int = 5) -> List[Tuple[str, str]]:
|
||||
"""Return up to k subjects with highest lexical overlap to query.
|
||||
|
||||
Simple char-level overlap score to keep it deterministic and fast.
|
||||
"""
|
||||
q = set(_normalize_text(query))
|
||||
if not q:
|
||||
q = set(query)
|
||||
with _pg_conn() as c:
|
||||
cur = c.cursor()
|
||||
cur.execute("SELECT id, name FROM law_sub")
|
||||
subs = [(str(sid), str(name)) for sid, name in cur.fetchall()]
|
||||
scored: List[Tuple[float, Tuple[str, str]]] = []
|
||||
for sid, name in subs:
|
||||
n = set(_normalize_text(name)) or set(name)
|
||||
inter = len(q & n)
|
||||
denom = max(1, len(n))
|
||||
score = inter / denom
|
||||
if inter > 0:
|
||||
scored.append((score, (sid, name)))
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
return [row for _s, row in scored[:k]]
|
||||
|
||||
|
||||
def suggest_questions_from_subjects(subject_names: List[str], max_q: int = 5) -> List[str]:
|
||||
"""Return subject names directly (no extra wording)."""
|
||||
out: List[str] = []
|
||||
for nm in subject_names:
|
||||
nm = (nm or "").strip()
|
||||
if nm and nm not in out:
|
||||
out.append(nm)
|
||||
if len(out) >= max_q:
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def suggest_questions_embed(query: str, max_q: int = 5) -> List[str]:
|
||||
"""Use embeddings to pick top-N subject names (no added text)."""
|
||||
try:
|
||||
client = EmbeddingClient()
|
||||
qvec = client.embed_one(query)
|
||||
except Exception:
|
||||
# Embedding not available; fallback to lexical shortlist
|
||||
subs = shortlist_subjects(query, max(1, max_q))
|
||||
return [name for _sid, name in subs][:max_q]
|
||||
|
||||
# Load subjects with vectors
|
||||
with _pg_conn() as c:
|
||||
cur = c.cursor()
|
||||
cur.execute("SELECT id, name, vector FROM law_sub")
|
||||
subjects: List[Tuple[str, str, List[float]]] = []
|
||||
for sid, name, vec_json in cur.fetchall():
|
||||
if isinstance(vec_json, str):
|
||||
try:
|
||||
vec = json.loads(vec_json)
|
||||
except Exception:
|
||||
vec = []
|
||||
else:
|
||||
vec = vec_json
|
||||
if isinstance(vec, list) and vec:
|
||||
subjects.append((str(sid), str(name), [float(x) for x in vec]))
|
||||
|
||||
if not subjects:
|
||||
subs = shortlist_subjects(query, max(1, max_q))
|
||||
return [name for _sid, name in subs][:max_q]
|
||||
|
||||
scored: List[Tuple[float, Tuple[str, str]]] = []
|
||||
for sid, name, vec in subjects:
|
||||
s = _cosine(qvec, vec)
|
||||
scored.append((s, (sid, name)))
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# Take top subjects (more than max_q to allow templating to fill up to max_q)
|
||||
top_subjects = [nm for _score, (_sid, nm) in scored[: max_q]]
|
||||
return top_subjects
|
||||
|
||||
def _cosine(a: List[float], b: List[float]) -> float:
|
||||
if not a or not b or len(a) != len(b):
|
||||
return 0.0
|
||||
dot = 0.0
|
||||
na = 0.0
|
||||
nb = 0.0
|
||||
for x, y in zip(a, b):
|
||||
dot += x * y
|
||||
na += x * x
|
||||
nb += y * y
|
||||
if na == 0.0 or nb == 0.0:
|
||||
return 0.0
|
||||
return dot / math.sqrt(na * nb)
|
||||
|
||||
|
||||
def upsert_subjects(
|
||||
rows: Iterable[Tuple[str, str, List[float]]]
|
||||
) -> None:
|
||||
"""Upsert subjects into law_sub."""
|
||||
with _pg_conn() as c:
|
||||
cur = c.cursor()
|
||||
for sid, name, vec in rows:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO law_sub (id, name, vector)
|
||||
VALUES (%s, %s, %s::jsonb)
|
||||
ON CONFLICT (id) DO UPDATE SET name=EXCLUDED.name, vector=EXCLUDED.vector
|
||||
""",
|
||||
(sid, name, json.dumps(vec)),
|
||||
)
|
||||
c.commit()
|
||||
|
||||
|
||||
def upsert_subject_permits(rows: Iterable[Tuple[str, List[str]]]) -> None:
|
||||
with _pg_conn() as c:
|
||||
cur = c.cursor()
|
||||
for sid, permit_ids in rows:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO law_sub_per (subject_id, permit_ids)
|
||||
VALUES (%s, %s::jsonb)
|
||||
ON CONFLICT (subject_id) DO UPDATE SET permit_ids=EXCLUDED.permit_ids
|
||||
""",
|
||||
(sid, json.dumps(permit_ids)),
|
||||
)
|
||||
c.commit()
|
||||
|
||||
|
||||
def upsert_permits(rows: Iterable[Tuple[str, str]]) -> None:
|
||||
"""Upsert permit catalog into law_permit (id -> name)."""
|
||||
with _pg_conn() as c:
|
||||
cur = c.cursor()
|
||||
for pid, name in rows:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO law_permit (id, name)
|
||||
VALUES (%s, %s)
|
||||
ON CONFLICT (id) DO UPDATE SET name=EXCLUDED.name
|
||||
""",
|
||||
(pid, name),
|
||||
)
|
||||
c.commit()
|
||||
|
||||
|
||||
def search_subjects(query: str, return_debug: bool = False, top_k_debug: int = 5) -> Dict[str, Any]:
|
||||
"""Search by embedding similarity, return JSON object compliant with PRD.
|
||||
|
||||
Thresholds:
|
||||
- return all with score >= 0.5
|
||||
- if none >= 0.5 but max > 0.4, return the single best one
|
||||
"""
|
||||
client = EmbeddingClient()
|
||||
qvec = client.embed_one(query)
|
||||
|
||||
# load all subjects
|
||||
with _pg_conn() as c:
|
||||
cur = c.cursor()
|
||||
cur.execute("SELECT id, name, vector FROM law_sub")
|
||||
subs: List[Tuple[str, str, List[float]]] = []
|
||||
for sid, name, vec_json in cur.fetchall():
|
||||
# vec_json may come back as Python list or JSON string depending on driver version
|
||||
if isinstance(vec_json, str):
|
||||
try:
|
||||
vec = json.loads(vec_json)
|
||||
except Exception:
|
||||
vec = []
|
||||
else:
|
||||
vec = vec_json
|
||||
subs.append((str(sid), str(name), list(vec) if isinstance(vec, list) else []))
|
||||
|
||||
# Build permit lookup
|
||||
cur.execute("SELECT subject_id, permit_ids FROM law_sub_per")
|
||||
per_map: Dict[str, List[str]] = {}
|
||||
for sid, pids in cur.fetchall():
|
||||
# pids may be list or JSON string
|
||||
if isinstance(pids, str):
|
||||
try:
|
||||
p_list = json.loads(pids)
|
||||
except Exception:
|
||||
p_list = []
|
||||
else:
|
||||
p_list = list(pids) if isinstance(pids, list) else []
|
||||
per_map[str(sid)] = [str(x) for x in p_list]
|
||||
|
||||
scored: List[Tuple[float, Tuple[str, str, List[float]]]] = []
|
||||
for row in subs:
|
||||
score = _cosine(qvec, row[2])
|
||||
scored.append((score, row))
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# Build permit name lookup
|
||||
permit_name: Dict[str, str] = {}
|
||||
try:
|
||||
with _pg_conn() as c2:
|
||||
cur2 = c2.cursor()
|
||||
cur2.execute("SELECT id, name FROM law_permit")
|
||||
for pid, pname in cur2.fetchall():
|
||||
permit_name[str(pid)] = str(pname)
|
||||
except Exception:
|
||||
# If table missing or query fails, leave map empty; upstream should seed via ingest
|
||||
permit_name = {}
|
||||
|
||||
results: List[Dict[str, Any]] = []
|
||||
for score, (sid, name, _vec) in scored:
|
||||
if score >= RETURN_IF_GE:
|
||||
item = {
|
||||
"id": sid,
|
||||
"name": name,
|
||||
# Build permit map: name -> id
|
||||
"permit": {permit_name.get(pid, ""): pid for pid in per_map.get(sid, []) if permit_name.get(pid)},
|
||||
}
|
||||
if return_debug:
|
||||
item["score"] = round(float(score), 6)
|
||||
results.append(item)
|
||||
|
||||
if not results and scored and scored[0][0] > FALLBACK_GT:
|
||||
sid, name, _ = scored[0][1]
|
||||
best_score = scored[0][0]
|
||||
item = {
|
||||
"id": sid,
|
||||
"name": name,
|
||||
"permit": {permit_name.get(pid, ""): pid for pid in per_map.get(sid, []) if permit_name.get(pid)},
|
||||
}
|
||||
if return_debug:
|
||||
item["score"] = round(float(best_score), 6)
|
||||
results = [item]
|
||||
|
||||
out: Dict[str, Any] = {"risk_subject": results}
|
||||
if return_debug:
|
||||
decision = (
|
||||
"returned_ge_threshold" if results
|
||||
else "returned_top_fallback" if (scored and scored[0][0] > FALLBACK_GT)
|
||||
else "no_match_below_fallback"
|
||||
)
|
||||
top_list = []
|
||||
for s, (sid, name, _v) in scored[: max(0, top_k_debug) or 5]:
|
||||
top_list.append({"id": sid, "name": name, "score": round(float(s), 6)})
|
||||
out["debug"] = {
|
||||
"query": query,
|
||||
"qvec_dim": len(qvec),
|
||||
"thresholds": {"return_if_ge": RETURN_IF_GE, "fallback_gt": FALLBACK_GT},
|
||||
"num_subjects": len(subs),
|
||||
"max_score": round(float(scored[0][0]), 6) if scored else 0.0,
|
||||
"top_candidates": top_list,
|
||||
"decision": decision,
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
def search_subjects_llm(query: str, return_debug: bool = False, top_k_debug: int = 5) -> Dict[str, Any]:
|
||||
"""Use LLM to pick one or more subject IDs from the catalog by instruction.
|
||||
|
||||
Steps:
|
||||
- Load subject id+name list from DB
|
||||
- Ask LLM (Qwen) to select at least one subject id from the list that best matches the user query
|
||||
- Map selected ids to full entries (name+permit_ids) and return
|
||||
"""
|
||||
# Load catalog
|
||||
with _pg_conn() as c:
|
||||
cur = c.cursor()
|
||||
cur.execute("SELECT id, name FROM law_sub")
|
||||
subjects = [(str(sid), str(name)) for sid, name in cur.fetchall()]
|
||||
cur.execute("SELECT subject_id, permit_ids FROM law_sub_per")
|
||||
per_map: Dict[str, List[str]] = {}
|
||||
for sid, pids in cur.fetchall():
|
||||
if isinstance(pids, str):
|
||||
try:
|
||||
p_list = json.loads(pids)
|
||||
except Exception:
|
||||
p_list = []
|
||||
else:
|
||||
p_list = list(pids) if isinstance(pids, list) else []
|
||||
per_map[str(sid)] = [str(x) for x in p_list]
|
||||
|
||||
# Build concise subject list block: id | name per line
|
||||
# Keep within reasonable token limits; if too long, truncate and rely on LLM suggestion quality.
|
||||
lines = [f"{sid}\t{name}" for sid, name in subjects]
|
||||
subjects_block = "\n".join(lines)
|
||||
|
||||
system_msg = (
|
||||
"你是政务事项检索助手。根据用户的中文查询,从给定的主题事项清单中选择最相关的主题事项。"
|
||||
"只允许从清单中选择,不能编造。若没有足够相关的主题,请返回空数组 []."
|
||||
"始终以 JSON 数组返回所选主题事项的 id 列表,例如: [\"id1\", \"id2\"]."
|
||||
)
|
||||
user_msg = (
|
||||
f"用户问题: {query}\n\n"
|
||||
f"主题事项清单(格式: id<TAB>名称):\n{subjects_block}\n\n"
|
||||
"请仅输出 JSON 数组 (仅数组本身)。若无匹配请输出 []."
|
||||
)
|
||||
|
||||
chat = ChatClient()
|
||||
content = chat.chat([
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": user_msg},
|
||||
], model=CHAT_MODEL, temperature=0.2)
|
||||
|
||||
# Try parsing as JSON array of strings; robustly extract if wrapped text exists
|
||||
selected_ids: List[str] = []
|
||||
try:
|
||||
txt = content.strip()
|
||||
start = txt.find("[")
|
||||
end = txt.rfind("]")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
arr = json.loads(txt[start : end + 1])
|
||||
else:
|
||||
arr = json.loads(txt)
|
||||
if isinstance(arr, list):
|
||||
for x in arr:
|
||||
if isinstance(x, str):
|
||||
selected_ids.append(x)
|
||||
elif isinstance(x, dict) and "id" in x and isinstance(x["id"], str):
|
||||
selected_ids.append(x["id"])
|
||||
except Exception:
|
||||
selected_ids = []
|
||||
|
||||
# Deduplicate and keep only ids that exist
|
||||
id_set = {sid for sid, _ in subjects}
|
||||
chosen = []
|
||||
for sid in selected_ids:
|
||||
if sid in id_set and sid not in chosen:
|
||||
chosen.append(sid)
|
||||
# Allow empty result when nothing is relevant
|
||||
|
||||
# Load permit names
|
||||
permit_name: Dict[str, str] = {}
|
||||
try:
|
||||
with _pg_conn() as c2:
|
||||
cur2 = c2.cursor()
|
||||
cur2.execute("SELECT id, name FROM law_permit")
|
||||
for pid, pname in cur2.fetchall():
|
||||
permit_name[str(pid)] = str(pname)
|
||||
except Exception:
|
||||
permit_name = {}
|
||||
|
||||
results = []
|
||||
name_map = {sid: name for sid, name in subjects}
|
||||
for sid in chosen:
|
||||
results.append({
|
||||
"id": sid,
|
||||
"name": name_map.get(sid, ""),
|
||||
"permit": {permit_name.get(pid, ""): pid for pid in per_map.get(sid, []) if permit_name.get(pid)},
|
||||
})
|
||||
|
||||
out: Dict[str, Any] = {"risk_subject": results}
|
||||
if return_debug:
|
||||
out["debug"] = {
|
||||
"model": CHAT_MODEL,
|
||||
"num_subjects": len(subjects),
|
||||
"selected_ids": chosen,
|
||||
"allow_empty": True,
|
||||
}
|
||||
return out
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,231 @@
|
|||
import os
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
from flask import Flask, request, Response, jsonify
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def _get_bool(name: str, default: bool = False) -> bool:
|
||||
v = os.getenv(name)
|
||||
if v is None:
|
||||
return default
|
||||
return str(v).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _parse_allowed(origins_raw: str | None) -> list[str]:
|
||||
if not origins_raw:
|
||||
return ["*"]
|
||||
parts = [p.strip() for p in origins_raw.split(",") if p.strip()]
|
||||
return parts or ["*"]
|
||||
|
||||
|
||||
def _origin_matches(origin: str, allowed: Iterable[str], strict: bool) -> tuple[bool, str | None]:
|
||||
# strict=True: only exact match items that are not wildcards
|
||||
# strict=False: allow simple suffix matches: .example.com or *.example.com
|
||||
for pat in allowed:
|
||||
if pat == "*":
|
||||
if strict:
|
||||
# strict mode ignores wildcard
|
||||
continue
|
||||
return True, "*"
|
||||
if origin == pat:
|
||||
return True, pat
|
||||
if not strict:
|
||||
if pat.startswith("*."):
|
||||
suf = pat[1:] # .example.com
|
||||
if origin.endswith(suf):
|
||||
return True, pat
|
||||
if pat.startswith("."):
|
||||
if origin.endswith(pat):
|
||||
return True, pat
|
||||
return False, None
|
||||
|
||||
|
||||
def init_smart_cors(app: Flask) -> None:
|
||||
"""Attach Smart CORS handlers to a Flask app.
|
||||
|
||||
Env vars:
|
||||
- ALLOWED_ORIGINS: comma-separated, e.g. https://a.com,https://b.com or *
|
||||
- CORS_STRICT: true/false (default false)
|
||||
- CORS_DEBUG: true/false
|
||||
- NGINX_CORS_MODE: true/false (when Nginx sets Allow-Origin; app only supplements others)
|
||||
- CORS_MAX_AGE: seconds for preflight caching (default 86400)
|
||||
- CORS_EXPOSE_HEADERS: override exposed headers list
|
||||
"""
|
||||
|
||||
allowed = _parse_allowed(os.getenv("ALLOWED_ORIGINS"))
|
||||
|
||||
|
||||
# 添加前端实际使用的域名
|
||||
frontend_origins = [
|
||||
"http://chinaweal.com.cn:8090",
|
||||
"http://www.chinaweal.com.cn:8090",
|
||||
"https://chinaweal.com.cn",
|
||||
"https://www.chinaweal.com.cn",
|
||||
"http://172.22.80.130:8000", # 从日志中看到的referer
|
||||
]
|
||||
|
||||
# 添加默认的本地开发源
|
||||
default_origins = [
|
||||
"http://localhost:3000",
|
||||
"http://localhost:8000",
|
||||
"http://localhost:8081",
|
||||
"http://127.0.0.1:3000",
|
||||
"http://127.0.0.1:8000",
|
||||
"http://127.0.0.1:8081"
|
||||
]
|
||||
|
||||
# 合并并去重
|
||||
all_origins = list(set(allowed + frontend_origins + default_origins))
|
||||
strict = _get_bool("CORS_STRICT", False)
|
||||
debug = _get_bool("CORS_DEBUG", False)
|
||||
nginx_mode = _get_bool("NGINX_CORS_MODE", False)
|
||||
max_age = os.getenv("CORS_MAX_AGE", "86400")
|
||||
expose_headers = os.getenv(
|
||||
"CORS_EXPOSE_HEADERS",
|
||||
"Content-Length, Content-Type, Authorization, X-Request-Id",
|
||||
)
|
||||
|
||||
def _is_proxy_request() -> bool:
|
||||
hdrs = request.headers
|
||||
return any(h in hdrs for h in ("X-Forwarded-For", "X-Forwarded-Proto", "X-Real-IP"))
|
||||
|
||||
def _log(level: str, msg: str, **fields):
|
||||
if not debug and level == "DEBUG":
|
||||
return
|
||||
logger = app.logger
|
||||
kv = " ".join(f"{k}={v}" for k, v in fields.items()) if fields else ""
|
||||
logger.log(
|
||||
20 if level == "INFO" else 10 if level == "DEBUG" else 30,
|
||||
f"[CORS] {msg} {kv}".rstrip(),
|
||||
)
|
||||
|
||||
@app.before_request
|
||||
def _handle_preflight(): # type: ignore
|
||||
if request.method != "OPTIONS":
|
||||
return None
|
||||
origin = request.headers.get("Origin", "")
|
||||
acrm = request.headers.get("Access-Control-Request-Method", "")
|
||||
acrh = request.headers.get("Access-Control-Request-Headers", "")
|
||||
|
||||
allowed_ok, matched = _origin_matches(origin, all_origins, strict) if origin else (False, None)
|
||||
|
||||
# Build a minimal preflight response
|
||||
resp = Response(status=204)
|
||||
|
||||
# When Nginx handles Allow-Origin, do not override it
|
||||
if not nginx_mode:
|
||||
if allowed_ok:
|
||||
if matched == "*":
|
||||
# If behind proxy (likely Nginx will add AO), skip to avoid duplicates
|
||||
if not _is_proxy_request():
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
else:
|
||||
_log("DEBUG", "proxy-skip-allow-origin(preflight)")
|
||||
else:
|
||||
resp.headers["Access-Control-Allow-Origin"] = origin
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
resp.headers.add("Vary", "Origin")
|
||||
# Always supplement the rest
|
||||
resp.headers["Access-Control-Allow-Methods"] = acrm or "GET, POST, OPTIONS"
|
||||
if acrh:
|
||||
resp.headers["Access-Control-Allow-Headers"] = acrh
|
||||
else:
|
||||
resp.headers["Access-Control-Allow-Headers"] = (
|
||||
"Content-Type, Authorization, X-Requested-With, X-Request-Id"
|
||||
)
|
||||
resp.headers["Access-Control-Max-Age"] = max_age
|
||||
resp.headers["Access-Control-Expose-Headers"] = expose_headers
|
||||
|
||||
decision = "deny"
|
||||
if not origin:
|
||||
decision = "no-origin"
|
||||
elif allowed_ok:
|
||||
decision = "allow-*" if matched == "*" else "allow-origin"
|
||||
resp.headers["X-CORS-Decision"] = f"preflight; {decision}; nginx_mode={nginx_mode}"
|
||||
|
||||
_log(
|
||||
"INFO",
|
||||
"preflight",
|
||||
path=request.path,
|
||||
origin=origin,
|
||||
request_method=acrm or "",
|
||||
request_headers=acrh or "",
|
||||
allowed=allowed_ok,
|
||||
matched=matched or "",
|
||||
nginx_mode=nginx_mode,
|
||||
decision=decision,
|
||||
)
|
||||
return resp
|
||||
|
||||
@app.after_request
|
||||
def _add_cors_headers(response: Response): # type: ignore
|
||||
origin = request.headers.get("Origin", "")
|
||||
if not origin:
|
||||
return response
|
||||
|
||||
allowed_ok, matched = _origin_matches(origin, all_origins, strict)
|
||||
|
||||
# When Nginx sets Allow-Origin, we only supplement others to avoid duplicates
|
||||
if not nginx_mode:
|
||||
if allowed_ok:
|
||||
if matched == "*":
|
||||
# Do not set credentials when wildcard is used; if proxied, skip to avoid dup
|
||||
if not _is_proxy_request():
|
||||
response.headers.setdefault("Access-Control-Allow-Origin", "*")
|
||||
else:
|
||||
_log("DEBUG", "proxy-skip-allow-origin(response)")
|
||||
else:
|
||||
response.headers.setdefault("Access-Control-Allow-Origin", origin)
|
||||
response.headers.setdefault("Access-Control-Allow-Credentials", "true")
|
||||
# Ensure caches vary by Origin when dynamic
|
||||
vary_val = response.headers.get("Vary")
|
||||
if not vary_val:
|
||||
response.headers["Vary"] = "Origin"
|
||||
elif "Origin" not in vary_val:
|
||||
response.headers["Vary"] = f"{vary_val}, Origin"
|
||||
|
||||
# Expose common headers (safe to always include)
|
||||
response.headers.setdefault("Access-Control-Expose-Headers", expose_headers)
|
||||
|
||||
# Add decision header and log
|
||||
decision = "deny"
|
||||
if allowed_ok:
|
||||
decision = "allow-*" if matched == "*" else "allow-origin"
|
||||
response.headers.setdefault(
|
||||
"X-CORS-Decision",
|
||||
f"response; {decision}; nginx_mode={nginx_mode}",
|
||||
)
|
||||
_log(
|
||||
"INFO",
|
||||
"response",
|
||||
path=request.path,
|
||||
method=request.method,
|
||||
origin=origin,
|
||||
allowed=allowed_ok,
|
||||
matched=matched or "",
|
||||
nginx_mode=nginx_mode,
|
||||
decision=decision,
|
||||
)
|
||||
return response
|
||||
|
||||
# Diagnosis endpoint to aid verification and debugging
|
||||
@app.get("/api/cors-diagnosis")
|
||||
def cors_diagnosis(): # type: ignore
|
||||
origin = request.headers.get("Origin", "")
|
||||
ok, matched = _origin_matches(origin, all_origins, strict) if origin else (False, None)
|
||||
data = {
|
||||
"nginx_mode": nginx_mode,
|
||||
"detected_proxy": _is_proxy_request(),
|
||||
"allowed_origins": all_origins,
|
||||
"strict": strict,
|
||||
"request_origin": origin,
|
||||
"matched_rule": matched,
|
||||
"origin_allowed": ok,
|
||||
"notes": (
|
||||
"Nginx handles Allow-Origin; app supplements other headers"
|
||||
if nginx_mode
|
||||
else "App sets CORS headers end-to-end"
|
||||
),
|
||||
}
|
||||
return jsonify(data)
|
||||
Loading…
Reference in New Issue