fs-lawrisk/lawrisk/services/licensing_repo.py

3839 lines
136 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import json
import logging
import os
import re
from collections import OrderedDict, defaultdict
import hashlib
from datetime import datetime, date
from decimal import Decimal
from io import BytesIO
import threading
import time
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
import uuid
import pg8000.dbapi as pg
from openpyxl import load_workbook
from openpyxl.utils.exceptions import InvalidFileException
# Configure logger
logger = logging.getLogger(__name__)
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s"))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logger.propagate = False
# Separate configuration so legacy fs_law_risk integration keeps using PG_*
LIC_DEFAULT_DB = "licensing_risks"
_UNSET = object()
ARTICLE_HEADING_RE = re.compile(r"(?m)^(第[一二三四五六七八九十百零0-9]+条)")
ARTICLE_TOKEN_RE = re.compile(r"(?<!\*)(第[一二三四五六七八九十百零0-9]+条)(?!\*)")
ARTICLE_NEWLINE_RE = re.compile(r"(?<!^)(?<!\n)(\*\*第[一二三四五六七八九十百零0-9]+条\*\*)")
CN_ENUM_INLINE_RE = re.compile(r"([;::。.])[ \t]*([一二三四五六七八九十百零]+)")
CN_ENUM_LINE_RE = re.compile(r"(?m)^\s*([一二三四五六七八九十百零]+)")
ARABIC_ENUM_INLINE_RE = re.compile(r"([;::。.,,])[ \t]*(\d+\.)")
ARABIC_ENUM_LINE_RE = re.compile(r"(?m)^\s*(\d+)\.")
NESTED_ENUM_INLINE_RE = re.compile(r"([;::。.])[ \t]*(\d+)")
NESTED_ENUM_LINE_RE = re.compile(r"(?m)^\s*(\d+)")
COLON_NEWLINE_RE = re.compile(r"\s*\n")
TRAILING_SPACE_RE = re.compile(r"[ \t]+\n")
EXTRA_NEWLINES_RE = re.compile(r"\n{3,}")
TEXT_SPLIT_PATTERN = re.compile(r"[,\uff0c;\uff1b\n\r]+")
TEXT_SPLIT_PATTERN_WITH_DUNHAO = re.compile(r"[,\uff0c;\uff1b、\n\r]+")
PERMIT_IMPORT_TTL_SECONDS = 1800
MAX_PERMIT_FILE_SIZE_BYTES = 500 * 1024 # 500 KB limit for uploaded Excel files
_PERMIT_IMPORT_SESSIONS: Dict[str, Dict[str, Any]] = {}
_PERMIT_IMPORT_LOCK = threading.Lock()
_IMPORT_HEADER_ALIASES: Dict[str, Set[str]] = {
"theme_names": {
"主题",
"主题事项",
"主题名称",
"所属主题",
"事项主题",
"主题分类",
},
"permit_name": {
"许可事项",
"许可名称",
"事项名称",
"事项",
"事项全称",
"事项标题",
},
"risk_content": {
"风险提示",
"风险内容",
"风险点",
"风险描述",
"风险信息",
"风险要点",
},
"legal_basis": {
"法律依据",
"主要法律依据",
"依据内容",
"法规依据",
},
"document_no": {
"依据文号",
"文号",
"法规文号",
"编号",
},
"summary": {
"风险说明",
"摘要",
"备注",
"风险摘要",
"补充说明",
},
"permit_status": {
"许可状态",
"事项状态",
"审批状态",
"状态",
"许可情况",
"事项情况",
},
"subitem_summary": {
"子项说明",
"子项摘要",
"风险小项",
"子项概述",
"子项备注",
"许可(备案)事项子项",
},
"responsible_contact": {
"责任部门",
"责任单位",
"责任主体",
"主管部门",
"负责部门",
"负责部门及联系方式",
},
"jurisdiction_scope": {
"适用范围",
"管辖范围",
"权限划分",
"适用区域",
"适用地区",
},
"scope_text": {
"经营范围",
"业务范围",
"范围说明",
"经营项目",
},
"subitem_text": {
"子项",
"办理子项",
"事项子项",
"细化子项",
},
}
_IMPORT_HEADER_KEYWORDS: List[Tuple[str, Tuple[str, ...]]] = [
("theme_names", ("主题",)),
("permit_status", ("情况", "状态")),
("permit_name", ("许可", "事项")),
("risk_content", ("风险",)),
("legal_basis", ("依据",)),
("document_no", ("文号", "编号")),
("summary", ("备注", "摘要")),
("responsible_contact", ("责任", "主管")),
("jurisdiction_scope", ("范围", "区域")),
]
_PERMIT_SOURCES_TABLE_PRESENT: Optional[bool] = None
_PERMIT_SOURCES_TABLE_LOCK = threading.Lock()
_PERMIT_FILE_SCHEMA_READY: Optional[bool] = None
_PERMIT_FILE_SCHEMA_LOCK = threading.Lock()
_CANONICAL_REGION_KEYWORDS: Dict[str, Tuple[str, ...]] = {
"市级": ("市级", "全市", "佛山市本级", "佛山市市级"),
"禅城区": ("禅城区", "禅城"),
"南海区": ("南海区", "南海"),
"顺德区": ("顺德区", "顺德"),
"三水区": ("三水区", "三水"),
"高明区": ("高明区", "高明"),
}
def _format_summary_markdown(summary: str) -> str:
"""Render Chinese legal excerpts as Markdown-friendly text."""
if not summary:
return ""
text = summary.replace("\r\n", "\n").strip()
if not text:
return ""
text = ARTICLE_HEADING_RE.sub(lambda m: f"**{m.group(1)}**", text)
text = CN_ENUM_INLINE_RE.sub(lambda m: f"{m.group(1)}\n- {m.group(2)} ", text)
text = CN_ENUM_LINE_RE.sub(lambda m: f"- {m.group(1)} ", text)
text = ARABIC_ENUM_INLINE_RE.sub(lambda m: f"{m.group(1)}\n {m.group(2)}", text)
text = ARABIC_ENUM_LINE_RE.sub(lambda m: f" {m.group(1)}.", text)
text = NESTED_ENUM_INLINE_RE.sub(lambda m: f"{m.group(1)}\n - {m.group(2)}", text)
text = NESTED_ENUM_LINE_RE.sub(lambda m: f" - {m.group(1)}", text)
text = ARTICLE_TOKEN_RE.sub(lambda m: f"**{m.group(1)}**", text)
text = ARTICLE_NEWLINE_RE.sub(lambda m: f"\n{m.group(1)}", text)
text = COLON_NEWLINE_RE.sub("\n", text)
text = EXTRA_NEWLINES_RE.sub("\n\n", text)
text = TRAILING_SPACE_RE.sub("\n", text)
text = re.sub(r"\n\s+\n", "\n\n", text)
return text.strip()
def _clean_text(value: Any) -> str:
"""Return a stripped string representation for Excel parsing."""
if value is None:
return ""
if isinstance(value, str):
return value.strip()
return str(value).strip()
def _normalize_header_label(value: Any) -> str:
"""Normalise header labels by removing spaces and lowercasing."""
if value is None:
return ""
text = _clean_text(value)
if not text:
return ""
compact = re.sub(r"[\s\u3000]+", "", text)
return compact.lower()
def _resolve_import_header(value: Any) -> Optional[str]:
"""Map an Excel header cell to a canonical field name."""
normalized = _normalize_header_label(value)
if not normalized:
return None
for canonical, candidates in _IMPORT_HEADER_ALIASES.items():
if normalized in candidates:
return canonical
for canonical, keywords in _IMPORT_HEADER_KEYWORDS:
if any(keyword in normalized for keyword in keywords):
return canonical
return None
def _score_import_header(canonical: str, cell_text: str, col_idx: int) -> float:
"""Heuristic score to choose the best header cell when duplicates exist."""
score = float(len(cell_text))
text = cell_text
if canonical == "risk_content":
if "内容" in text:
score += 10
if "提示" in text:
score += 4
if "风险" in text:
score += 2
elif canonical == "permit_name":
if "事项" in text:
score += 6
if "名称" in text:
score += 3
if "许可" in text:
score += 2
elif canonical == "theme_names":
if "主题" in text:
score += 4
elif canonical == "permit_status":
if "情况" in text or "状态" in text:
score += 3
elif canonical == "summary":
if "摘要" in text:
score += 3
score += col_idx * 0.1
return score
def _split_multi_value(value: Any, *, allow_dunhao: bool = False) -> List[str]:
"""Split multi-value cells using common punctuation characters.
默认不把中文顿号(、)视作分隔符,以避免误拆“文化、旅游”等合法的
许可名称。对于确实需要用顿号分隔的字段(如主题、经营范围等),调用
方可以显式传入 allow_dunhao=True。
"""
text = _clean_text(value)
if not text:
return []
pattern = TEXT_SPLIT_PATTERN_WITH_DUNHAO if allow_dunhao else TEXT_SPLIT_PATTERN
return [item.strip() for item in pattern.split(text) if item.strip()]
def _clean_empty(value: Any) -> Optional[str]:
"""Convert empty strings to None for database writes."""
text = _clean_text(value)
return text or None
def _canonicalize_region_label(label: str) -> str:
"""Return a canonical region label based on known keywords."""
text = _clean_text(label)
if not text:
return ""
for canonical, keywords in _CANONICAL_REGION_KEYWORDS.items():
for keyword in keywords:
if keyword and keyword in text:
return canonical
return text
def _normalize_import_row(
raw_row: Dict[str, Any],
sheet_name: str,
sheet_defaults: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
"""Convert a raw Excel row into canonical import structure."""
sheet_defaults = sheet_defaults or {}
permit_name = _clean_text(raw_row.get("permit_name") or sheet_defaults.get("permit_name"))
risk_content = _clean_text(raw_row.get("risk_content"))
if not permit_name or not risk_content:
return None
row_index = raw_row.get("row_index")
legal_basis = _clean_empty(raw_row.get("legal_basis"))
document_no = _clean_empty(raw_row.get("document_no"))
summary = _clean_empty(raw_row.get("summary"))
permit_status = _clean_empty(raw_row.get("permit_status") or sheet_defaults.get("permit_status"))
subitem_summary = _clean_empty(raw_row.get("subitem_summary") or sheet_defaults.get("subitem_summary"))
responsible_contact = _clean_empty(
raw_row.get("responsible_contact") or sheet_defaults.get("responsible_contact")
)
jurisdiction_scope = _clean_empty(
raw_row.get("jurisdiction_scope") or sheet_defaults.get("jurisdiction_scope")
)
theme_names = _split_multi_value(
raw_row.get("theme_names") or sheet_defaults.get("theme_names"), allow_dunhao=True
)
scope_descriptions = _split_multi_value(
raw_row.get("scope_text") or sheet_defaults.get("scope_text"), allow_dunhao=True
)
subitem_names = _split_multi_value(
raw_row.get("subitem_text") or sheet_defaults.get("subitem_text"), allow_dunhao=True
)
return {
"row_index": int(row_index) if isinstance(row_index, int) else row_index,
"sheet_name": sheet_name,
"permit_name": permit_name,
"theme_names": theme_names,
"risk_content": risk_content,
"legal_basis": legal_basis,
"document_no": document_no,
"summary": summary,
"permit_status": permit_status,
"subitem_summary": subitem_summary,
"responsible_contact": responsible_contact,
"jurisdiction_scope": jurisdiction_scope,
"scope_descriptions": scope_descriptions,
"subitem_names": subitem_names,
}
def _parse_import_workbook(file_bytes: bytes, filename: str) -> Dict[str, Any]:
"""Parse the uploaded Excel workbook into structured sheet data."""
if not file_bytes:
raise ValueError("上传文件为空")
try:
workbook = load_workbook(BytesIO(file_bytes), data_only=True)
except InvalidFileException as exc:
raise ValueError(f"Excel 文件格式无法识别:{exc}") from exc
except Exception as exc:
raise ValueError(f"Excel 解析失败:{exc}") from exc
sheets: Dict[str, Dict[str, Any]] = {}
total_rows = 0
for worksheet in workbook.worksheets:
sheet_title = worksheet.title or ""
sheet_name = _clean_text(sheet_title) or f"Sheet{len(sheets) + 1}"
header_row_index: Optional[int] = None
header_values: List[Any] = []
header_map: Dict[int, str] = {}
resolved_by_name: Dict[str, Tuple[int, int, str]] = {}
metadata_rows: List[Tuple[int, Dict[int, Tuple[str, str]], Tuple[Any, ...]]] = []
max_header_row = min(worksheet.max_row or 0, 120) or 120
for row_idx, row_values in enumerate(
worksheet.iter_rows(min_row=1, max_row=max_header_row, values_only=True),
start=1,
):
if not row_values:
continue
if not any(_clean_text(cell) for cell in row_values):
continue
first_cell_text = _clean_text(row_values[0]) if len(row_values) else ""
is_section_row = bool(first_cell_text) and bool(re.fullmatch(r"[一二三四五六七八九十百零]+", first_cell_text))
row_candidate_map: Dict[str, Tuple[int, str, float]] = {}
for col_idx, header_cell in enumerate(row_values, start=1):
cell_text = _clean_text(header_cell)
if not cell_text:
continue
if len(cell_text) > 60:
continue
canonical = _resolve_import_header(cell_text)
if not canonical:
continue
if (
is_section_row
and col_idx >= 3
and canonical in {"permit_name", "permit_status", "responsible_contact", "theme_names", "subitem_summary"}
):
continue
if len(cell_text) > 35 and canonical in {"permit_name", "risk_content", "summary"}:
continue
score = _score_import_header(canonical, cell_text, col_idx)
previous = row_candidate_map.get(canonical)
if not previous or score > previous[2] or (score == previous[2] and col_idx < previous[0]):
row_candidate_map[canonical] = (col_idx, cell_text, score)
candidate_map: Dict[int, Tuple[str, str]] = {
col_idx: (canonical, cell_text)
for canonical, (col_idx, cell_text, _score) in row_candidate_map.items()
}
row_canonicals: Set[str] = {canonical for canonical, _ in candidate_map.values()}
metadata_rows.append((row_idx, candidate_map, tuple(row_values)))
if candidate_map:
display_entries = []
for idx, (name, text) in sorted(candidate_map.items()):
preview = text[:30] + ("" if len(text) > 30 else "")
display_entries.append(f"{idx}->{name}({preview})")
logger.info(
"[PERMIT-IMPORT] Sheet %s candidate row %d resolved: %s",
sheet_name,
row_idx,
", ".join(display_entries),
)
# Update resolved columns, allowing later rows to refine mapping.
for col_idx, (canonical, cell_text) in candidate_map.items():
previous = resolved_by_name.get(canonical)
if previous and previous[0] != col_idx:
header_map.pop(previous[0], None)
resolved_by_name[canonical] = (col_idx, row_idx, cell_text)
header_map[col_idx] = canonical
if (
"risk_content" in row_canonicals
and (
{"legal_basis", "document_no", "summary"}.intersection(row_canonicals)
or ("permit_name" in row_canonicals and sheet_defaults.get("permit_name"))
)
):
header_row_index = row_idx
header_values = list(row_values)
break
if header_row_index is None or not header_map:
logger.warning(
"[PERMIT-IMPORT] Sheet %s skipped: 未找到包含许可名称与风险内容的表头",
sheet_name,
)
continue
sheet_defaults: Dict[str, Any] = {}
for row_idx, candidate_map, row_values in metadata_rows:
if row_idx >= header_row_index:
continue
for col_idx, (canonical, _label_text) in candidate_map.items():
value = None
# Prefer the cell to the right of the label as the value.
if col_idx < len(row_values):
value = row_values[col_idx]
if value is None and (col_idx + 1) < len(row_values):
value = row_values[col_idx + 1]
cleaned = _clean_text(value)
if cleaned and canonical not in sheet_defaults:
sheet_defaults[canonical] = value
# Keep only headers that were identified on or before the selected header row.
effective_header_map: Dict[int, str] = {}
for canonical, (col_idx, row_idx, _cell_text) in resolved_by_name.items():
if row_idx == header_row_index:
effective_header_map[col_idx] = canonical
header_map = dict(sorted(effective_header_map.items()))
has_risk_column = "risk_content" in header_map.values()
has_permit_column = "permit_name" in header_map.values()
has_permit_default = bool(_clean_text(sheet_defaults.get("permit_name")))
if not has_risk_column or (not has_permit_column and not has_permit_default):
logger.warning(
"[PERMIT-IMPORT] Sheet %s skipped: 表头缺少许可名称或风险内容列(行 %d | header_map=%s | permit_default=%s",
sheet_name,
header_row_index,
", ".join(f"{idx}:{name}" for idx, name in sorted(header_map.items())),
_clean_text(sheet_defaults.get("permit_name") or ""),
)
continue
logger.info(
"[PERMIT-IMPORT] Sheet %s header row %d candidates: %s",
sheet_name,
header_row_index,
", ".join(_clean_text(cell) or "<空>" for cell in header_values),
)
logger.info(
"[PERMIT-IMPORT] Sheet %s resolved headers: %s",
sheet_name,
", ".join(f"{idx}->{name}" for idx, name in sorted(header_map.items())),
)
normalized_rows: List[Dict[str, Any]] = []
for row_idx, row_values in enumerate(
worksheet.iter_rows(min_row=header_row_index + 1, values_only=True),
start=header_row_index + 1,
):
raw_row: Dict[str, Any] = {"row_index": row_idx}
has_data = False
for col_idx, cell_value in enumerate(row_values, start=1):
canonical = header_map.get(col_idx)
if not canonical:
continue
if cell_value is None or (isinstance(cell_value, str) and not cell_value.strip()):
continue
has_data = True
raw_row[canonical] = cell_value
if not has_data:
logger.debug(
"[PERMIT-IMPORT] Sheet %s row %d ignored: 空行",
sheet_name,
row_idx,
)
continue
normalized = _normalize_import_row(raw_row, sheet_name, sheet_defaults)
if normalized:
normalized_rows.append(normalized)
else:
logger.debug(
"[PERMIT-IMPORT] Sheet %s row %d ignored: 缺少许可名称或风险内容",
sheet_name,
row_idx,
)
if not normalized_rows:
logger.warning(
"[PERMIT-IMPORT] Sheet %s skipped: 没有有效数据行",
sheet_name,
)
continue
sheets[sheet_name] = {
"sheet_name": sheet_name,
"rows": normalized_rows,
}
total_rows += len(normalized_rows)
if not sheets:
logger.error(
"[PERMIT-IMPORT] Workbook %s has no importable sheets (headers missing or rows空)",
filename,
)
raise ValueError("Excel 中未找到可导入的数据")
return {
"filename": os.path.basename(filename or ""),
"sheets": sheets,
"total_rows": total_rows,
}
def _cleanup_expired_import_sessions() -> None:
"""Remove expired import sessions to avoid unbounded growth."""
now = time.time()
expired: List[str] = []
with _PERMIT_IMPORT_LOCK:
for session_id, payload in list(_PERMIT_IMPORT_SESSIONS.items()):
created_at = payload.get("created_at", now)
if now - created_at > PERMIT_IMPORT_TTL_SECONDS:
expired.append(session_id)
for session_id in expired:
_PERMIT_IMPORT_SESSIONS.pop(session_id, None)
def start_permit_import_session(
file_bytes: bytes,
filename: str,
*,
content_type: Optional[str] = None,
uploaded_by: Optional[str] = None,
) -> Dict[str, Any]:
"""Parse the uploaded workbook and create an import session."""
if not file_bytes:
raise ValueError("上传的文件为空")
if len(file_bytes) > MAX_PERMIT_FILE_SIZE_BYTES:
raise ValueError("上传的文件超过 500KB 限制,请拆分或压缩内容后重试")
parsed = _parse_import_workbook(file_bytes, filename)
workbook_filename = parsed.get("filename") or os.path.basename(filename or "")
raw_sheet_payloads: Dict[str, Dict[str, Any]] = parsed.get("sheets", {})
canonical_sheets: Dict[str, Dict[str, Any]] = {}
sheet_name_mapping: Dict[str, str] = {}
for original_name, sheet_data in raw_sheet_payloads.items():
canonical_name = _canonicalize_region_label(original_name) or original_name
sheet_name_mapping[original_name] = canonical_name
bucket = canonical_sheets.setdefault(
canonical_name,
{
"canonical_name": canonical_name,
"original_names": [],
"rows": [],
},
)
bucket["original_names"].append(original_name)
for row in sheet_data.get("rows", []):
row_copy = dict(row)
row_copy["sheet_name"] = canonical_name
bucket["rows"].append(row_copy)
sheet_payloads: Dict[str, Dict[str, Any]] = canonical_sheets
sheet_names = list(sheet_payloads.keys())
sheet_tokens = sorted({name.lower() for name in sheet_names if name})
logger.info(
"[PERMIT-IMPORT] Workbook %s parsed sheets: %s",
workbook_filename or filename,
", ".join(sheet_names),
)
for original_name, canonical_name in sheet_name_mapping.items():
if original_name != canonical_name:
logger.info(
"[PERMIT-IMPORT] Sheet alias: %s -> %s",
original_name,
canonical_name,
)
region_lookup: Dict[str, Dict[str, Any]] = {}
permit_lookup: Dict[str, Dict[str, str]] = {}
if sheet_tokens:
with _lic_pg_conn() as conn:
cur = conn.cursor()
cur.execute(
"""
SELECT id, name
FROM regions
WHERE LOWER(name) = ANY(%s)
""",
(sheet_tokens,),
)
for region_id, region_name in cur.fetchall():
key = str(region_name).lower()
region_lookup[key] = {
"id": str(region_id),
"uuid": region_id,
"name": str(region_name),
}
if region_lookup:
logger.info(
"[PERMIT-IMPORT] Resolved regions: %s",
", ".join(f"{entry['name']}({entry['id']})" for entry in region_lookup.values()),
)
region_ids = [entry["uuid"] for entry in region_lookup.values()]
if region_ids:
cur.execute(
"""
SELECT DISTINCT rtp.region_id, p.id, p.name
FROM region_theme_permits rtp
JOIN permits p ON p.id = rtp.permit_id
WHERE rtp.region_id = ANY(%s)
""",
(region_ids,),
)
for region_id, permit_id, permit_name in cur.fetchall():
rid = str(region_id)
permit_lookup.setdefault(rid, {})[str(permit_name)] = str(permit_id)
sheet_summaries: List[Dict[str, Any]] = []
session_sheets: Dict[str, Dict[str, Any]] = {}
for sheet_name, sheet_data in sheet_payloads.items():
rows: List[Dict[str, Any]] = sheet_data.get("rows", [])
permit_groups: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
for row in rows:
permit_groups[row["permit_name"]].append(row)
region_key = sheet_name.lower()
region_info = region_lookup.get(region_key)
region_id_str = region_info["id"] if region_info else None
existing_permits = permit_lookup.get(region_id_str or "", {})
duplicate_permits = sorted(
[name for name in permit_groups.keys() if name in existing_permits]
)
new_permits = sorted(
[name for name in permit_groups.keys() if name not in existing_permits]
)
logger.info(
"[PERMIT-IMPORT] Sheet %s summary: permits=%d, duplicates=%d, new=%d, missing_region=%s",
sheet_name,
len(permit_groups),
len(duplicate_permits),
len(new_permits),
region_info is None,
)
session_sheets[sheet_name] = {
"sheet_name": sheet_name,
"region_name": sheet_name,
"region_id": region_id_str,
"rows": rows,
"permit_groups": {name: group for name, group in permit_groups.items()},
"existing_permits": dict(existing_permits),
"duplicate_permits": duplicate_permits,
"new_permits": new_permits,
"original_sheet_names": sheet_data.get("original_names", [sheet_name]),
}
sheet_summaries.append(
{
"sheet_name": sheet_name,
"region_name": region_info["name"] if region_info else sheet_name,
"region_id": region_id_str or "",
"row_count": len(rows),
"permit_count": len(permit_groups),
"risk_count": len(rows),
"duplicate_permits": duplicate_permits,
"new_permits": new_permits,
"missing_region": region_info is None,
"original_sheet_names": sheet_data.get("original_names", [sheet_name]),
}
)
_cleanup_expired_import_sessions()
session_id = str(uuid.uuid4())
session_payload = {
"id": session_id,
"filename": workbook_filename,
"created_at": time.time(),
"sheets": session_sheets,
"file_bytes": bytes(file_bytes),
"file_size": len(file_bytes),
"content_type": content_type or "application/octet-stream",
"uploaded_by": uploaded_by,
}
with _PERMIT_IMPORT_LOCK:
_PERMIT_IMPORT_SESSIONS[session_id] = session_payload
return {
"session_id": session_id,
"filename": workbook_filename,
"sheet_summaries": sheet_summaries,
"total_rows": parsed.get("total_rows", 0),
"expires_in": PERMIT_IMPORT_TTL_SECONDS,
"file_size": len(file_bytes),
"content_type": content_type or "application/octet-stream",
}
def _ensure_region(conn: pg.Connection, region_name: str) -> str:
name = _clean_text(region_name)
if not name:
raise ValueError("地区名称不能为空")
cur = conn.cursor()
cur.execute(
"""
INSERT INTO regions (name)
VALUES (%s)
ON CONFLICT (name)
DO UPDATE SET name = EXCLUDED.name
RETURNING id
""",
(name,),
)
region_id = cur.fetchone()[0]
return str(region_id)
def _ensure_theme(conn: pg.Connection, theme_name: str) -> str:
name = _clean_text(theme_name) or "不涉及"
cur = conn.cursor()
cur.execute(
"""
INSERT INTO themes (name)
VALUES (%s)
ON CONFLICT (name)
DO UPDATE SET name = EXCLUDED.name
RETURNING id
""",
(name,),
)
theme_id = cur.fetchone()[0]
return str(theme_id)
def _ensure_permit(conn: pg.Connection, permit_name: str) -> str:
name = _clean_text(permit_name)
if not name:
raise ValueError("许可名称不能为空")
cur = conn.cursor()
cur.execute(
"""
INSERT INTO permits (name)
VALUES (%s)
ON CONFLICT (name)
DO UPDATE SET name = EXCLUDED.name
RETURNING id
""",
(name,),
)
permit_id = cur.fetchone()[0]
return str(permit_id)
def _ensure_business_scope(conn: pg.Connection, description: str) -> Optional[str]:
text = _clean_text(description)
if not text:
return None
cur = conn.cursor()
cur.execute(
"""
INSERT INTO business_scopes (description)
VALUES (%s)
ON CONFLICT (description)
DO UPDATE SET description = EXCLUDED.description
RETURNING id
""",
(text,),
)
scope_id = cur.fetchone()[0]
return str(scope_id)
def _ensure_permit_subitem(conn: pg.Connection, description: str) -> Optional[str]:
text = _clean_text(description)
if not text:
return None
cur = conn.cursor()
cur.execute(
"""
INSERT INTO permit_subitems (description)
VALUES (%s)
ON CONFLICT (description)
DO UPDATE SET description = EXCLUDED.description
RETURNING id
""",
(text,),
)
subitem_id = cur.fetchone()[0]
return str(subitem_id)
def _ensure_risk(
conn: pg.Connection,
*,
risk_content: str,
legal_basis: Optional[str],
document_no: Optional[str],
summary: Optional[str],
) -> str:
cur = conn.cursor()
cur.execute(
"""
INSERT INTO risks (risk_content, legal_basis, document_no, summary)
VALUES (%s, %s, %s, %s)
ON CONFLICT (risk_content, legal_basis, document_no, summary)
DO UPDATE SET risk_content = EXCLUDED.risk_content
RETURNING id
""",
(risk_content, legal_basis, document_no, summary),
)
risk_id = cur.fetchone()[0]
return str(risk_id)
def _fetch_region_permit_name_map(conn: pg.Connection, region_id: str) -> Dict[str, str]:
cur = conn.cursor()
cur.execute(
"""
SELECT DISTINCT p.name, p.id
FROM region_theme_permits rtp
JOIN permits p ON p.id = rtp.permit_id
WHERE rtp.region_id = %s
""",
(region_id,),
)
return {str(name): str(pid) for name, pid in cur.fetchall()}
def _backup_permit_before_import(
conn: pg.Connection,
*,
region_id: str,
permit_id: str,
region_name: str,
permit_name: str,
filename: str,
sheet_name: str,
edited_by: Optional[str],
change_summary: Optional[str],
) -> Dict[str, Any]:
cur = conn.cursor()
cur.execute(
"""
SELECT risk_id
FROM region_permit_risks
WHERE region_id = %s AND permit_id = %s
ORDER BY risk_id
FOR UPDATE
""",
(region_id, permit_id),
)
risk_ids = [str(risk_id) for (risk_id,) in cur.fetchall()]
if not risk_ids:
return {"snapshot_count": 0, "batch_id": ""}
batch_id = str(uuid.uuid4())
base_summary = change_summary or (
f"Excel导入前快照{filename} {sheet_name} {permit_name}"
)
for idx, risk_id in enumerate(risk_ids, start=1):
detail_summary = f"{base_summary} - 风险 {idx}/{len(risk_ids)}"
_create_snapshot_with_connection(
conn,
region_id,
permit_id,
risk_id,
edited_by=edited_by,
change_summary=detail_summary,
batch_id=batch_id,
)
logger.info(
"[PERMIT-IMPORT] Captured %d snapshots before overwriting permit %s (%s) in region %s (%s)",
len(risk_ids),
permit_id,
permit_name,
region_id,
region_name,
)
return {"snapshot_count": len(risk_ids), "batch_id": batch_id}
def _purge_region_permit_relations(conn: pg.Connection, region_id: str, permit_id: str) -> Dict[str, int]:
cur = conn.cursor()
delete_counts: Dict[str, int] = {}
statements = [
("region_permit_risks", "DELETE FROM region_permit_risks WHERE region_id = %s AND permit_id = %s"),
("region_permit_scopes", "DELETE FROM region_permit_scopes WHERE region_id = %s AND permit_id = %s"),
("region_permit_subitems", "DELETE FROM region_permit_subitems WHERE region_id = %s AND permit_id = %s"),
("region_permit_details", "DELETE FROM region_permit_details WHERE region_id = %s AND permit_id = %s"),
("region_theme_permits", "DELETE FROM region_theme_permits WHERE region_id = %s AND permit_id = %s"),
]
for key, sql in statements:
cur.execute(sql, (region_id, permit_id))
delete_counts[key] = int(cur.rowcount or 0)
if _permit_sources_available(conn):
cur.execute(
"""
DELETE FROM permit_sources
WHERE region_id = %s AND permit_id = %s
""",
(region_id, permit_id),
)
delete_counts["permit_sources"] = int(cur.rowcount or 0)
return delete_counts
def commit_permit_import_session(
session_id: str,
sheet_names: Iterable[str],
*,
overrides: Optional[Dict[str, Iterable[str]]] = None,
edited_by: Optional[str] = None,
change_summary: Optional[str] = None,
) -> Dict[str, Any]:
if not session_id:
raise ValueError("导入会话无效")
selected_sheets: List[str] = []
seen_sheet_names: Set[str] = set()
for sheet_name in sheet_names or []:
name = _clean_text(sheet_name)
if not name or name in seen_sheet_names:
continue
seen_sheet_names.add(name)
selected_sheets.append(name)
if not selected_sheets:
raise ValueError("请选择至少一个Sheet进行导入")
with _PERMIT_IMPORT_LOCK:
session_payload = _PERMIT_IMPORT_SESSIONS.get(session_id)
if not session_payload:
raise ValueError("导入会话不存在或已过期请重新上传Excel文件")
session_file_bytes: Optional[bytes] = session_payload.get("file_bytes")
session_file_content_type: str = session_payload.get("content_type") or "application/octet-stream"
session_uploaded_by: Optional[str] = session_payload.get("uploaded_by")
session_sheets: Dict[str, Dict[str, Any]] = session_payload.get("sheets", {})
workbook_filename = session_payload.get("filename") or ""
overrides_map: Dict[str, Set[str]] = {}
if overrides:
for sheet_key, permit_names in overrides.items():
sheet_token = _clean_text(sheet_key)
if not sheet_token:
continue
overrides_map[sheet_token] = {
_clean_text(name) for name in (permit_names or []) if _clean_text(name)
}
default_change_summary = change_summary or (f"Excel导入{workbook_filename}" if workbook_filename else "Excel导入")
result: Dict[str, Any] = {
"session_id": session_id,
"filename": workbook_filename,
"processed_sheets": [],
"created_permits": [],
"overwritten_permits": [],
"skipped_permits": [],
"snapshot_count": 0,
"risk_count": 0,
}
logger.info(
"[PERMIT-IMPORT] Committing session %s with %d sheet(s): %s",
session_id,
len(selected_sheets),
", ".join(selected_sheets),
)
stored_file_meta: Optional[Dict[str, Any]] = None
stored_file_id: Optional[str] = None
with _lic_pg_conn(autocommit=False) as conn:
try:
_ensure_permit_sources_table(conn)
if session_file_bytes:
_ensure_permit_file_schema(conn)
stored_file_meta = _insert_permit_file_record(
conn,
file_bytes=session_file_bytes,
filename=workbook_filename or "许可导入.xlsx",
content_type=session_file_content_type,
uploaded_by=session_uploaded_by,
)
stored_file_id = stored_file_meta.get("file_id")
cur = conn.cursor()
for sheet_name in selected_sheets:
sheet_data = session_sheets.get(sheet_name)
if not sheet_data:
raise ValueError(f"导入会话中未找到名为 {sheet_name} 的Sheet")
region_name = sheet_data.get("region_name") or sheet_name
region_id = sheet_data.get("region_id")
if region_id:
# 确保地区仍然存在
existing_map = _fetch_region_permit_name_map(conn, region_id)
sheet_data["existing_permits"] = existing_map
else:
region_id = _ensure_region(conn, region_name)
sheet_data["region_id"] = region_id
sheet_data["existing_permits"] = {}
existing_permits = dict(sheet_data.get("existing_permits", {}))
override_set = overrides_map.get(sheet_name, set())
permit_groups: Dict[str, List[Dict[str, Any]]] = sheet_data.get("permit_groups", {})
sheet_snapshot_count = 0
sheet_risk_count = 0
sheet_created: List[str] = []
sheet_overwritten: List[str] = []
sheet_skipped: List[str] = []
for permit_name, permit_rows in permit_groups.items():
canonical_permit_name = _clean_text(permit_name)
if not canonical_permit_name:
continue
permit_id = existing_permits.get(canonical_permit_name)
should_override = canonical_permit_name in override_set
permit_modified = False
if permit_id and not should_override:
sheet_skipped.append(canonical_permit_name)
result["skipped_permits"].append(
{
"sheet": sheet_name,
"permit_name": canonical_permit_name,
"region_id": region_id,
"reason": "exists",
}
)
continue
if permit_id:
backup_info = _backup_permit_before_import(
conn,
region_id=region_id,
permit_id=permit_id,
region_name=region_name,
permit_name=canonical_permit_name,
filename=workbook_filename,
sheet_name=sheet_name,
edited_by=edited_by,
change_summary=default_change_summary,
)
sheet_snapshot_count += backup_info["snapshot_count"]
result["snapshot_count"] += backup_info["snapshot_count"]
_purge_region_permit_relations(conn, region_id, permit_id)
sheet_overwritten.append(canonical_permit_name)
result["overwritten_permits"].append(
{
"sheet": sheet_name,
"permit_name": canonical_permit_name,
"region_id": region_id,
"snapshot_batch_id": backup_info.get("batch_id", ""),
}
)
permit_modified = True
else:
permit_id = _ensure_permit(conn, canonical_permit_name)
existing_permits[canonical_permit_name] = permit_id
sheet_created.append(canonical_permit_name)
result["created_permits"].append(
{
"sheet": sheet_name,
"permit_name": canonical_permit_name,
"region_id": region_id,
}
)
permit_modified = True
theme_names: Set[str] = set()
scope_descriptions: Set[str] = set()
subitem_names: Set[str] = set()
permit_status_val: Optional[str] = None
subitem_summary_val: Optional[str] = None
responsible_contact_val: Optional[str] = None
jurisdiction_scope_val: Optional[str] = None
for row in permit_rows:
theme_names.update(row.get("theme_names") or [])
scope_descriptions.update(row.get("scope_descriptions") or [])
subitem_names.update(row.get("subitem_names") or [])
if not permit_status_val and row.get("permit_status"):
permit_status_val = row.get("permit_status")
if not subitem_summary_val and row.get("subitem_summary"):
subitem_summary_val = row.get("subitem_summary")
if not responsible_contact_val and row.get("responsible_contact"):
responsible_contact_val = row.get("responsible_contact")
if not jurisdiction_scope_val and row.get("jurisdiction_scope"):
jurisdiction_scope_val = row.get("jurisdiction_scope")
if not theme_names:
theme_names.add("不涉及")
theme_ids: List[str] = []
for theme_name in sorted(theme_names):
theme_id = _ensure_theme(conn, theme_name)
theme_ids.append(theme_id)
cur.execute(
"""
INSERT INTO region_themes (region_id, theme_id)
VALUES (%s, %s)
ON CONFLICT DO NOTHING
""",
(region_id, theme_id),
)
cur.execute(
"""
INSERT INTO region_theme_permits (region_id, theme_id, permit_id)
VALUES (%s, %s, %s)
ON CONFLICT DO NOTHING
""",
(region_id, theme_id, permit_id),
)
for scope_desc in sorted(scope_descriptions):
scope_id = _ensure_business_scope(conn, scope_desc)
if not scope_id:
continue
cur.execute(
"""
INSERT INTO region_permit_scopes (region_id, permit_id, scope_id)
VALUES (%s, %s, %s)
ON CONFLICT DO NOTHING
""",
(region_id, permit_id, scope_id),
)
for subitem_desc in sorted(subitem_names):
subitem_id = _ensure_permit_subitem(conn, subitem_desc)
if not subitem_id:
continue
cur.execute(
"""
INSERT INTO region_permit_subitems (region_id, permit_id, subitem_id)
VALUES (%s, %s, %s)
ON CONFLICT DO NOTHING
""",
(region_id, permit_id, subitem_id),
)
cur.execute(
"""
INSERT INTO region_permit_details (
region_id,
permit_id,
permit_status,
subitem_summary,
responsible_contact,
jurisdiction_scope
)
VALUES (%s, %s, %s, %s, %s, %s)
ON CONFLICT (region_id, permit_id)
DO UPDATE SET
permit_status = EXCLUDED.permit_status,
subitem_summary = EXCLUDED.subitem_summary,
responsible_contact = EXCLUDED.responsible_contact,
jurisdiction_scope = EXCLUDED.jurisdiction_scope,
updated_at = now()
""",
(
region_id,
permit_id,
permit_status_val,
subitem_summary_val,
responsible_contact_val,
jurisdiction_scope_val,
),
)
for row in permit_rows:
risk_id = _ensure_risk(
conn,
risk_content=row.get("risk_content", ""),
legal_basis=row.get("legal_basis"),
document_no=row.get("document_no"),
summary=row.get("summary"),
)
cur.execute(
"""
INSERT INTO region_permit_risks (region_id, permit_id, risk_id)
VALUES (%s, %s, %s)
ON CONFLICT DO NOTHING
""",
(region_id, permit_id, risk_id),
)
sheet_risk_count += len(permit_rows)
result["risk_count"] += len(permit_rows)
source_detail_payload = {
"sheet": sheet_name,
"permit": canonical_permit_name,
"risk_rows": len(permit_rows),
"imported_at": datetime.utcnow().isoformat(),
}
cur.execute(
"""
INSERT INTO permit_sources (
region_id,
permit_id,
source_type,
source_name,
source_detail,
created_at,
updated_at
)
VALUES (%s, %s, %s, %s, %s, now(), now())
ON CONFLICT (region_id, permit_id)
DO UPDATE SET
source_type = EXCLUDED.source_type,
source_name = EXCLUDED.source_name,
source_detail = EXCLUDED.source_detail,
updated_at = now()
""",
(
region_id,
permit_id,
"excel",
workbook_filename or sheet_name,
json.dumps(source_detail_payload, ensure_ascii=False),
),
)
if stored_file_id and permit_modified:
_link_file_to_permit(
conn,
file_id=stored_file_id,
region_id=region_id,
permit_id=permit_id,
created_by=session_uploaded_by or edited_by,
)
result["processed_sheets"].append(
{
"sheet_name": sheet_name,
"region_id": region_id,
"region_name": region_name,
"created_permits": sheet_created,
"overwritten_permits": sheet_overwritten,
"skipped_permits": sheet_skipped,
"snapshot_count": sheet_snapshot_count,
"risk_count": sheet_risk_count,
}
)
logger.info(
"[PERMIT-IMPORT] Sheet %s processed (region=%s) -> created=%d, overwritten=%d, skipped=%d, risks=%d, snapshots=%d",
sheet_name,
region_id,
len(sheet_created),
len(sheet_overwritten),
len(sheet_skipped),
sheet_risk_count,
sheet_snapshot_count,
)
conn.commit()
except Exception:
conn.rollback()
raise
result["file_attachment"] = stored_file_meta
with _PERMIT_IMPORT_LOCK:
_PERMIT_IMPORT_SESSIONS.pop(session_id, None)
logger.info(
"[PERMIT-IMPORT] Completed session %s: created=%d overwritten=%d skipped=%d snapshots=%d risks=%d",
session_id,
len(result["created_permits"]),
len(result["overwritten_permits"]),
len(result["skipped_permits"]),
result["snapshot_count"],
result["risk_count"],
)
return result
def _permit_sources_available(conn: pg.Connection) -> bool:
"""Return True if permit_sources table exists (cached)."""
global _PERMIT_SOURCES_TABLE_PRESENT
if _PERMIT_SOURCES_TABLE_PRESENT is True:
return True
if _PERMIT_SOURCES_TABLE_PRESENT is False:
return False
with _PERMIT_SOURCES_TABLE_LOCK:
if _PERMIT_SOURCES_TABLE_PRESENT is not None:
return bool(_PERMIT_SOURCES_TABLE_PRESENT)
cur = conn.cursor()
cur.execute(
"""
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = 'permit_sources'
LIMIT 1
"""
)
exists = cur.fetchone() is not None
_PERMIT_SOURCES_TABLE_PRESENT = exists
return exists
def _create_permit_sources_schema(conn: pg.Connection) -> None:
"""Create permit_sources table and ancillary indexes if missing."""
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS permit_sources (
region_id uuid NOT NULL REFERENCES regions(id) ON DELETE CASCADE,
permit_id uuid NOT NULL REFERENCES permits(id) ON DELETE CASCADE,
source_type text NOT NULL,
source_name text NOT NULL,
source_detail text,
created_at timestamptz NOT NULL DEFAULT now(),
updated_at timestamptz NOT NULL DEFAULT now(),
PRIMARY KEY (region_id, permit_id)
)
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS permit_sources_source_name_idx
ON permit_sources (source_name)
"""
)
def _ensure_permit_sources_table(conn: Optional[pg.Connection] = None) -> None:
"""Ensure the permit_sources table exists and cache the result."""
global _PERMIT_SOURCES_TABLE_PRESENT
if _PERMIT_SOURCES_TABLE_PRESENT is True:
return
with _PERMIT_SOURCES_TABLE_LOCK:
if _PERMIT_SOURCES_TABLE_PRESENT is True:
return
if conn is not None:
original_autocommit = conn.autocommit
try:
conn.autocommit = True
_create_permit_sources_schema(conn)
finally:
conn.autocommit = original_autocommit
else:
with _lic_pg_conn(autocommit=True) as ensure_conn:
_create_permit_sources_schema(ensure_conn)
_PERMIT_SOURCES_TABLE_PRESENT = True
def _create_permit_file_schema(conn: pg.Connection) -> None:
"""Create permit file storage tables on demand."""
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS permit_files (
id uuid PRIMARY KEY,
filename text NOT NULL,
content_type text,
file_size integer NOT NULL,
file_data bytea NOT NULL,
checksum text,
uploaded_by text,
created_at timestamptz NOT NULL DEFAULT now()
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS permit_file_links (
id uuid PRIMARY KEY,
region_id uuid NOT NULL REFERENCES regions(id) ON DELETE CASCADE,
permit_id uuid NOT NULL REFERENCES permits(id) ON DELETE CASCADE,
file_id uuid NOT NULL REFERENCES permit_files(id) ON DELETE CASCADE,
created_by text,
created_at timestamptz NOT NULL DEFAULT now()
)
"""
)
cur.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS permit_file_links_region_permit_idx
ON permit_file_links (region_id, permit_id)
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS permit_file_links_permit_idx
ON permit_file_links (permit_id)
"""
)
def _ensure_permit_file_schema(conn: Optional[pg.Connection] = None) -> None:
"""Ensure permit file tables exist (lazy creation, thread safe)."""
global _PERMIT_FILE_SCHEMA_READY
if _PERMIT_FILE_SCHEMA_READY:
return
with _PERMIT_FILE_SCHEMA_LOCK:
if _PERMIT_FILE_SCHEMA_READY:
return
if conn is not None:
original_autocommit = conn.autocommit
try:
conn.autocommit = True
_create_permit_file_schema(conn)
finally:
conn.autocommit = original_autocommit
else:
with _lic_pg_conn(autocommit=True) as ensure_conn:
_create_permit_file_schema(ensure_conn)
_PERMIT_FILE_SCHEMA_READY = True
def _insert_permit_file_record(
conn: pg.Connection,
*,
file_bytes: bytes,
filename: str,
content_type: Optional[str],
uploaded_by: Optional[str],
) -> Dict[str, Any]:
"""Persist an uploaded file and return its metadata."""
normalized_name = filename or "许可导入.xlsx"
content_type = content_type or "application/octet-stream"
file_id = uuid.uuid4()
checksum = hashlib.sha256(file_bytes).hexdigest()
cur = conn.cursor()
cur.execute(
"""
INSERT INTO permit_files (
id,
filename,
content_type,
file_size,
file_data,
checksum,
uploaded_by
)
VALUES (%s, %s, %s, %s, %s, %s, %s)
""",
(
file_id,
normalized_name,
content_type,
len(file_bytes),
file_bytes,
checksum,
uploaded_by,
),
)
return {
"file_id": str(file_id),
"filename": normalized_name,
"content_type": content_type,
"file_size": len(file_bytes),
"checksum": checksum,
"uploaded_by": uploaded_by,
}
def _link_file_to_permit(
conn: pg.Connection,
*,
file_id: str,
region_id: str,
permit_id: str,
created_by: Optional[str],
) -> None:
"""Associate a stored file with a region-permit pair."""
if not (file_id and region_id and permit_id):
return
rid = str(region_id)
pid = str(permit_id)
cur = conn.cursor()
cur.execute(
"""
INSERT INTO permit_file_links (id, region_id, permit_id, file_id, created_by)
VALUES (%s, %s, %s, %s, %s)
ON CONFLICT (region_id, permit_id)
DO UPDATE SET
file_id = EXCLUDED.file_id,
created_by = EXCLUDED.created_by,
created_at = now()
""",
(uuid.uuid4(), rid, pid, file_id, created_by),
)
def _load_permit_file_metadata(
conn: pg.Connection,
region_id: str,
permit_ids: Iterable[str],
) -> Dict[str, Dict[str, Any]]:
"""Load file metadata for a batch of permits."""
ids = [str(pid) for pid in permit_ids if pid]
if not ids:
return {}
_ensure_permit_file_schema(conn)
rows = _select_permit_files(conn, region_id, ids)
out: Dict[str, Dict[str, Any]] = {}
for row in rows:
(
permit_id,
file_id,
filename,
content_type,
file_size,
created_at,
uploaded_by,
) = row
out[str(permit_id)] = {
"file_id": str(file_id),
"filename": filename or "",
"content_type": content_type or "",
"file_size": int(file_size or 0),
"created_at": created_at.isoformat() if created_at else None,
"uploaded_by": uploaded_by or "",
}
return out
def _select_permit_files(conn: pg.Connection, region_id: str, permit_ids: Iterable[str]):
"""Execute the permit file metadata query, recreating tables if missing."""
sql = """
SELECT
pfl.permit_id,
pf.id,
pf.filename,
pf.content_type,
pf.file_size,
pf.created_at,
pf.uploaded_by
FROM permit_file_links pfl
JOIN permit_files pf ON pf.id = pfl.file_id
WHERE pfl.region_id = %s
AND pfl.permit_id = ANY(%s)
"""
attempts = 0
while attempts < 2:
try:
cur = conn.cursor()
cur.execute(sql, (region_id, permit_ids))
return cur.fetchall()
except pg.DatabaseError as exc: # type: ignore[attr-defined]
sqlstate = getattr(exc, "sqlstate", "")
if sqlstate != "42P01":
raise
attempts += 1
try:
conn.rollback()
except Exception:
pass
global _PERMIT_FILE_SCHEMA_READY
_PERMIT_FILE_SCHEMA_READY = None
_ensure_permit_file_schema(conn)
if attempts >= 2:
logger.warning("[PERMIT-FILES] permit_file_links table missing after recreate attempt, skipping metadata fetch")
return []
return []
def _lic_pg_conn(autocommit: bool = False) -> pg.Connection:
host = os.getenv("LIC_PG_HOST", "172.24.240.1")
port = int(os.getenv("LIC_PG_PORT", os.getenv("PG_PORT", "5432")))
user = os.getenv("LIC_PG_USER", os.getenv("PG_USER", "postgres"))
password = os.getenv("LIC_PG_PASSWORD", "")
database = os.getenv("LIC_PG_DATABASE", LIC_DEFAULT_DB)
conn = pg.connect(host=host, port=port, user=user, password=password, database=database)
conn.autocommit = autocommit
return conn
def list_region_theme_options() -> List[Dict[str, str]]:
"""Return all region-theme pairs usable for LLM selection."""
sql = """
SELECT
rt.region_id,
r.name AS region_name,
rt.theme_id,
t.name AS theme_name
FROM region_themes rt
JOIN regions r ON r.id = rt.region_id
JOIN themes t ON t.id = rt.theme_id
ORDER BY r.name, t.name
"""
out: List[Dict[str, str]] = []
with _lic_pg_conn() as conn:
cur = conn.cursor()
cur.execute(sql)
for region_id, region_name, theme_id, theme_name in cur.fetchall():
rid = str(region_id)
tid = str(theme_id)
out.append(
{
"option_id": f"{rid}:{tid}",
"region_id": rid,
"region_name": str(region_name),
"theme_id": tid,
"theme_name": str(theme_name),
"display_name": f"{region_name} · {theme_name}",
}
)
return out
def load_business_scopes(region_id: str) -> List[Dict[str, str]]:
"""List business scopes bound to a region."""
sql = """
SELECT bs.id, bs.description
FROM region_scopes rs
JOIN business_scopes bs ON bs.id = rs.scope_id
WHERE rs.region_id = %s
ORDER BY bs.description
"""
scopes: List[Dict[str, str]] = []
with _lic_pg_conn() as conn:
cur = conn.cursor()
cur.execute(sql, (region_id,))
for scope_id, description in cur.fetchall():
scopes.append({"id": str(scope_id), "description": str(description)})
return scopes
def list_permits_for_region(region: str) -> List[Dict[str, str]]:
"""Return all permits available within a region (accepts id or name)."""
sql = """
SELECT DISTINCT p.id, p.name
FROM region_theme_permits rtp
JOIN permits p ON p.id = rtp.permit_id
JOIN regions r ON r.id = rtp.region_id
WHERE rtp.region_id::text = %s OR LOWER(r.name) = LOWER(%s)
ORDER BY p.name
"""
permits: List[Dict[str, str]] = []
with _lic_pg_conn() as conn:
cur = conn.cursor()
cur.execute(sql, (region, region))
for permit_id, permit_name in cur.fetchall():
permits.append({"id": str(permit_id), "name": str(permit_name)})
return permits
def list_region_permit_catalog(region_id: str) -> List[Dict[str, Any]]:
"""Return permit entries for a region, including owning theme and risk count."""
sql = """
SELECT
rtp.permit_id,
p.name AS permit_name,
rtp.theme_id,
COALESCE(t.name, '') AS theme_name,
COUNT(rpr.risk_id) AS risk_count
FROM region_theme_permits rtp
JOIN permits p ON p.id = rtp.permit_id
LEFT JOIN themes t ON t.id = rtp.theme_id
LEFT JOIN region_permit_risks rpr
ON rpr.region_id = rtp.region_id
AND rpr.permit_id = rtp.permit_id
WHERE rtp.region_id = %s
GROUP BY rtp.permit_id, p.name, rtp.theme_id, t.name
ORDER BY LOWER(p.name), LOWER(COALESCE(t.name, ''))
"""
catalog: List[Dict[str, Any]] = []
with _lic_pg_conn() as conn:
cur = conn.cursor()
cur.execute(sql, (region_id,))
for permit_id, permit_name, theme_id, theme_name, risk_count in cur.fetchall():
catalog.append(
{
"id": str(permit_id),
"name": str(permit_name),
"theme": {
"id": str(theme_id) if theme_id else "",
"name": str(theme_name) if theme_name else "",
},
"risk_count": int(risk_count or 0),
}
)
return catalog
def resolve_region_permit_theme(region_id: str, permit_id: str) -> Optional[Dict[str, str]]:
"""Return the first theme associated with a region-permit pair (if any)."""
sql = """
SELECT rtp.theme_id, t.name
FROM region_theme_permits rtp
LEFT JOIN themes t ON t.id = rtp.theme_id
WHERE rtp.region_id = %s
AND rtp.permit_id = %s
ORDER BY t.name NULLS LAST
LIMIT 1
"""
with _lic_pg_conn() as conn:
cur = conn.cursor()
cur.execute(sql, (region_id, permit_id))
row = cur.fetchone()
if not row:
return None
theme_id, theme_name = row
return {
"id": str(theme_id) if theme_id else "",
"name": str(theme_name) if theme_name else "",
}
def _load_permit_scopes_for_region(
conn: pg.Connection, region_id: str, permit_ids: List[str]
) -> Dict[str, List[Dict[str, str]]]:
"""Return mapping of permit_id -> business scopes for that permit within region."""
scope_map: Dict[str, List[Dict[str, str]]] = {pid: [] for pid in permit_ids}
if not permit_ids:
return scope_map
sql = """
SELECT rps.permit_id, bs.id, bs.description
FROM region_permit_scopes rps
JOIN business_scopes bs ON bs.id = rps.scope_id
WHERE rps.region_id = %s
ORDER BY rps.permit_id, bs.description
"""
cur = conn.cursor()
try:
cur.execute(sql, (region_id,))
except pg.ProgrammingError as exc:
# 42P01 => undefined_table; allow fallback when migration not yet applied.
sqlstate = getattr(exc, "sqlstate", "")
if sqlstate == "42P01":
return scope_map
raise
for permit_id, scope_id, description in cur.fetchall():
pid = str(permit_id)
if pid not in scope_map:
continue
scope_map[pid].append({"id": str(scope_id), "description": str(description)})
return scope_map
def _load_permit_sources_for_region(
conn: pg.Connection, region_id: str, permit_ids: Iterable[str]
) -> Dict[str, Dict[str, Any]]:
"""Return mapping of permit_id -> source metadata for the permit."""
permit_ids_list = [str(pid) for pid in permit_ids]
if not permit_ids_list:
return {}
if not _permit_sources_available(conn):
return {}
cur = conn.cursor()
cur.execute(
"""
SELECT permit_id, source_type, source_name, source_detail, updated_at
FROM permit_sources
WHERE region_id = %s AND permit_id = ANY(%s)
""",
(region_id, permit_ids_list),
)
sources: Dict[str, Dict[str, Any]] = {}
for permit_id, source_type, source_name, source_detail, updated_at in cur.fetchall():
pid = str(permit_id)
sources[pid] = {
"source_type": source_type or "",
"source_name": source_name or "",
"source_detail": source_detail or "",
"updated_at": _convert_snapshot_value(updated_at),
}
return sources
def load_permits_and_risks(
region_id: str, theme_id: Optional[str] = None, permit_id: Optional[str] = None
) -> List[Dict[str, object]]:
"""Return permits with attached risk entries for a region (optionally filtered by theme)."""
# Ensure optional permit file tables exist before running user queries.
try:
_ensure_permit_file_schema()
except Exception as exc:
logger.warning("[PERMIT-FILES] Failed to ensure permit file schema before loading permits: %s", exc)
sql = """
SELECT
rtp.theme_id,
t.name AS theme_name,
p.id AS permit_id,
p.name AS permit_name,
rk.id AS risk_id,
rk.risk_content,
rk.legal_basis,
rk.document_no,
rk.summary,
rpd.permit_status,
rpd.subitem_summary,
rpd.responsible_contact,
rpd.jurisdiction_scope
FROM region_theme_permits rtp
JOIN permits p ON p.id = rtp.permit_id
LEFT JOIN themes t ON t.id = rtp.theme_id
LEFT JOIN region_permit_risks rpr
ON rpr.region_id = rtp.region_id
AND rpr.permit_id = rtp.permit_id
LEFT JOIN risks rk ON rk.id = rpr.risk_id
LEFT JOIN region_permit_details rpd
ON rpd.region_id = rtp.region_id
AND rpd.permit_id = rtp.permit_id
WHERE rtp.region_id = %s
"""
params: List[Any] = [region_id]
if theme_id:
sql += " AND rtp.theme_id = %s"
params.append(theme_id)
if permit_id is not None:
sql += " AND rtp.permit_id = %s"
params.append(permit_id)
sql += """
ORDER BY p.name, rk.risk_content
"""
permits: Dict[str, Dict[str, object]] = {}
with _lic_pg_conn() as conn:
cur = conn.cursor()
cur.execute(sql, tuple(params))
for row in cur.fetchall():
(
row_theme_id,
row_theme_name,
permit_id,
permit_name,
risk_id,
risk_content,
legal_basis,
document_no,
summary,
permit_status,
subitem_summary,
responsible_contact,
jurisdiction_scope,
) = row
pid = str(permit_id)
theme_id_value = str(row_theme_id) if row_theme_id else ""
theme_name_value = str(row_theme_name) if row_theme_name else ""
entry = permits.setdefault(
pid,
{
"id": pid,
"name": str(permit_name),
"business_scopes": [],
"risks": [],
"permit_status": None,
"subitem_summary": None,
"responsible_contact": None,
"jurisdiction_scope": None,
"theme": {
"id": theme_id_value,
"name": theme_name_value,
},
"themes": [],
},
)
if theme_id_value and not entry["theme"].get("id"):
entry["theme"]["id"] = theme_id_value
if theme_name_value and not entry["theme"].get("name"):
entry["theme"]["name"] = theme_name_value
if theme_id_value or theme_name_value:
theme_list = entry.get("themes") or []
duplicate = False
for theme_entry in theme_list:
if theme_id_value:
if theme_entry.get("id") == theme_id_value:
duplicate = True
break
else:
if not theme_entry.get("id") and theme_entry.get("name") == theme_name_value:
duplicate = True
break
if not duplicate:
theme_list.append(
{
"id": theme_id_value,
"name": theme_name_value,
}
)
entry["themes"] = theme_list
if entry["permit_status"] is None and permit_status:
entry["permit_status"] = permit_status.strip() or None
if entry["subitem_summary"] is None and subitem_summary:
entry["subitem_summary"] = subitem_summary.strip() or None
if entry["responsible_contact"] is None and responsible_contact:
entry["responsible_contact"] = responsible_contact.strip() or None
if entry["jurisdiction_scope"] is None and jurisdiction_scope:
entry["jurisdiction_scope"] = jurisdiction_scope.strip() or None
if risk_id is not None:
summary_markdown = _format_summary_markdown(summary or "")
entry["risks"].append(
{
"id": str(risk_id),
"risk_content": risk_content or "",
"legal_basis": legal_basis or "",
"document_no": document_no or "",
"summary": summary_markdown,
}
)
permit_ids = list(permits.keys())
scope_map = _load_permit_scopes_for_region(conn, region_id, permit_ids)
source_map = _load_permit_sources_for_region(conn, region_id, permit_ids)
try:
file_meta_map = _load_permit_file_metadata(conn, region_id, permit_ids)
except pg.DatabaseError as exc: # type: ignore[attr-defined]
sqlstate = getattr(exc, "sqlstate", "")
if sqlstate == "42P01":
logger.warning("[PERMIT-FILES] permit_file_links missing while loading permits, recreating schema lazily")
global _PERMIT_FILE_SCHEMA_READY
_PERMIT_FILE_SCHEMA_READY = None
_ensure_permit_file_schema()
file_meta_map = {}
else:
raise
for pid in permit_ids:
permits[pid]["business_scopes"] = scope_map.get(pid, [])
if pid in source_map:
permits[pid]["permit_source"] = source_map[pid]
else:
permits[pid]["permit_source"] = {
"source_type": "",
"source_name": "",
"source_detail": "",
"updated_at": None,
}
if pid in file_meta_map:
permits[pid]["permit_file"] = file_meta_map[pid]
else:
permits[pid]["permit_file"] = {
"file_id": "",
"filename": "",
"content_type": "",
"file_size": 0,
"created_at": None,
"uploaded_by": "",
}
if "themes" not in permits[pid] or permits[pid]["themes"] is None:
permits[pid]["themes"] = []
return list(permits.values())
def fetch_permit_file(region_id: str, permit_id: str) -> Optional[Dict[str, Any]]:
"""Return file payload for a region-permit pair if available."""
if not region_id or not permit_id:
return None
with _lic_pg_conn() as conn:
_ensure_permit_file_schema(conn)
try:
row = _select_permit_file_blob(conn, region_id, permit_id)
except pg.DatabaseError as exc: # type: ignore[attr-defined]
sqlstate = getattr(exc, "sqlstate", "")
if sqlstate == "42P01":
logger.warning(
"[PERMIT-FILES] permit_file_links missing when downloading file (region=%s permit=%s); recreating schema",
region_id,
permit_id,
)
global _PERMIT_FILE_SCHEMA_READY
_PERMIT_FILE_SCHEMA_READY = None
_ensure_permit_file_schema()
return None
raise
if not row:
return None
(
file_id,
filename,
content_type,
file_size,
file_data,
created_at,
uploaded_by,
) = row
return {
"file_id": str(file_id),
"filename": filename or "",
"content_type": content_type or "application/octet-stream",
"file_size": int(file_size or 0),
"file_data": bytes(file_data) if file_data is not None else b"",
"created_at": created_at.isoformat() if created_at else None,
"uploaded_by": uploaded_by or "",
}
def _select_permit_file_blob(conn: pg.Connection, region_id: str, permit_id: str):
"""Fetch a single permit file with binary content, recreating tables if needed."""
sql = """
SELECT
pf.id,
pf.filename,
pf.content_type,
pf.file_size,
pf.file_data,
pf.created_at,
pf.uploaded_by
FROM permit_file_links pfl
JOIN permit_files pf ON pf.id = pfl.file_id
WHERE pfl.region_id = %s
AND pfl.permit_id = %s
LIMIT 1
"""
attempts = 0
while attempts < 2:
try:
cur = conn.cursor()
cur.execute(sql, (region_id, permit_id))
return cur.fetchone()
except pg.DatabaseError as exc: # type: ignore[attr-defined]
sqlstate = getattr(exc, "sqlstate", "")
if sqlstate != "42P01":
raise
attempts += 1
try:
conn.rollback()
except Exception:
pass
global _PERMIT_FILE_SCHEMA_READY
_PERMIT_FILE_SCHEMA_READY = None
_ensure_permit_file_schema(conn)
if attempts >= 2:
logger.warning(
"[PERMIT-FILES] permit_file_links table missing after recreate attempt when downloading file (region=%s permit=%s)",
region_id,
permit_id,
)
return None
return None
def find_permit_contexts_by_name(permit_name: str) -> List[Dict[str, str]]:
"""Return region/theme contexts for permits with an exact name match."""
if not permit_name:
return []
sql = """
SELECT
rtp.region_id,
r.name AS region_name,
rtp.theme_id,
t.name AS theme_name,
p.id AS permit_id,
p.name AS permit_name
FROM region_theme_permits rtp
JOIN permits p ON p.id = rtp.permit_id
JOIN regions r ON r.id = rtp.region_id
JOIN themes t ON t.id = rtp.theme_id
WHERE p.name = %s
ORDER BY r.name, t.name
"""
ordered: OrderedDict[Tuple[str, str], Dict[str, str]] = OrderedDict()
with _lic_pg_conn() as conn:
cur = conn.cursor()
cur.execute(sql, (permit_name,))
for row in cur.fetchall():
region_id, region_name, theme_id, theme_name, permit_id, canonical_name = row
rid = str(region_id)
pid = str(permit_id)
key = (rid, pid)
if key in ordered:
continue
ordered[key] = {
"region_id": rid,
"region_name": str(region_name),
"theme_id": str(theme_id),
"theme_name": str(theme_name),
"permit_id": pid,
"permit_name": str(canonical_name),
}
return list(ordered.values())
def load_theme_payload(region_id: str, theme_id: str) -> Dict[str, object]:
"""Assemble full data bundle for a region-theme selection."""
info_sql = """
SELECT r.id, r.name, t.id, t.name
FROM regions r
JOIN region_themes rt ON rt.region_id = r.id
JOIN themes t ON t.id = rt.theme_id
WHERE r.id = %s AND t.id = %s
LIMIT 1
"""
with _lic_pg_conn() as conn:
cur = conn.cursor()
cur.execute(info_sql, (region_id, theme_id))
row = cur.fetchone()
if not row:
raise ValueError("Region/theme combination not found")
region_uuid, region_name, theme_uuid, theme_name = row
permits = load_permits_and_risks(region_id, theme_id)
return {
"region": {"id": str(region_uuid), "name": str(region_name)},
"theme": {"id": str(theme_uuid), "name": str(theme_name)},
"permits": permits,
}
def _get_checkpoints_dir() -> str:
"""Get the directory for storing checkpoint files."""
base_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data")
checkpoints_dir = os.path.join(base_dir, "checkpoints")
os.makedirs(checkpoints_dir, exist_ok=True)
return checkpoints_dir
def _get_all_tables() -> List[str]:
"""Get list of all tables in the licensing_risks database."""
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
with _lic_pg_conn() as conn:
cur = conn.cursor()
cur.execute(sql)
return [row[0] for row in cur.fetchall()]
def _get_table_dependencies(conn: pg.Connection) -> Dict[str, List[str]]:
"""
获取表依赖关系图。
返回: {被引用表名: [引用它的表列表]}
例如: {'regions': ['region_themes', 'region_scopes'], 'themes': ['region_themes']}
"""
sql = """
SELECT
ccu.table_name AS referenced_table,
tc.table_name AS dependent_table
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage ccu
ON tc.constraint_name = ccu.constraint_name
AND tc.table_schema = ccu.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = 'public'
ORDER BY ccu.table_name, tc.table_name
"""
cur = conn.cursor()
cur.execute(sql)
dependencies = {}
for row in cur.fetchall():
referenced_table, dependent_table = row
if referenced_table not in dependencies:
dependencies[referenced_table] = []
if dependent_table not in dependencies[referenced_table]:
dependencies[referenced_table].append(dependent_table)
return dependencies
def _topological_sort_tables(all_tables: List[str], dependencies: Dict[str, List[str]]) -> List[str]:
"""
使用拓扑排序确定表恢复顺序。
先返回父表(无外键依赖),再返回子表(引用其他表)。
这确保恢复时不会违反外键约束。
"""
from collections import deque
# 计算每个表的入度(被多少表引用/依赖)
in_degree = {table: 0 for table in all_tables}
for parent_table, children in dependencies.items():
for child in children:
if child in all_tables:
in_degree[child] += 1
# 入度为0的表是父表可以先恢复
queue = deque([table for table in all_tables if in_degree[table] == 0])
result = []
while queue:
table = queue.popleft()
result.append(table)
# 减少依赖该表的子表的入度
for parent, children in dependencies.items():
if parent == table:
for child in children:
if child in all_tables and in_degree[child] > 0:
in_degree[child] -= 1
if in_degree[child] == 0:
queue.append(child)
# 如果有循环依赖或孤立节点,将剩余表添加到末尾
for table in all_tables:
if table not in result:
result.append(table)
return result
def _backup_table(conn: pg.Connection, table_name: str) -> Tuple[List[Dict[str, Any]], int]:
"""Backup a single table and return its data and row count."""
logger.info(f"[CHECKPOINT] Backing up table: {table_name}")
sql = f"SELECT * FROM {table_name}"
cur = conn.cursor()
cur.execute(sql)
rows = cur.fetchall()
colnames = [desc[0] for desc in cur.description]
data = []
for row in rows:
row_dict = {}
for i, col in enumerate(colnames):
value = row[i]
if value is not None:
# Convert UUID and other non-serializable types to strings
if hasattr(value, 'isoformat'): # UUID, datetime, etc.
row_dict[col] = str(value)
else:
row_dict[col] = value
data.append(row_dict)
row_count = len(data)
logger.info(f"[CHECKPOINT] Backup complete: {table_name} - {row_count} rows, {len(colnames)} columns")
return data, row_count
def _restore_table(conn: pg.Connection, table_name: str, data: List[Dict[str, Any]], batch_size: int = 1000) -> int:
"""Restore a table from backup data. Returns number of rows restored."""
if not data:
logger.info(f"[CHECKPOINT] Skipping empty table: {table_name}")
return 0
conn.autocommit = False
try:
cur = conn.cursor()
logger.info(f"[CHECKPOINT] Truncating table: {table_name}")
# Use TRUNCATE with CASCADE to handle foreign key dependencies
# This will automatically remove dependent records
truncate_sql = f"TRUNCATE TABLE {table_name} CASCADE"
cur.execute(truncate_sql)
columns = list(data[0].keys())
placeholders = ", ".join(["%s"] * len(columns))
logger.info(f"[CHECKPOINT] Restoring {len(data)} rows into {table_name} (columns: {', '.join(columns)})")
# 🚀 优化: 使用批量插入而不是逐行插入
# pg8000 使用 executemany 进行批量插入
batch_size = 1000 # 每批1000行
if len(data) <= batch_size:
# 小数据量,直接批量插入
values_list = [[row.get(col) for col in columns] for row in data]
cur.executemany(f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})", values_list)
logger.info(f"[CHECKPOINT] Bulk insert complete: {table_name} - {len(data)} rows inserted")
else:
# 大数据量,分批插入
total_rows = len(data)
for i in range(0, total_rows, batch_size):
batch_end = min(i + batch_size, total_rows)
batch_data = data[i:batch_end]
values_list = [[row.get(col) for col in columns] for row in batch_data]
cur.executemany(f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})", values_list)
logger.info(f"[CHECKPOINT] Progress: {table_name} - {batch_end}/{total_rows} rows inserted")
conn.commit()
logger.info(f"[CHECKPOINT] Restore complete: {table_name} - {len(data)} rows successfully inserted")
return len(data)
except Exception as e:
conn.rollback()
logger.error(f"[CHECKPOINT] Restore FAILED: {table_name} - {str(e)}")
raise e
finally:
conn.autocommit = False
def create_checkpoint(description: str = "") -> Dict[str, Any]:
"""
安全创建checkpoint所有表操作在一个事务中。
Args:
description: checkpoint的描述信息
Returns:
包含checkpoint信息的字典
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_id = f"checkpoint_{timestamp}"
logger.info("=" * 80)
logger.info(f"[CHECKPOINT] Starting checkpoint creation: {checkpoint_id}")
if description:
logger.info(f"[CHECKPOINT] Description: {description}")
logger.info("=" * 80)
tables = _get_all_tables()
logger.info(f"[CHECKPOINT] Found {len(tables)} tables to backup: {', '.join(tables)}")
checkpoint_data = {
"checkpoint_id": checkpoint_id,
"timestamp": timestamp,
"description": description,
"tables": {}
}
total_rows = 0
table_counts = {}
with _lic_pg_conn(autocommit=False) as conn:
try:
for i, table in enumerate(tables, 1):
logger.info(f"[CHECKPOINT] [{i}/{len(tables)}] Processing table: {table}")
data, row_count = _backup_table(conn, table)
checkpoint_data["tables"][table] = data
table_counts[table] = row_count
total_rows += row_count
logger.info(f"[CHECKPOINT] [{i}/{len(tables)}] Table {table} backed up: {row_count} rows")
# 全部成功后才提交
logger.info("[CHECKPOINT] All tables backed up successfully, committing transaction...")
conn.commit()
logger.info("[CHECKPOINT] Transaction committed")
except Exception as e:
# 任何失败都回滚
logger.error(f"[CHECKPOINT] ERROR during backup, rolling back: {str(e)}")
conn.rollback()
raise e
finally:
conn.autocommit = False
checkpoint_data["table_counts"] = table_counts
checkpoint_data["total_rows"] = total_rows
checkpoints_dir = _get_checkpoints_dir()
checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json")
logger.info(f"[CHECKPOINT] Saving checkpoint file: {checkpoint_file}")
def json_serializer(obj):
"""Convert non-JSON serializable objects to strings."""
try:
import uuid
if isinstance(obj, uuid.UUID):
return str(obj)
except ImportError:
pass
if hasattr(obj, 'isoformat'):
return str(obj)
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
with open(checkpoint_file, "w", encoding="utf-8") as f:
json.dump(checkpoint_data, f, ensure_ascii=False, indent=2, default=json_serializer)
logger.info("=" * 80)
logger.info(f"[CHECKPOINT] Checkpoint creation COMPLETED: {checkpoint_id}")
logger.info(f"[CHECKPOINT] Total rows backed up: {total_rows}")
logger.info(f"[CHECKPOINT] File: {checkpoint_file}")
logger.info("=" * 80)
return {
"checkpoint_id": checkpoint_id,
"timestamp": timestamp,
"description": description,
"total_rows": total_rows,
"table_counts": table_counts
}
def list_checkpoints() -> List[Dict[str, Any]]:
"""List all available checkpoints."""
checkpoints_dir = _get_checkpoints_dir()
checkpoints = []
if not os.path.exists(checkpoints_dir):
return checkpoints
for filename in os.listdir(checkpoints_dir):
if filename.endswith(".json"):
filepath = os.path.join(checkpoints_dir, filename)
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
checkpoints.append({
"checkpoint_id": data["checkpoint_id"],
"timestamp": data["timestamp"],
"description": data.get("description", ""),
"total_rows": data.get("total_rows", 0),
"table_counts": data.get("table_counts", {}),
"filename": filename
})
except Exception as e:
print(f"Error reading checkpoint {filename}: {e}")
return sorted(checkpoints, key=lambda x: x["timestamp"], reverse=True)
def restore_checkpoint(
checkpoint_id: str,
create_auto_backup: bool = True,
batch_size: int = 1000
) -> Dict[str, Any]:
"""
安全恢复数据库从checkpoint。
⚠️ 危险操作: 会覆盖现有数据!
Args:
checkpoint_id: 要恢复的checkpoint ID
create_auto_backup: 是否在恢复前自动备份当前状态
batch_size: 批量插入的批次大小 (默认1000行/批)
Returns:
包含恢复结果的字典
"""
checkpoints_dir = _get_checkpoints_dir()
checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json")
logger.warning("=" * 80)
logger.warning(f"[CHECKPOINT] WARNING: Starting restore operation: {checkpoint_id}")
logger.warning("[CHECKPOINT] This will OVERWRITE all existing data in the database!")
logger.warning("=" * 80)
if not os.path.exists(checkpoint_file):
error_msg = f"Checkpoint {checkpoint_id} not found"
logger.error(f"[CHECKPOINT] {error_msg}")
raise ValueError(error_msg)
with open(checkpoint_file, "r", encoding="utf-8") as f:
checkpoint_data = json.load(f)
# 自动备份当前状态(可选但推荐)
auto_backup_info = None
if create_auto_backup:
logger.info("[CHECKPOINT] Creating auto-backup before restore... (THIS MAY TAKE TIME)")
try:
import time
start_time = time.time()
auto_backup_info = create_checkpoint(f"auto_backup_before_restore_{checkpoint_id}")
elapsed = time.time() - start_time
logger.info(f"[CHECKPOINT] Auto-backup created in {elapsed:.2f}s: {auto_backup_info['checkpoint_id']}")
logger.info(f"[CHECKPOINT] Auto-backup contains {auto_backup_info['total_rows']} rows")
except Exception as e:
logger.error(f"[CHECKPOINT] WARNING: Failed to create auto-backup: {e}")
logger.warning("[CHECKPOINT] Continuing without auto-backup...")
else:
logger.info("[CHECKPOINT] Auto-backup DISABLED by user")
tables = checkpoint_data.get("tables", {})
total_rows_in_checkpoint = sum(len(data) for data in tables.values())
logger.info("=" * 80)
logger.info(f"[CHECKPOINT] Checkpoint details:")
logger.info(f"[CHECKPOINT] ID: {checkpoint_id}")
logger.info(f"[CHECKPOINT] Tables: {len(tables)}")
logger.info(f"[CHECKPOINT] Total rows: {total_rows_in_checkpoint}")
logger.info(f"[CHECKPOINT] Auto-backup: {'Yes (' + auto_backup_info['checkpoint_id'] + ')' if auto_backup_info else 'No'}")
logger.info("=" * 80)
restore_summary = {
"checkpoint_id": checkpoint_id,
"auto_backup": auto_backup_info.get("checkpoint_id") if auto_backup_info else None,
"tables_restored": 0,
"total_rows_restored": 0,
"table_details": {},
"errors": []
}
with _lic_pg_conn(autocommit=False) as conn:
try:
# 1. 构建表依赖关系图
logger.info("[CHECKPOINT] Building table dependency graph...")
dependencies = _get_table_dependencies(conn)
all_tables = list(tables.keys())
logger.info(f"[CHECKPOINT] Found {len(dependencies)} table dependencies")
# 2. 拓扑排序获取恢复顺序
logger.info("[CHECKPOINT] Calculating restore order...")
restore_order = _topological_sort_tables(all_tables, dependencies)
logger.info(f"[CHECKPOINT] Restore order: {' -> '.join(restore_order)}")
# 3. 锁定所有表(防止并发写入)
logger.info("[CHECKPOINT] Acquiring exclusive locks on all tables...")
cur = conn.cursor()
for table in restore_order:
cur.execute(f"LOCK TABLE {table} IN EXCLUSIVE MODE")
logger.info("[CHECKPOINT] All tables locked exclusively")
# 4. 按依赖顺序恢复表
logger.info("=" * 80)
logger.info("[CHECKPOINT] Starting restore process...")
logger.info("=" * 80)
import time
restore_start_time = time.time()
for i, table_name in enumerate(restore_order, 1):
data = tables.get(table_name, [])
table_start_time = time.time()
logger.info(f"[CHECKPOINT] [{i}/{len(restore_order)}] Preparing to restore table: {table_name}")
try:
rows_restored = _restore_table(conn, table_name, data, batch_size=batch_size)
table_elapsed = time.time() - table_start_time
restore_summary["tables_restored"] += 1
restore_summary["total_rows_restored"] += rows_restored
restore_summary["table_details"][table_name] = rows_restored
logger.info(f"[CHECKPOINT] [{i}/{len(restore_order)}] Table {table_name} restored: {rows_restored} rows in {table_elapsed:.2f}s")
except Exception as e:
error_msg = f"Failed to restore table {table_name}: {str(e)}"
logger.error(f"[CHECKPOINT] ERROR: {error_msg}")
restore_summary["errors"].append(error_msg)
raise e
restore_elapsed = time.time() - restore_start_time
total_tables = len(restore_order)
logger.info(f"[CHECKPOINT] All {total_tables} tables restored in {restore_elapsed:.2f}s")
# 5. 提交事务
logger.info("=" * 80)
logger.info("[CHECKPOINT] All tables restored successfully, committing transaction...")
conn.commit()
logger.info("[CHECKPOINT] Transaction committed successfully")
logger.info("=" * 80)
logger.warning("=" * 80)
logger.warning(f"[CHECKPOINT] RESTORE COMPLETED SUCCESSFULLY: {checkpoint_id}")
logger.warning(f"[CHECKPOINT] Tables restored: {restore_summary['tables_restored']}/{len(restore_order)}")
logger.warning(f"[CHECKPOINT] Total rows restored: {restore_summary['total_rows_restored']}")
if auto_backup_info:
logger.warning(f"[CHECKPOINT] Auto-backup available: {auto_backup_info['checkpoint_id']}")
logger.warning("=" * 80)
return {
"status": "success",
"message": f"Successfully restored {restore_summary['tables_restored']} tables, "
f"{restore_summary['total_rows_restored']} total rows",
"summary": restore_summary
}
except Exception as e:
# 回滚事务
logger.error(f"[CHECKPOINT] ERROR during restore, rolling back: {str(e)}")
conn.rollback()
logger.warning("[CHECKPOINT] Transaction rolled back, changes reverted")
error_info = {
"status": "error",
"message": f"Restore failed: {str(e)}",
"summary": restore_summary,
"auto_backup_available": bool(auto_backup_info)
}
# 如果有自动备份,提供恢复建议
if auto_backup_info:
error_info["recovery_suggestion"] = (
f"Use auto-backup to restore current state: "
f"restore_checkpoint('{auto_backup_info['checkpoint_id']}')"
)
logger.warning(f"[CHECKPOINT] Recovery suggestion: Use auto-backup '{auto_backup_info['checkpoint_id']}'")
logger.error("=" * 80)
logger.error(f"[CHECKPOINT] RESTORE FAILED: {str(e)}")
if auto_backup_info:
logger.error(f"[CHECKPOINT] You can restore from auto-backup: {auto_backup_info['checkpoint_id']}")
logger.error("=" * 80)
return error_info
def delete_checkpoint(checkpoint_id: str) -> bool:
"""Delete a checkpoint file."""
checkpoints_dir = _get_checkpoints_dir()
checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json")
if os.path.exists(checkpoint_file):
os.remove(checkpoint_file)
return True
return False
# ---------------------------------------------------------------------------
# Permit risk snapshot helpers (region_permit_risk_vw)
# ---------------------------------------------------------------------------
def _convert_snapshot_value(value: Any) -> Any:
"""Convert database values into JSON-serialisable primitives."""
if isinstance(value, (datetime, date)):
return value.isoformat()
if isinstance(value, Decimal):
return float(value)
if isinstance(value, uuid.UUID):
return str(value)
if isinstance(value, memoryview):
return bytes(value).decode("utf-8", errors="replace")
if isinstance(value, (bytes, bytearray)):
return value.decode("utf-8", errors="replace")
if isinstance(value, (list, tuple)):
return [_convert_snapshot_value(v) for v in value]
return value
def _normalize_snapshot_payload(record: Dict[str, Any]) -> Dict[str, Any]:
"""Return a JSON-safe copy of the permit risk view record."""
return {key: _convert_snapshot_value(val) for key, val in record.items()}
def _fetch_permit_risk_row(
conn: pg.Connection, region_id: str, permit_id: str, risk_id: str
) -> Dict[str, Any]:
"""Fetch a single row from the consolidation view."""
if _permit_sources_available(conn):
sql = """
SELECT
vw.region_id,
vw.region_name,
vw.permit_id,
vw.permit_name,
vw.risk_id,
vw.risk_content,
vw.legal_basis,
vw.document_no,
vw.summary,
vw.theme_ids,
vw.theme_names,
vw.scope_ids,
vw.scope_descriptions,
vw.subitem_ids,
vw.permit_status,
vw.subitem_summary,
vw.responsible_contact,
vw.jurisdiction_scope,
vw.permit_detail_updated_at,
vw.permit_risk_key,
ps.source_type AS permit_source_type,
ps.source_name AS permit_source_name,
ps.source_detail AS permit_source_detail,
ps.updated_at AS permit_source_updated_at
FROM region_permit_risk_vw vw
LEFT JOIN permit_sources ps
ON ps.region_id = vw.region_id
AND ps.permit_id = vw.permit_id
WHERE vw.region_id = %s AND vw.permit_id = %s AND vw.risk_id = %s
LIMIT 1
"""
else:
sql = """
SELECT
region_id,
region_name,
permit_id,
permit_name,
risk_id,
risk_content,
legal_basis,
document_no,
summary,
theme_ids,
theme_names,
scope_ids,
scope_descriptions,
subitem_ids,
permit_status,
subitem_summary,
responsible_contact,
jurisdiction_scope,
permit_detail_updated_at,
permit_risk_key
FROM region_permit_risk_vw
WHERE region_id = %s AND permit_id = %s AND risk_id = %s
LIMIT 1
"""
cur = conn.cursor()
cur.execute(sql, (region_id, permit_id, risk_id))
row = cur.fetchone()
if not row:
raise ValueError("Permit risk combination not found in consolidation view")
columns = [desc[0] for desc in cur.description]
return {columns[i]: row[i] for i in range(len(columns))}
def _insert_permit_risk_snapshot(
conn: pg.Connection,
payload: Dict[str, Any],
*,
edited_by: Optional[str],
change_summary: Optional[str],
batch_id: str,
) -> Dict[str, Any]:
"""Insert a snapshot row and return metadata."""
permit_risk_key = str(payload["permit_risk_key"])
cur = conn.cursor()
cur.execute(
"""
SELECT version
FROM permit_risk_snapshots
WHERE permit_risk_key = %s
ORDER BY version DESC
LIMIT 1
FOR UPDATE
""",
(permit_risk_key,),
)
row = cur.fetchone()
next_version = (int(row[0]) + 1) if row else 1
payload_with_batch = dict(payload)
payload_with_batch["snapshot_batch_id"] = batch_id
payload_json = json.dumps(payload_with_batch, ensure_ascii=False)
cur.execute(
"""
INSERT INTO permit_risk_snapshots (
region_id,
permit_id,
risk_id,
permit_risk_key,
version,
payload,
edited_by,
change_summary
)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
RETURNING snapshot_id, created_at
""",
(
payload["region_id"],
payload["permit_id"],
payload["risk_id"],
permit_risk_key,
next_version,
payload_json,
edited_by,
change_summary,
),
)
snapshot_id, created_at = cur.fetchone()
return {
"snapshot_id": str(snapshot_id),
"region_id": payload["region_id"],
"permit_id": payload["permit_id"],
"risk_id": payload["risk_id"],
"permit_risk_key": permit_risk_key,
"version": next_version,
"created_at": _convert_snapshot_value(created_at),
"edited_by": edited_by,
"change_summary": change_summary or "",
"snapshot_batch_id": batch_id,
"payload": payload_with_batch,
}
def _create_snapshot_with_connection(
conn: pg.Connection,
region_id: str,
permit_id: str,
risk_id: str,
*,
edited_by: Optional[str],
change_summary: Optional[str],
batch_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Create a snapshot using an existing DB connection (no commit)."""
view_record = _fetch_permit_risk_row(conn, region_id, permit_id, risk_id)
payload = _normalize_snapshot_payload(view_record)
snapshot_batch_id = batch_id or str(uuid.uuid4())
metadata = _insert_permit_risk_snapshot(
conn,
payload,
edited_by=edited_by,
change_summary=change_summary,
batch_id=snapshot_batch_id,
)
theme_names = payload.get("theme_names") or []
scopes = payload.get("scope_ids") or []
subitems = payload.get("subitem_ids") or []
permit_status = payload.get("permit_status") or ""
preview_text = str(payload.get("risk_content") or "").strip()
preview_flat = re.sub(r"\s+", " ", preview_text)
if len(preview_flat) > 120:
preview_flat = f"{preview_flat[:117]}..."
logger.info(
"[CHECKPOINT] Snapshot created: %s version %s",
metadata["permit_risk_key"],
metadata["version"],
)
logger.info(
"[CHECKPOINT] Snapshot context: region=%s(%s) permit=%s(%s) risk=%s | themes=%s | scopes=%d | subitems=%d | status=%s",
payload.get("region_id"),
payload.get("region_name"),
payload.get("permit_id"),
payload.get("permit_name"),
payload.get("risk_id"),
"".join(str(name) for name in theme_names) if theme_names else "",
len(scopes),
len(subitems),
permit_status or "",
)
source_name = payload.get("permit_source_name")
if source_name:
logger.info(
"[CHECKPOINT] Snapshot permit source: %s",
source_name,
)
if preview_flat:
logger.info(
"[CHECKPOINT] Snapshot risk preview: %s",
preview_flat,
)
return metadata
def create_permit_risk_snapshot(
region_id: str,
permit_id: str,
risk_id: str,
*,
edited_by: Optional[str] = None,
change_summary: Optional[str] = None,
batch_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Capture the current state of a region/permit/risk record as a versioned snapshot.
Returns metadata about the created snapshot including version number.
"""
with _lic_pg_conn(autocommit=False) as conn:
try:
snapshot_meta = _create_snapshot_with_connection(
conn,
region_id,
permit_id,
risk_id,
edited_by=edited_by,
change_summary=change_summary,
batch_id=batch_id,
)
conn.commit()
return snapshot_meta
except Exception:
conn.rollback()
raise
def list_permit_risk_snapshots(
region_id: Optional[str] = None,
permit_id: Optional[str] = None,
risk_id: Optional[str] = None,
*,
permit_risk_key: Optional[str] = None,
limit: int = 20,
offset: int = 0,
) -> List[Dict[str, Any]]:
"""
List snapshots for a region/permit/risk combination ordered by version descending.
At least one identifier (permit_risk_key or region/permit/risk) must be provided.
"""
filters: List[str] = []
params: List[Any] = []
if permit_risk_key:
filters.append("permit_risk_key = %s")
params.append(permit_risk_key)
else:
if region_id:
filters.append("region_id = %s")
params.append(region_id)
if permit_id:
filters.append("permit_id = %s")
params.append(permit_id)
if risk_id:
filters.append("risk_id = %s")
params.append(risk_id)
if not filters:
raise ValueError("At least one identifier must be provided to list snapshots")
filters_clause = " AND ".join(filters)
sql = f"""
SELECT snapshot_id, version, permit_risk_key, edited_by, change_summary, created_at
FROM permit_risk_snapshots
WHERE {filters_clause}
ORDER BY version DESC
LIMIT %s OFFSET %s
"""
params.extend([limit, offset])
with _lic_pg_conn() as conn:
cur = conn.cursor()
cur.execute(sql, tuple(params))
rows = cur.fetchall()
snapshots: List[Dict[str, Any]] = []
for snapshot_id, version, key, editor, summary, created_at in rows:
snapshots.append(
{
"snapshot_id": str(snapshot_id),
"permit_risk_key": key,
"version": int(version),
"created_at": _convert_snapshot_value(created_at),
"edited_by": editor,
"change_summary": summary or "",
}
)
return snapshots
def get_permit_risk_snapshot(snapshot_id: str) -> Optional[Dict[str, Any]]:
"""Fetch a snapshot payload by its identifier."""
sql = """
SELECT
snapshot_id,
region_id,
permit_id,
risk_id,
permit_risk_key,
version,
payload,
edited_by,
change_summary,
created_at
FROM permit_risk_snapshots
WHERE snapshot_id = %s
"""
with _lic_pg_conn() as conn:
cur = conn.cursor()
cur.execute(sql, (snapshot_id,))
row = cur.fetchone()
if not row:
return None
(
snap_id,
region_id,
permit_id,
risk_id,
permit_risk_key,
version,
payload,
edited_by,
change_summary,
created_at,
) = row
if isinstance(payload, (bytes, bytearray, memoryview)):
payload_obj = json.loads(payload)
else:
payload_obj = payload if isinstance(payload, dict) else json.loads(payload)
return {
"snapshot_id": str(snap_id),
"region_id": str(region_id),
"permit_id": str(permit_id),
"risk_id": str(risk_id),
"permit_risk_key": permit_risk_key,
"version": int(version),
"created_at": _convert_snapshot_value(created_at),
"edited_by": edited_by,
"change_summary": change_summary or "",
"payload": payload_obj,
}
def list_permit_risk_snapshot_summaries(
*,
region_id: Optional[str] = None,
permit_id: Optional[str] = None,
edited_by: Optional[str] = None,
limit: int = 20,
offset: int = 0,
) -> List[Dict[str, Any]]:
"""
Return snapshot summaries for checkpoint history views.
The payload JSON stores the flattened view row, so we project key fields for UI display.
"""
filters: List[str] = []
params: List[Any] = []
if region_id:
filters.append("region_id = %s")
params.append(region_id)
if permit_id:
filters.append("permit_id = %s")
params.append(permit_id)
if edited_by:
filters.append("edited_by = %s")
params.append(edited_by)
where_clause = f"WHERE {' AND '.join(filters)}" if filters else ""
sql = f"""
SELECT
snapshot_id,
region_id,
permit_id,
risk_id,
permit_risk_key,
version,
edited_by,
change_summary,
created_at,
payload ->> 'region_name' AS region_name,
payload ->> 'permit_name' AS permit_name,
payload ->> 'risk_content' AS risk_content,
payload ->> 'legal_basis' AS legal_basis,
payload ->> 'document_no' AS document_no,
payload ->> 'permit_status' AS permit_status,
payload ->> 'snapshot_batch_id' AS snapshot_batch_id,
payload ->> 'permit_source_name' AS permit_source_name,
payload ->> 'permit_source_type' AS permit_source_type
FROM permit_risk_snapshots
{where_clause}
ORDER BY created_at DESC
LIMIT %s OFFSET %s
"""
params.extend([limit, offset])
with _lic_pg_conn() as conn:
cur = conn.cursor()
cur.execute(sql, tuple(params))
rows = cur.fetchall()
summaries: List[Dict[str, Any]] = []
for (
snapshot_id,
region_uuid,
permit_uuid,
risk_uuid,
permit_risk_key,
version,
editor,
summary_text,
created_at,
region_name,
permit_name,
risk_content,
legal_basis,
document_no,
permit_status,
snapshot_batch_id,
permit_source_name,
permit_source_type,
) in rows:
summaries.append(
{
"snapshot_id": str(snapshot_id),
"region_id": str(region_uuid),
"permit_id": str(permit_uuid),
"risk_id": str(risk_uuid),
"permit_risk_key": permit_risk_key,
"version": int(version),
"created_at": _convert_snapshot_value(created_at),
"edited_by": editor,
"change_summary": summary_text or "",
"region_name": region_name or "",
"permit_name": permit_name or "",
"risk_content": risk_content or "",
"legal_basis": legal_basis or "",
"document_no": document_no or "",
"permit_status": permit_status or "",
"snapshot_batch_id": snapshot_batch_id or "",
"permit_source_name": permit_source_name or "",
"permit_source_type": permit_source_type or "",
}
)
return summaries
def count_permit_risk_snapshots(
*,
region_id: Optional[str] = None,
permit_id: Optional[str] = None,
edited_by: Optional[str] = None,
) -> int:
"""Return total snapshots matching the optional filters."""
filters: List[str] = []
params: List[Any] = []
if region_id:
filters.append("region_id = %s")
params.append(region_id)
if permit_id:
filters.append("permit_id = %s")
params.append(permit_id)
if edited_by:
filters.append("edited_by = %s")
params.append(edited_by)
where_clause = f"WHERE {' AND '.join(filters)}" if filters else ""
sql = f"SELECT COUNT(*) FROM permit_risk_snapshots {where_clause}"
with _lic_pg_conn() as conn:
cur = conn.cursor()
cur.execute(sql, tuple(params))
row = cur.fetchone()
return int(row[0]) if row else 0
def update_permit_risk_record(
region_id: str,
permit_id: str,
risk_id: str,
*,
risk_content: Any = _UNSET,
legal_basis: Any = _UNSET,
document_no: Any = _UNSET,
summary: Any = _UNSET,
permit_status: Any = _UNSET,
subitem_summary: Any = _UNSET,
responsible_contact: Any = _UNSET,
jurisdiction_scope: Any = _UNSET,
edited_by: Optional[str] = None,
change_summary: Optional[str] = None,
) -> Dict[str, Any]:
"""
Update the permit risk record while capturing a checkpoint snapshot beforehand.
Returns the snapshot metadata (pre-change) and the refreshed view row (post-change).
"""
update_flags = [
risk_content,
legal_basis,
document_no,
summary,
permit_status,
subitem_summary,
responsible_contact,
jurisdiction_scope,
]
if all(flag is _UNSET for flag in update_flags):
raise ValueError("No fields provided to update.")
with _lic_pg_conn(autocommit=False) as conn:
try:
cur = conn.cursor()
cur.execute(
"""
SELECT 1
FROM region_permit_risks
WHERE region_id = %s AND permit_id = %s AND risk_id = %s
FOR UPDATE
""",
(region_id, permit_id, risk_id),
)
if cur.fetchone() is None:
raise ValueError("Permit risk combination not found.")
snapshot_meta = _create_snapshot_with_connection(
conn,
region_id,
permit_id,
risk_id,
edited_by=edited_by,
change_summary=change_summary,
)
risk_updates: List[str] = []
risk_params: List[Any] = []
risk_fields = (
("risk_content", risk_content),
("legal_basis", legal_basis),
("document_no", document_no),
("summary", summary),
)
for column, value in risk_fields:
if value is not _UNSET:
risk_updates.append(f"{column} = %s")
risk_params.append(value)
if risk_updates:
risk_params.append(risk_id)
cur.execute(
f"UPDATE risks SET {', '.join(risk_updates)} WHERE id = %s",
tuple(risk_params),
)
detail_columns: List[str] = []
detail_values: List[Any] = []
detail_fields = (
("permit_status", permit_status),
("subitem_summary", subitem_summary),
("responsible_contact", responsible_contact),
("jurisdiction_scope", jurisdiction_scope),
)
for column, value in detail_fields:
if value is not _UNSET:
detail_columns.append(column)
detail_values.append(value)
details_updated = False
if detail_columns:
insert_cols = ", ".join(["region_id", "permit_id"] + detail_columns)
insert_placeholders = ", ".join(["%s"] * (2 + len(detail_values)))
update_assignments = ", ".join(
[f"{col} = EXCLUDED.{col}" for col in detail_columns]
)
sql = f"""
INSERT INTO region_permit_details ({insert_cols})
VALUES ({insert_placeholders})
ON CONFLICT (region_id, permit_id)
DO UPDATE SET
{update_assignments},
updated_at = now()
"""
cur.execute(
sql,
(region_id, permit_id, *detail_values),
)
details_updated = True
updated_record = _normalize_snapshot_payload(
_fetch_permit_risk_row(conn, region_id, permit_id, risk_id)
)
conn.commit()
logger.info(
"[CHECKPOINT] Permit risk updated: %s version %s -> new snapshot ready",
snapshot_meta["permit_risk_key"],
snapshot_meta["version"],
)
return {
"snapshot": snapshot_meta,
"current": updated_record,
"risk_updated": bool(risk_updates),
"details_updated": details_updated,
}
except Exception:
conn.rollback()
raise
def delete_region_permit(
region_id: str,
theme_id: str,
permit_id: str,
*,
edited_by: Optional[str] = None,
change_summary: Optional[str] = None,
) -> Dict[str, Any]:
"""
删除指定区划下的许可,同时为所有关联风险生成快照,并清理依赖关系。
返回删除摘要、快照列表以及主题剩余许可数量。
"""
with _lic_pg_conn(autocommit=False) as conn:
cur = conn.cursor()
try:
cur.execute(
"""
SELECT r.name, t.name, p.name
FROM region_theme_permits rtp
JOIN regions r ON r.id = rtp.region_id
JOIN themes t ON t.id = rtp.theme_id
JOIN permits p ON p.id = rtp.permit_id
WHERE rtp.region_id = %s
AND rtp.theme_id = %s
AND rtp.permit_id = %s
FOR UPDATE
""",
(region_id, theme_id, permit_id),
)
row = cur.fetchone()
if not row:
raise ValueError("地区-主题-许可组合不存在,无法删除")
region_name, theme_name, permit_name = (str(val) for val in row)
cur.execute(
"""
SELECT risk_id
FROM region_permit_risks
WHERE region_id = %s
AND permit_id = %s
ORDER BY risk_id
FOR UPDATE
""",
(region_id, permit_id),
)
risk_ids = [str(risk_id) for (risk_id,) in cur.fetchall()]
snapshots: List[Dict[str, Any]] = []
total_snapshots = 0
summary_base = (change_summary or "").strip()
if not summary_base:
summary_base = f"删除许可 {permit_name}(地区:{region_name}"
snapshot_batch_id = str(uuid.uuid4())
for idx, risk_id in enumerate(risk_ids, start=1):
detail_summary = summary_base
if len(risk_ids) > 1:
detail_summary = f"{summary_base} - 风险 {idx}/{len(risk_ids)}ID{risk_id}"
snapshot_meta = _create_snapshot_with_connection(
conn,
region_id,
permit_id,
risk_id,
edited_by=edited_by,
change_summary=detail_summary,
batch_id=snapshot_batch_id,
)
snapshots.append(
{
"snapshot_id": snapshot_meta["snapshot_id"],
"permit_risk_key": snapshot_meta["permit_risk_key"],
"version": snapshot_meta["version"],
"risk_id": snapshot_meta["risk_id"],
"created_at": snapshot_meta["created_at"],
"snapshot_batch_id": snapshot_meta.get("snapshot_batch_id"),
"change_summary": snapshot_meta.get("change_summary", ""),
}
)
total_snapshots += 1
if total_snapshots:
logger.info(
"[PERMIT-DELETE] Captured %d snapshots before deleting permit %s (%s) in region %s (%s)",
total_snapshots,
permit_id,
permit_name,
region_id,
region_name,
)
else:
logger.info(
"[PERMIT-DELETE] No risk snapshots required for permit %s (%s) in region %s (%s)",
permit_id,
permit_name,
region_id,
region_name,
)
delete_counts: Dict[str, int] = {}
cur.execute(
"""
DELETE FROM region_permit_risks
WHERE region_id = %s AND permit_id = %s
""",
(region_id, permit_id),
)
delete_counts["region_permit_risks"] = int(cur.rowcount or 0)
cur.execute(
"""
DELETE FROM region_permit_subitems
WHERE region_id = %s AND permit_id = %s
""",
(region_id, permit_id),
)
delete_counts["region_permit_subitems"] = int(cur.rowcount or 0)
cur.execute(
"""
DELETE FROM region_permit_scopes
WHERE region_id = %s AND permit_id = %s
""",
(region_id, permit_id),
)
delete_counts["region_permit_scopes"] = int(cur.rowcount or 0)
cur.execute(
"""
DELETE FROM region_permit_details
WHERE region_id = %s AND permit_id = %s
""",
(region_id, permit_id),
)
delete_counts["region_permit_details"] = int(cur.rowcount or 0)
if _permit_sources_available(conn):
cur.execute(
"""
DELETE FROM permit_sources
WHERE region_id = %s AND permit_id = %s
""",
(region_id, permit_id),
)
delete_counts["permit_sources"] = int(cur.rowcount or 0)
cur.execute(
"""
DELETE FROM region_theme_permits
WHERE region_id = %s AND theme_id = %s AND permit_id = %s
""",
(region_id, theme_id, permit_id),
)
delete_counts["region_theme_permits"] = int(cur.rowcount or 0)
cur.execute(
"""
SELECT COUNT(*)
FROM region_theme_permits
WHERE region_id = %s AND theme_id = %s
""",
(region_id, theme_id),
)
remaining_theme_permits = int(cur.fetchone()[0] or 0)
theme_detached = False
if remaining_theme_permits == 0:
cur.execute(
"""
DELETE FROM region_themes
WHERE region_id = %s AND theme_id = %s
""",
(region_id, theme_id),
)
theme_detached = cur.rowcount > 0
if theme_detached:
logger.info(
"[PERMIT-DELETE] Detached theme %s (%s) from region %s (%s) because no permits remain",
theme_id,
theme_name,
region_id,
region_name,
)
else:
logger.info(
"[PERMIT-DELETE] Theme %s (%s) retains %d permit(s) in region %s (%s); theme linkage preserved",
theme_id,
theme_name,
remaining_theme_permits,
region_id,
region_name,
)
conn.commit()
logger.info(
"[PERMIT-DELETE] Completed deletion for permit %s (%s) in region %s (%s): snapshots=%d, deleted_rows=%s",
permit_id,
permit_name,
region_id,
region_name,
total_snapshots,
delete_counts,
)
return {
"region_id": str(region_id),
"region_name": region_name,
"theme_id": str(theme_id),
"theme_name": theme_name,
"permit_id": str(permit_id),
"permit_name": permit_name,
"risk_ids": risk_ids,
"snapshot_count": total_snapshots,
"snapshots": snapshots,
"snapshot_batch_id": snapshot_batch_id if total_snapshots else "",
"deleted_rows": delete_counts,
"theme_detached": theme_detached,
"remaining_theme_permits": remaining_theme_permits,
}
except Exception:
conn.rollback()
raise
def restore_permit_risk_snapshot_batch(
snapshot_batch_id: str,
*,
edited_by: Optional[str] = None,
change_summary: Optional[str] = None,
) -> Dict[str, Any]:
"""
Restore region/permit/risk relations based on a snapshot batch (or single snapshot).
"""
if not snapshot_batch_id:
raise ValueError("快照批次 ID 不能为空")
with _lic_pg_conn(autocommit=False) as conn:
cur = conn.cursor()
cur.execute(
"""
SELECT snapshot_id, payload, created_at, edited_by, change_summary
FROM permit_risk_snapshots
WHERE payload ->> 'snapshot_batch_id' = %s
ORDER BY created_at ASC, snapshot_id ASC
""",
(snapshot_batch_id,),
)
rows = cur.fetchall()
resolved_batch_id = snapshot_batch_id
if not rows:
cur.execute(
"""
SELECT snapshot_id, payload, created_at, edited_by, change_summary
FROM permit_risk_snapshots
WHERE snapshot_id::text = %s
""",
(snapshot_batch_id,),
)
rows = cur.fetchall()
if not rows:
raise ValueError("未找到对应的快照记录")
snapshots: List[Dict[str, Any]] = []
for snap_id, payload_raw, created_at, snap_editor, snap_summary in rows:
if isinstance(payload_raw, dict):
payload_obj = payload_raw
elif isinstance(payload_raw, (bytes, bytearray, memoryview)):
payload_obj = json.loads(bytes(payload_raw).decode("utf-8"))
elif isinstance(payload_raw, str):
payload_obj = json.loads(payload_raw)
else:
payload_obj = json.loads(payload_raw)
batch_token = payload_obj.get("snapshot_batch_id")
if batch_token:
resolved_batch_id = batch_token
snapshots.append(
{
"snapshot_id": str(snap_id),
"payload": payload_obj,
"created_at": _convert_snapshot_value(created_at),
"edited_by": snap_editor,
"change_summary": snap_summary or "",
}
)
payload0 = snapshots[0]["payload"]
region_id = str(payload0["region_id"])
permit_id = str(payload0["permit_id"])
region_name = str(payload0.get("region_name") or "")
permit_name = str(payload0.get("permit_name") or "")
theme_ids: Set[str] = set(str(t) for t in (payload0.get("theme_ids") or []) if t)
scope_ids: Set[str] = set()
subitem_ids: Set[str] = set()
for snap in snapshots:
payload = snap["payload"]
for tid in payload.get("theme_ids") or []:
if tid:
theme_ids.add(str(tid))
for scope in payload.get("scope_ids") or []:
if scope:
scope_ids.add(str(scope))
for subitem in payload.get("subitem_ids") or []:
if subitem:
subitem_ids.add(str(subitem))
detail_fields = {
"permit_status": payload0.get("permit_status"),
"subitem_summary": payload0.get("subitem_summary"),
"responsible_contact": payload0.get("responsible_contact"),
"jurisdiction_scope": payload0.get("jurisdiction_scope"),
}
source_name = _clean_text(payload0.get("permit_source_name"))
source_type = _clean_text(payload0.get("permit_source_type")) or "snapshot"
source_detail = payload0.get("permit_source_detail")
insert_counts = {
"region_themes": 0,
"region_theme_permits": 0,
"region_permit_details": 0,
"region_permit_scopes": 0,
"region_permit_subitems": 0,
"region_permit_risks": 0,
"risks_upserted": 0,
"permit_sources_synced": 0,
}
restored_risk_ids: Set[str] = set()
try:
for theme_id in sorted(theme_ids):
cur.execute(
"""
INSERT INTO region_themes (region_id, theme_id)
VALUES (%s, %s)
ON CONFLICT DO NOTHING
""",
(region_id, theme_id),
)
insert_counts["region_themes"] += cur.rowcount or 0
cur.execute(
"""
INSERT INTO region_theme_permits (region_id, theme_id, permit_id)
VALUES (%s, %s, %s)
ON CONFLICT DO NOTHING
""",
(region_id, theme_id, permit_id),
)
insert_counts["region_theme_permits"] += cur.rowcount or 0
cur.execute(
"""
INSERT INTO region_permit_details (
region_id,
permit_id,
permit_status,
subitem_summary,
responsible_contact,
jurisdiction_scope
)
VALUES (%s, %s, %s, %s, %s, %s)
ON CONFLICT (region_id, permit_id)
DO UPDATE SET
permit_status = EXCLUDED.permit_status,
subitem_summary = EXCLUDED.subitem_summary,
responsible_contact = EXCLUDED.responsible_contact,
jurisdiction_scope = EXCLUDED.jurisdiction_scope,
updated_at = now()
""",
(
region_id,
permit_id,
detail_fields["permit_status"],
detail_fields["subitem_summary"],
detail_fields["responsible_contact"],
detail_fields["jurisdiction_scope"],
),
)
insert_counts["region_permit_details"] += cur.rowcount or 0
for scope_id in sorted(scope_ids):
cur.execute(
"""
INSERT INTO region_permit_scopes (region_id, permit_id, scope_id)
VALUES (%s, %s, %s)
ON CONFLICT DO NOTHING
""",
(region_id, permit_id, scope_id),
)
insert_counts["region_permit_scopes"] += cur.rowcount or 0
for subitem_id in sorted(subitem_ids):
cur.execute(
"""
INSERT INTO region_permit_subitems (region_id, permit_id, subitem_id)
VALUES (%s, %s, %s)
ON CONFLICT DO NOTHING
""",
(region_id, permit_id, subitem_id),
)
insert_counts["region_permit_subitems"] += cur.rowcount or 0
for snap in snapshots:
payload = snap["payload"]
risk_id = str(payload["risk_id"])
cur.execute(
"""
INSERT INTO risks (id, risk_content, legal_basis, document_no, summary)
VALUES (%s, %s, %s, %s, %s)
ON CONFLICT (id)
DO UPDATE SET
risk_content = EXCLUDED.risk_content,
legal_basis = EXCLUDED.legal_basis,
document_no = EXCLUDED.document_no,
summary = EXCLUDED.summary
""",
(
risk_id,
payload.get("risk_content"),
payload.get("legal_basis"),
payload.get("document_no"),
payload.get("summary"),
),
)
insert_counts["risks_upserted"] += cur.rowcount or 0
cur.execute(
"""
INSERT INTO region_permit_risks (region_id, permit_id, risk_id)
VALUES (%s, %s, %s)
ON CONFLICT DO NOTHING
""",
(region_id, permit_id, risk_id),
)
if cur.rowcount:
insert_counts["region_permit_risks"] += cur.rowcount
restored_risk_ids.add(risk_id)
if source_name:
_ensure_permit_sources_table(conn)
if isinstance(source_detail, (dict, list)):
source_detail_text = json.dumps(source_detail, ensure_ascii=False)
else:
source_detail_text = _clean_text(source_detail) if source_detail is not None else None
cur.execute(
"""
INSERT INTO permit_sources (
region_id,
permit_id,
source_type,
source_name,
source_detail,
created_at,
updated_at
)
VALUES (%s, %s, %s, %s, %s, now(), now())
ON CONFLICT (region_id, permit_id)
DO UPDATE SET
source_type = EXCLUDED.source_type,
source_name = EXCLUDED.source_name,
source_detail = EXCLUDED.source_detail,
updated_at = now()
""",
(
region_id,
permit_id,
source_type or "snapshot",
source_name,
source_detail_text,
),
)
insert_counts["permit_sources_synced"] += 1
elif _permit_sources_available(conn):
cur.execute(
"""
DELETE FROM permit_sources
WHERE region_id = %s AND permit_id = %s
""",
(region_id, permit_id),
)
if cur.rowcount:
insert_counts["permit_sources_synced"] += 1
conn.commit()
except Exception:
conn.rollback()
raise
logger.info(
"[PERMIT-RESTORE] Restored permit %s (%s) in region %s (%s) from snapshot batch %s: %d risk mappings",
permit_id,
permit_name,
region_id,
region_name,
resolved_batch_id,
len(restored_risk_ids),
)
return {
"snapshot_batch_id": resolved_batch_id,
"snapshot_ids": [snap["snapshot_id"] for snap in snapshots],
"restored_risk_count": len(restored_risk_ids),
"restored_risks": sorted(restored_risk_ids),
"region_id": region_id,
"region_name": region_name,
"permit_id": permit_id,
"permit_name": permit_name,
"applied_theme_ids": sorted(theme_ids),
"applied_scope_ids": sorted(scope_ids),
"applied_subitem_ids": sorted(subitem_ids),
"detail_fields": detail_fields,
"insert_counts": insert_counts,
"edited_by": edited_by,
"change_summary": change_summary or "",
}