536 lines
19 KiB
Python
536 lines
19 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import os
|
||
import re
|
||
from collections import OrderedDict
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
import pg8000.dbapi as pg
|
||
|
||
# Separate configuration so legacy fs_law_risk integration keeps using PG_*
|
||
LIC_DEFAULT_DB = "licensing_risks"
|
||
|
||
|
||
ARTICLE_HEADING_RE = re.compile(r"(?m)^(第[一二三四五六七八九十百零0-9]+条)")
|
||
ARTICLE_TOKEN_RE = re.compile(r"(?<!\*)(第[一二三四五六七八九十百零0-9]+条)(?!\*)")
|
||
ARTICLE_NEWLINE_RE = re.compile(r"(?<!^)(?<!\n)(\*\*第[一二三四五六七八九十百零0-9]+条\*\*)")
|
||
CN_ENUM_INLINE_RE = re.compile(r"([;;::。.])[ \t]*(([一二三四五六七八九十百零]+))")
|
||
CN_ENUM_LINE_RE = re.compile(r"(?m)^\s*(([一二三四五六七八九十百零]+))")
|
||
ARABIC_ENUM_INLINE_RE = re.compile(r"([;;::。.,,])[ \t]*(\d+\.)")
|
||
ARABIC_ENUM_LINE_RE = re.compile(r"(?m)^\s*(\d+)\.")
|
||
NESTED_ENUM_INLINE_RE = re.compile(r"([;;::。.])[ \t]*((\d+))")
|
||
NESTED_ENUM_LINE_RE = re.compile(r"(?m)^\s*((\d+))")
|
||
COLON_NEWLINE_RE = re.compile(r":\s*\n")
|
||
TRAILING_SPACE_RE = re.compile(r"[ \t]+\n")
|
||
EXTRA_NEWLINES_RE = re.compile(r"\n{3,}")
|
||
|
||
|
||
def _format_summary_markdown(summary: str) -> str:
|
||
"""Render Chinese legal excerpts as Markdown-friendly text."""
|
||
if not summary:
|
||
return ""
|
||
|
||
text = summary.replace("\r\n", "\n").strip()
|
||
if not text:
|
||
return ""
|
||
|
||
text = ARTICLE_HEADING_RE.sub(lambda m: f"**{m.group(1)}**", text)
|
||
text = CN_ENUM_INLINE_RE.sub(lambda m: f"{m.group(1)}\n- ({m.group(2)}) ", text)
|
||
text = CN_ENUM_LINE_RE.sub(lambda m: f"- ({m.group(1)}) ", text)
|
||
text = ARABIC_ENUM_INLINE_RE.sub(lambda m: f"{m.group(1)}\n {m.group(2)}", text)
|
||
text = ARABIC_ENUM_LINE_RE.sub(lambda m: f" {m.group(1)}.", text)
|
||
text = NESTED_ENUM_INLINE_RE.sub(lambda m: f"{m.group(1)}\n - ({m.group(2)})", text)
|
||
text = NESTED_ENUM_LINE_RE.sub(lambda m: f" - ({m.group(1)})", text)
|
||
text = ARTICLE_TOKEN_RE.sub(lambda m: f"**{m.group(1)}**", text)
|
||
text = ARTICLE_NEWLINE_RE.sub(lambda m: f"\n{m.group(1)}", text)
|
||
text = COLON_NEWLINE_RE.sub(":\n", text)
|
||
text = EXTRA_NEWLINES_RE.sub("\n\n", text)
|
||
text = TRAILING_SPACE_RE.sub("\n", text)
|
||
text = re.sub(r"\n\s+\n", "\n\n", text)
|
||
return text.strip()
|
||
|
||
|
||
def _lic_pg_conn(autocommit: bool = False) -> pg.Connection:
|
||
host = os.getenv("LIC_PG_HOST", "172.24.240.1")
|
||
port = int(os.getenv("LIC_PG_PORT", os.getenv("PG_PORT", "5432")))
|
||
user = os.getenv("LIC_PG_USER", os.getenv("PG_USER", "postgres"))
|
||
password = os.getenv("LIC_PG_PASSWORD", "")
|
||
database = os.getenv("LIC_PG_DATABASE", LIC_DEFAULT_DB)
|
||
conn = pg.connect(host=host, port=port, user=user, password=password, database=database)
|
||
conn.autocommit = autocommit
|
||
return conn
|
||
|
||
|
||
def list_region_theme_options() -> List[Dict[str, str]]:
|
||
"""Return all region-theme pairs usable for LLM selection."""
|
||
sql = """
|
||
SELECT
|
||
rt.region_id,
|
||
r.name AS region_name,
|
||
rt.theme_id,
|
||
t.name AS theme_name
|
||
FROM region_themes rt
|
||
JOIN regions r ON r.id = rt.region_id
|
||
JOIN themes t ON t.id = rt.theme_id
|
||
ORDER BY r.name, t.name
|
||
"""
|
||
out: List[Dict[str, str]] = []
|
||
with _lic_pg_conn() as conn:
|
||
cur = conn.cursor()
|
||
cur.execute(sql)
|
||
for region_id, region_name, theme_id, theme_name in cur.fetchall():
|
||
rid = str(region_id)
|
||
tid = str(theme_id)
|
||
out.append(
|
||
{
|
||
"option_id": f"{rid}:{tid}",
|
||
"region_id": rid,
|
||
"region_name": str(region_name),
|
||
"theme_id": tid,
|
||
"theme_name": str(theme_name),
|
||
"display_name": f"{region_name} · {theme_name}",
|
||
}
|
||
)
|
||
return out
|
||
|
||
|
||
def load_business_scopes(region_id: str) -> List[Dict[str, str]]:
|
||
"""List business scopes bound to a region."""
|
||
sql = """
|
||
SELECT bs.id, bs.description
|
||
FROM region_scopes rs
|
||
JOIN business_scopes bs ON bs.id = rs.scope_id
|
||
WHERE rs.region_id = %s
|
||
ORDER BY bs.description
|
||
"""
|
||
scopes: List[Dict[str, str]] = []
|
||
with _lic_pg_conn() as conn:
|
||
cur = conn.cursor()
|
||
cur.execute(sql, (region_id,))
|
||
for scope_id, description in cur.fetchall():
|
||
scopes.append({"id": str(scope_id), "description": str(description)})
|
||
return scopes
|
||
|
||
|
||
def list_permits_for_region(region: str) -> List[Dict[str, str]]:
|
||
"""Return all permits available within a region (accepts id or name)."""
|
||
sql = """
|
||
SELECT DISTINCT p.id, p.name
|
||
FROM region_theme_permits rtp
|
||
JOIN permits p ON p.id = rtp.permit_id
|
||
JOIN regions r ON r.id = rtp.region_id
|
||
WHERE rtp.region_id::text = %s OR LOWER(r.name) = LOWER(%s)
|
||
ORDER BY p.name
|
||
"""
|
||
permits: List[Dict[str, str]] = []
|
||
with _lic_pg_conn() as conn:
|
||
cur = conn.cursor()
|
||
cur.execute(sql, (region, region))
|
||
for permit_id, permit_name in cur.fetchall():
|
||
permits.append({"id": str(permit_id), "name": str(permit_name)})
|
||
return permits
|
||
|
||
|
||
def _load_permit_scopes_for_region(
|
||
conn: pg.Connection, region_id: str, permit_ids: List[str]
|
||
) -> Dict[str, List[Dict[str, str]]]:
|
||
"""Return mapping of permit_id -> business scopes for that permit within region."""
|
||
scope_map: Dict[str, List[Dict[str, str]]] = {pid: [] for pid in permit_ids}
|
||
if not permit_ids:
|
||
return scope_map
|
||
|
||
sql = """
|
||
SELECT rps.permit_id, bs.id, bs.description
|
||
FROM region_permit_scopes rps
|
||
JOIN business_scopes bs ON bs.id = rps.scope_id
|
||
WHERE rps.region_id = %s
|
||
ORDER BY rps.permit_id, bs.description
|
||
"""
|
||
cur = conn.cursor()
|
||
try:
|
||
cur.execute(sql, (region_id,))
|
||
except pg.ProgrammingError as exc:
|
||
# 42P01 => undefined_table; allow fallback when migration not yet applied.
|
||
sqlstate = getattr(exc, "sqlstate", "")
|
||
if sqlstate == "42P01":
|
||
return scope_map
|
||
raise
|
||
|
||
for permit_id, scope_id, description in cur.fetchall():
|
||
pid = str(permit_id)
|
||
if pid not in scope_map:
|
||
continue
|
||
scope_map[pid].append({"id": str(scope_id), "description": str(description)})
|
||
return scope_map
|
||
|
||
|
||
def load_permits_and_risks(
|
||
region_id: str, theme_id: str, permit_id: Optional[str] = None
|
||
) -> List[Dict[str, object]]:
|
||
"""Return permits with attached risk entries for a region-theme pair."""
|
||
sql = """
|
||
SELECT
|
||
p.id AS permit_id,
|
||
p.name AS permit_name,
|
||
rk.id AS risk_id,
|
||
rk.risk_content,
|
||
rk.legal_basis,
|
||
rk.document_no,
|
||
rk.summary,
|
||
rpd.permit_status,
|
||
rpd.subitem_summary,
|
||
rpd.responsible_contact,
|
||
rpd.jurisdiction_scope
|
||
FROM region_theme_permits rtp
|
||
JOIN permits p ON p.id = rtp.permit_id
|
||
LEFT JOIN region_permit_risks rpr
|
||
ON rpr.region_id = rtp.region_id
|
||
AND rpr.permit_id = rtp.permit_id
|
||
LEFT JOIN risks rk ON rk.id = rpr.risk_id
|
||
LEFT JOIN region_permit_details rpd
|
||
ON rpd.region_id = rtp.region_id
|
||
AND rpd.permit_id = rtp.permit_id
|
||
WHERE rtp.region_id = %s AND rtp.theme_id = %s
|
||
"""
|
||
params: List[Any] = [region_id, theme_id]
|
||
if permit_id is not None:
|
||
sql += " AND rtp.permit_id = %s"
|
||
params.append(permit_id)
|
||
|
||
sql += """
|
||
ORDER BY p.name, rk.risk_content
|
||
"""
|
||
permits: Dict[str, Dict[str, object]] = {}
|
||
with _lic_pg_conn() as conn:
|
||
cur = conn.cursor()
|
||
cur.execute(sql, tuple(params))
|
||
for row in cur.fetchall():
|
||
(
|
||
permit_id,
|
||
permit_name,
|
||
risk_id,
|
||
risk_content,
|
||
legal_basis,
|
||
document_no,
|
||
summary,
|
||
permit_status,
|
||
subitem_summary,
|
||
responsible_contact,
|
||
jurisdiction_scope,
|
||
) = row
|
||
pid = str(permit_id)
|
||
entry = permits.setdefault(
|
||
pid,
|
||
{
|
||
"id": pid,
|
||
"name": str(permit_name),
|
||
"business_scopes": [],
|
||
"risks": [],
|
||
"permit_status": None,
|
||
"subitem_summary": None,
|
||
"responsible_contact": None,
|
||
"jurisdiction_scope": None,
|
||
},
|
||
)
|
||
if entry["permit_status"] is None and permit_status:
|
||
entry["permit_status"] = permit_status.strip() or None
|
||
if entry["subitem_summary"] is None and subitem_summary:
|
||
entry["subitem_summary"] = subitem_summary.strip() or None
|
||
if entry["responsible_contact"] is None and responsible_contact:
|
||
entry["responsible_contact"] = responsible_contact.strip() or None
|
||
if entry["jurisdiction_scope"] is None and jurisdiction_scope:
|
||
entry["jurisdiction_scope"] = jurisdiction_scope.strip() or None
|
||
if risk_id is not None:
|
||
summary_markdown = _format_summary_markdown(summary or "")
|
||
entry["risks"].append(
|
||
{
|
||
"id": str(risk_id),
|
||
"risk_content": risk_content or "",
|
||
"legal_basis": legal_basis or "",
|
||
"document_no": document_no or "",
|
||
"summary": summary_markdown,
|
||
}
|
||
)
|
||
|
||
permit_ids = list(permits.keys())
|
||
scope_map = _load_permit_scopes_for_region(conn, region_id, permit_ids)
|
||
for pid in permit_ids:
|
||
permits[pid]["business_scopes"] = scope_map.get(pid, [])
|
||
return list(permits.values())
|
||
|
||
|
||
def find_permit_contexts_by_name(permit_name: str) -> List[Dict[str, str]]:
|
||
"""Return region/theme contexts for permits with an exact name match."""
|
||
if not permit_name:
|
||
return []
|
||
|
||
sql = """
|
||
SELECT
|
||
rtp.region_id,
|
||
r.name AS region_name,
|
||
rtp.theme_id,
|
||
t.name AS theme_name,
|
||
p.id AS permit_id,
|
||
p.name AS permit_name
|
||
FROM region_theme_permits rtp
|
||
JOIN permits p ON p.id = rtp.permit_id
|
||
JOIN regions r ON r.id = rtp.region_id
|
||
JOIN themes t ON t.id = rtp.theme_id
|
||
WHERE p.name = %s
|
||
ORDER BY r.name, t.name
|
||
"""
|
||
ordered: OrderedDict[Tuple[str, str], Dict[str, str]] = OrderedDict()
|
||
with _lic_pg_conn() as conn:
|
||
cur = conn.cursor()
|
||
cur.execute(sql, (permit_name,))
|
||
for row in cur.fetchall():
|
||
region_id, region_name, theme_id, theme_name, permit_id, canonical_name = row
|
||
rid = str(region_id)
|
||
pid = str(permit_id)
|
||
key = (rid, pid)
|
||
if key in ordered:
|
||
continue
|
||
ordered[key] = {
|
||
"region_id": rid,
|
||
"region_name": str(region_name),
|
||
"theme_id": str(theme_id),
|
||
"theme_name": str(theme_name),
|
||
"permit_id": pid,
|
||
"permit_name": str(canonical_name),
|
||
}
|
||
return list(ordered.values())
|
||
|
||
|
||
def load_theme_payload(region_id: str, theme_id: str) -> Dict[str, object]:
|
||
"""Assemble full data bundle for a region-theme selection."""
|
||
info_sql = """
|
||
SELECT r.id, r.name, t.id, t.name
|
||
FROM regions r
|
||
JOIN region_themes rt ON rt.region_id = r.id
|
||
JOIN themes t ON t.id = rt.theme_id
|
||
WHERE r.id = %s AND t.id = %s
|
||
LIMIT 1
|
||
"""
|
||
with _lic_pg_conn() as conn:
|
||
cur = conn.cursor()
|
||
cur.execute(info_sql, (region_id, theme_id))
|
||
row = cur.fetchone()
|
||
if not row:
|
||
raise ValueError("Region/theme combination not found")
|
||
region_uuid, region_name, theme_uuid, theme_name = row
|
||
|
||
permits = load_permits_and_risks(region_id, theme_id)
|
||
return {
|
||
"region": {"id": str(region_uuid), "name": str(region_name)},
|
||
"theme": {"id": str(theme_uuid), "name": str(theme_name)},
|
||
"permits": permits,
|
||
}
|
||
|
||
|
||
def _get_checkpoints_dir() -> str:
|
||
"""Get the directory for storing checkpoint files."""
|
||
base_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data")
|
||
checkpoints_dir = os.path.join(base_dir, "checkpoints")
|
||
os.makedirs(checkpoints_dir, exist_ok=True)
|
||
return checkpoints_dir
|
||
|
||
|
||
def _get_all_tables() -> List[str]:
|
||
"""Get list of all tables in the licensing_risks database."""
|
||
sql = """
|
||
SELECT table_name
|
||
FROM information_schema.tables
|
||
WHERE table_schema = 'public'
|
||
AND table_type = 'BASE TABLE'
|
||
ORDER BY table_name
|
||
"""
|
||
with _lic_pg_conn() as conn:
|
||
cur = conn.cursor()
|
||
cur.execute(sql)
|
||
return [row[0] for row in cur.fetchall()]
|
||
|
||
|
||
def _backup_table(conn: pg.Connection, table_name: str) -> Tuple[List[Dict[str, Any]], int]:
|
||
"""Backup a single table and return its data and row count."""
|
||
sql = f"SELECT * FROM {table_name}"
|
||
cur = conn.cursor()
|
||
cur.execute(sql)
|
||
|
||
rows = cur.fetchall()
|
||
colnames = [desc[0] for desc in cur.description]
|
||
|
||
data = []
|
||
for row in rows:
|
||
row_dict = {}
|
||
for i, col in enumerate(colnames):
|
||
value = row[i]
|
||
if value is not None:
|
||
# Convert UUID and other non-serializable types to strings
|
||
if hasattr(value, 'isoformat'): # UUID, datetime, etc.
|
||
row_dict[col] = str(value)
|
||
else:
|
||
row_dict[col] = value
|
||
data.append(row_dict)
|
||
|
||
return data, len(data)
|
||
|
||
|
||
def _restore_table(conn: pg.Connection, table_name: str, data: List[Dict[str, Any]]) -> int:
|
||
"""Restore a table from backup data. Returns number of rows restored."""
|
||
if not data:
|
||
return 0
|
||
|
||
conn.autocommit = False
|
||
try:
|
||
cur = conn.cursor()
|
||
|
||
delete_sql = f"DELETE FROM {table_name}"
|
||
cur.execute(delete_sql)
|
||
|
||
if data:
|
||
columns = list(data[0].keys())
|
||
placeholders = ", ".join(["%s"] * len(columns))
|
||
insert_sql = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})"
|
||
|
||
for row in data:
|
||
values = [row.get(col) for col in columns]
|
||
cur.execute(insert_sql, values)
|
||
|
||
conn.commit()
|
||
return len(data)
|
||
except Exception as e:
|
||
conn.rollback()
|
||
raise e
|
||
finally:
|
||
conn.autocommit = False
|
||
|
||
|
||
def create_checkpoint(description: str = "") -> Dict[str, Any]:
|
||
"""Create a checkpoint by backing up all tables."""
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
checkpoint_id = f"checkpoint_{timestamp}"
|
||
|
||
tables = _get_all_tables()
|
||
checkpoint_data = {
|
||
"checkpoint_id": checkpoint_id,
|
||
"timestamp": timestamp,
|
||
"description": description,
|
||
"tables": {}
|
||
}
|
||
|
||
total_rows = 0
|
||
table_counts = {}
|
||
|
||
with _lic_pg_conn() as conn:
|
||
for table in tables:
|
||
data, row_count = _backup_table(conn, table)
|
||
checkpoint_data["tables"][table] = data
|
||
table_counts[table] = row_count
|
||
total_rows += row_count
|
||
|
||
checkpoint_data["table_counts"] = table_counts
|
||
checkpoint_data["total_rows"] = total_rows
|
||
|
||
checkpoints_dir = _get_checkpoints_dir()
|
||
checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json")
|
||
|
||
def json_serializer(obj):
|
||
"""Convert non-JSON serializable objects to strings."""
|
||
try:
|
||
import uuid
|
||
if isinstance(obj, uuid.UUID):
|
||
return str(obj)
|
||
except ImportError:
|
||
pass
|
||
|
||
if hasattr(obj, 'isoformat'):
|
||
return str(obj)
|
||
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
||
|
||
with open(checkpoint_file, "w", encoding="utf-8") as f:
|
||
json.dump(checkpoint_data, f, ensure_ascii=False, indent=2, default=json_serializer)
|
||
|
||
return {
|
||
"checkpoint_id": checkpoint_id,
|
||
"timestamp": timestamp,
|
||
"description": description,
|
||
"total_rows": total_rows,
|
||
"table_counts": table_counts
|
||
}
|
||
|
||
|
||
def list_checkpoints() -> List[Dict[str, Any]]:
|
||
"""List all available checkpoints."""
|
||
checkpoints_dir = _get_checkpoints_dir()
|
||
checkpoints = []
|
||
|
||
if not os.path.exists(checkpoints_dir):
|
||
return checkpoints
|
||
|
||
for filename in os.listdir(checkpoints_dir):
|
||
if filename.endswith(".json"):
|
||
filepath = os.path.join(checkpoints_dir, filename)
|
||
try:
|
||
with open(filepath, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
checkpoints.append({
|
||
"checkpoint_id": data["checkpoint_id"],
|
||
"timestamp": data["timestamp"],
|
||
"description": data.get("description", ""),
|
||
"total_rows": data.get("total_rows", 0),
|
||
"table_counts": data.get("table_counts", {}),
|
||
"filename": filename
|
||
})
|
||
except Exception as e:
|
||
print(f"Error reading checkpoint {filename}: {e}")
|
||
|
||
return sorted(checkpoints, key=lambda x: x["timestamp"], reverse=True)
|
||
|
||
|
||
def restore_checkpoint(checkpoint_id: str) -> Dict[str, Any]:
|
||
"""Restore database from a checkpoint. This is a destructive operation!"""
|
||
checkpoints_dir = _get_checkpoints_dir()
|
||
checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json")
|
||
|
||
if not os.path.exists(checkpoint_file):
|
||
raise ValueError(f"Checkpoint {checkpoint_id} not found")
|
||
|
||
with open(checkpoint_file, "r", encoding="utf-8") as f:
|
||
checkpoint_data = json.load(f)
|
||
|
||
tables = checkpoint_data.get("tables", {})
|
||
restore_summary = {
|
||
"checkpoint_id": checkpoint_id,
|
||
"tables_restored": 0,
|
||
"total_rows_restored": 0,
|
||
"table_details": {}
|
||
}
|
||
|
||
with _lic_pg_conn(autocommit=False) as conn:
|
||
try:
|
||
for table_name, data in tables.items():
|
||
rows_restored = _restore_table(conn, table_name, data)
|
||
restore_summary["tables_restored"] += 1
|
||
restore_summary["total_rows_restored"] += rows_restored
|
||
restore_summary["table_details"][table_name] = rows_restored
|
||
|
||
conn.commit()
|
||
except Exception as e:
|
||
conn.rollback()
|
||
raise e
|
||
|
||
return restore_summary
|
||
|
||
|
||
def delete_checkpoint(checkpoint_id: str) -> bool:
|
||
"""Delete a checkpoint file."""
|
||
checkpoints_dir = _get_checkpoints_dir()
|
||
checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json")
|
||
|
||
if os.path.exists(checkpoint_file):
|
||
os.remove(checkpoint_file)
|
||
return True
|
||
return False
|