from __future__ import annotations import json import logging import os import re from collections import OrderedDict from datetime import datetime, date from decimal import Decimal from typing import Any, Dict, List, Optional, Tuple import uuid import pg8000.dbapi as pg # Configure logger logger = logging.getLogger(__name__) if not logger.handlers: handler = logging.StreamHandler() handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s")) logger.addHandler(handler) logger.setLevel(logging.INFO) logger.propagate = False # Separate configuration so legacy fs_law_risk integration keeps using PG_* LIC_DEFAULT_DB = "licensing_risks" _UNSET = object() ARTICLE_HEADING_RE = re.compile(r"(?m)^(第[一二三四五六七八九十百零0-9]+条)") ARTICLE_TOKEN_RE = re.compile(r"(? str: """Render Chinese legal excerpts as Markdown-friendly text.""" if not summary: return "" text = summary.replace("\r\n", "\n").strip() if not text: return "" text = ARTICLE_HEADING_RE.sub(lambda m: f"**{m.group(1)}**", text) text = CN_ENUM_INLINE_RE.sub(lambda m: f"{m.group(1)}\n- ({m.group(2)}) ", text) text = CN_ENUM_LINE_RE.sub(lambda m: f"- ({m.group(1)}) ", text) text = ARABIC_ENUM_INLINE_RE.sub(lambda m: f"{m.group(1)}\n {m.group(2)}", text) text = ARABIC_ENUM_LINE_RE.sub(lambda m: f" {m.group(1)}.", text) text = NESTED_ENUM_INLINE_RE.sub(lambda m: f"{m.group(1)}\n - ({m.group(2)})", text) text = NESTED_ENUM_LINE_RE.sub(lambda m: f" - ({m.group(1)})", text) text = ARTICLE_TOKEN_RE.sub(lambda m: f"**{m.group(1)}**", text) text = ARTICLE_NEWLINE_RE.sub(lambda m: f"\n{m.group(1)}", text) text = COLON_NEWLINE_RE.sub(":\n", text) text = EXTRA_NEWLINES_RE.sub("\n\n", text) text = TRAILING_SPACE_RE.sub("\n", text) text = re.sub(r"\n\s+\n", "\n\n", text) return text.strip() def _lic_pg_conn(autocommit: bool = False) -> pg.Connection: host = os.getenv("LIC_PG_HOST", "172.24.240.1") port = int(os.getenv("LIC_PG_PORT", os.getenv("PG_PORT", "5432"))) user = os.getenv("LIC_PG_USER", os.getenv("PG_USER", "postgres")) password = os.getenv("LIC_PG_PASSWORD", "") database = os.getenv("LIC_PG_DATABASE", LIC_DEFAULT_DB) conn = pg.connect(host=host, port=port, user=user, password=password, database=database) conn.autocommit = autocommit return conn def list_region_theme_options() -> List[Dict[str, str]]: """Return all region-theme pairs usable for LLM selection.""" sql = """ SELECT rt.region_id, r.name AS region_name, rt.theme_id, t.name AS theme_name FROM region_themes rt JOIN regions r ON r.id = rt.region_id JOIN themes t ON t.id = rt.theme_id ORDER BY r.name, t.name """ out: List[Dict[str, str]] = [] with _lic_pg_conn() as conn: cur = conn.cursor() cur.execute(sql) for region_id, region_name, theme_id, theme_name in cur.fetchall(): rid = str(region_id) tid = str(theme_id) out.append( { "option_id": f"{rid}:{tid}", "region_id": rid, "region_name": str(region_name), "theme_id": tid, "theme_name": str(theme_name), "display_name": f"{region_name} · {theme_name}", } ) return out def load_business_scopes(region_id: str) -> List[Dict[str, str]]: """List business scopes bound to a region.""" sql = """ SELECT bs.id, bs.description FROM region_scopes rs JOIN business_scopes bs ON bs.id = rs.scope_id WHERE rs.region_id = %s ORDER BY bs.description """ scopes: List[Dict[str, str]] = [] with _lic_pg_conn() as conn: cur = conn.cursor() cur.execute(sql, (region_id,)) for scope_id, description in cur.fetchall(): scopes.append({"id": str(scope_id), "description": str(description)}) return scopes def list_permits_for_region(region: str) -> List[Dict[str, str]]: """Return all permits available within a region (accepts id or name).""" sql = """ SELECT DISTINCT p.id, p.name FROM region_theme_permits rtp JOIN permits p ON p.id = rtp.permit_id JOIN regions r ON r.id = rtp.region_id WHERE rtp.region_id::text = %s OR LOWER(r.name) = LOWER(%s) ORDER BY p.name """ permits: List[Dict[str, str]] = [] with _lic_pg_conn() as conn: cur = conn.cursor() cur.execute(sql, (region, region)) for permit_id, permit_name in cur.fetchall(): permits.append({"id": str(permit_id), "name": str(permit_name)}) return permits def _load_permit_scopes_for_region( conn: pg.Connection, region_id: str, permit_ids: List[str] ) -> Dict[str, List[Dict[str, str]]]: """Return mapping of permit_id -> business scopes for that permit within region.""" scope_map: Dict[str, List[Dict[str, str]]] = {pid: [] for pid in permit_ids} if not permit_ids: return scope_map sql = """ SELECT rps.permit_id, bs.id, bs.description FROM region_permit_scopes rps JOIN business_scopes bs ON bs.id = rps.scope_id WHERE rps.region_id = %s ORDER BY rps.permit_id, bs.description """ cur = conn.cursor() try: cur.execute(sql, (region_id,)) except pg.ProgrammingError as exc: # 42P01 => undefined_table; allow fallback when migration not yet applied. sqlstate = getattr(exc, "sqlstate", "") if sqlstate == "42P01": return scope_map raise for permit_id, scope_id, description in cur.fetchall(): pid = str(permit_id) if pid not in scope_map: continue scope_map[pid].append({"id": str(scope_id), "description": str(description)}) return scope_map def load_permits_and_risks( region_id: str, theme_id: str, permit_id: Optional[str] = None ) -> List[Dict[str, object]]: """Return permits with attached risk entries for a region-theme pair.""" sql = """ SELECT p.id AS permit_id, p.name AS permit_name, rk.id AS risk_id, rk.risk_content, rk.legal_basis, rk.document_no, rk.summary, rpd.permit_status, rpd.subitem_summary, rpd.responsible_contact, rpd.jurisdiction_scope FROM region_theme_permits rtp JOIN permits p ON p.id = rtp.permit_id LEFT JOIN region_permit_risks rpr ON rpr.region_id = rtp.region_id AND rpr.permit_id = rtp.permit_id LEFT JOIN risks rk ON rk.id = rpr.risk_id LEFT JOIN region_permit_details rpd ON rpd.region_id = rtp.region_id AND rpd.permit_id = rtp.permit_id WHERE rtp.region_id = %s AND rtp.theme_id = %s """ params: List[Any] = [region_id, theme_id] if permit_id is not None: sql += " AND rtp.permit_id = %s" params.append(permit_id) sql += """ ORDER BY p.name, rk.risk_content """ permits: Dict[str, Dict[str, object]] = {} with _lic_pg_conn() as conn: cur = conn.cursor() cur.execute(sql, tuple(params)) for row in cur.fetchall(): ( permit_id, permit_name, risk_id, risk_content, legal_basis, document_no, summary, permit_status, subitem_summary, responsible_contact, jurisdiction_scope, ) = row pid = str(permit_id) entry = permits.setdefault( pid, { "id": pid, "name": str(permit_name), "business_scopes": [], "risks": [], "permit_status": None, "subitem_summary": None, "responsible_contact": None, "jurisdiction_scope": None, }, ) if entry["permit_status"] is None and permit_status: entry["permit_status"] = permit_status.strip() or None if entry["subitem_summary"] is None and subitem_summary: entry["subitem_summary"] = subitem_summary.strip() or None if entry["responsible_contact"] is None and responsible_contact: entry["responsible_contact"] = responsible_contact.strip() or None if entry["jurisdiction_scope"] is None and jurisdiction_scope: entry["jurisdiction_scope"] = jurisdiction_scope.strip() or None if risk_id is not None: summary_markdown = _format_summary_markdown(summary or "") entry["risks"].append( { "id": str(risk_id), "risk_content": risk_content or "", "legal_basis": legal_basis or "", "document_no": document_no or "", "summary": summary_markdown, } ) permit_ids = list(permits.keys()) scope_map = _load_permit_scopes_for_region(conn, region_id, permit_ids) for pid in permit_ids: permits[pid]["business_scopes"] = scope_map.get(pid, []) return list(permits.values()) def find_permit_contexts_by_name(permit_name: str) -> List[Dict[str, str]]: """Return region/theme contexts for permits with an exact name match.""" if not permit_name: return [] sql = """ SELECT rtp.region_id, r.name AS region_name, rtp.theme_id, t.name AS theme_name, p.id AS permit_id, p.name AS permit_name FROM region_theme_permits rtp JOIN permits p ON p.id = rtp.permit_id JOIN regions r ON r.id = rtp.region_id JOIN themes t ON t.id = rtp.theme_id WHERE p.name = %s ORDER BY r.name, t.name """ ordered: OrderedDict[Tuple[str, str], Dict[str, str]] = OrderedDict() with _lic_pg_conn() as conn: cur = conn.cursor() cur.execute(sql, (permit_name,)) for row in cur.fetchall(): region_id, region_name, theme_id, theme_name, permit_id, canonical_name = row rid = str(region_id) pid = str(permit_id) key = (rid, pid) if key in ordered: continue ordered[key] = { "region_id": rid, "region_name": str(region_name), "theme_id": str(theme_id), "theme_name": str(theme_name), "permit_id": pid, "permit_name": str(canonical_name), } return list(ordered.values()) def load_theme_payload(region_id: str, theme_id: str) -> Dict[str, object]: """Assemble full data bundle for a region-theme selection.""" info_sql = """ SELECT r.id, r.name, t.id, t.name FROM regions r JOIN region_themes rt ON rt.region_id = r.id JOIN themes t ON t.id = rt.theme_id WHERE r.id = %s AND t.id = %s LIMIT 1 """ with _lic_pg_conn() as conn: cur = conn.cursor() cur.execute(info_sql, (region_id, theme_id)) row = cur.fetchone() if not row: raise ValueError("Region/theme combination not found") region_uuid, region_name, theme_uuid, theme_name = row permits = load_permits_and_risks(region_id, theme_id) return { "region": {"id": str(region_uuid), "name": str(region_name)}, "theme": {"id": str(theme_uuid), "name": str(theme_name)}, "permits": permits, } def _get_checkpoints_dir() -> str: """Get the directory for storing checkpoint files.""" base_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data") checkpoints_dir = os.path.join(base_dir, "checkpoints") os.makedirs(checkpoints_dir, exist_ok=True) return checkpoints_dir def _get_all_tables() -> List[str]: """Get list of all tables in the licensing_risks database.""" sql = """ SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE' ORDER BY table_name """ with _lic_pg_conn() as conn: cur = conn.cursor() cur.execute(sql) return [row[0] for row in cur.fetchall()] def _get_table_dependencies(conn: pg.Connection) -> Dict[str, List[str]]: """ 获取表依赖关系图。 返回: {被引用表名: [引用它的表列表]} 例如: {'regions': ['region_themes', 'region_scopes'], 'themes': ['region_themes']} """ sql = """ SELECT ccu.table_name AS referenced_table, tc.table_name AS dependent_table FROM information_schema.table_constraints tc JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema JOIN information_schema.constraint_column_usage ccu ON tc.constraint_name = ccu.constraint_name AND tc.table_schema = ccu.table_schema WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_schema = 'public' ORDER BY ccu.table_name, tc.table_name """ cur = conn.cursor() cur.execute(sql) dependencies = {} for row in cur.fetchall(): referenced_table, dependent_table = row if referenced_table not in dependencies: dependencies[referenced_table] = [] if dependent_table not in dependencies[referenced_table]: dependencies[referenced_table].append(dependent_table) return dependencies def _topological_sort_tables(all_tables: List[str], dependencies: Dict[str, List[str]]) -> List[str]: """ 使用拓扑排序确定表恢复顺序。 先返回父表(无外键依赖),再返回子表(引用其他表)。 这确保恢复时不会违反外键约束。 """ from collections import deque # 计算每个表的入度(被多少表引用/依赖) in_degree = {table: 0 for table in all_tables} for parent_table, children in dependencies.items(): for child in children: if child in all_tables: in_degree[child] += 1 # 入度为0的表是父表,可以先恢复 queue = deque([table for table in all_tables if in_degree[table] == 0]) result = [] while queue: table = queue.popleft() result.append(table) # 减少依赖该表的子表的入度 for parent, children in dependencies.items(): if parent == table: for child in children: if child in all_tables and in_degree[child] > 0: in_degree[child] -= 1 if in_degree[child] == 0: queue.append(child) # 如果有循环依赖或孤立节点,将剩余表添加到末尾 for table in all_tables: if table not in result: result.append(table) return result def _backup_table(conn: pg.Connection, table_name: str) -> Tuple[List[Dict[str, Any]], int]: """Backup a single table and return its data and row count.""" logger.info(f"[CHECKPOINT] Backing up table: {table_name}") sql = f"SELECT * FROM {table_name}" cur = conn.cursor() cur.execute(sql) rows = cur.fetchall() colnames = [desc[0] for desc in cur.description] data = [] for row in rows: row_dict = {} for i, col in enumerate(colnames): value = row[i] if value is not None: # Convert UUID and other non-serializable types to strings if hasattr(value, 'isoformat'): # UUID, datetime, etc. row_dict[col] = str(value) else: row_dict[col] = value data.append(row_dict) row_count = len(data) logger.info(f"[CHECKPOINT] Backup complete: {table_name} - {row_count} rows, {len(colnames)} columns") return data, row_count def _restore_table(conn: pg.Connection, table_name: str, data: List[Dict[str, Any]], batch_size: int = 1000) -> int: """Restore a table from backup data. Returns number of rows restored.""" if not data: logger.info(f"[CHECKPOINT] Skipping empty table: {table_name}") return 0 conn.autocommit = False try: cur = conn.cursor() logger.info(f"[CHECKPOINT] Truncating table: {table_name}") # Use TRUNCATE with CASCADE to handle foreign key dependencies # This will automatically remove dependent records truncate_sql = f"TRUNCATE TABLE {table_name} CASCADE" cur.execute(truncate_sql) columns = list(data[0].keys()) placeholders = ", ".join(["%s"] * len(columns)) logger.info(f"[CHECKPOINT] Restoring {len(data)} rows into {table_name} (columns: {', '.join(columns)})") # 🚀 优化: 使用批量插入而不是逐行插入 # pg8000 使用 executemany 进行批量插入 batch_size = 1000 # 每批1000行 if len(data) <= batch_size: # 小数据量,直接批量插入 values_list = [[row.get(col) for col in columns] for row in data] cur.executemany(f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})", values_list) logger.info(f"[CHECKPOINT] Bulk insert complete: {table_name} - {len(data)} rows inserted") else: # 大数据量,分批插入 total_rows = len(data) for i in range(0, total_rows, batch_size): batch_end = min(i + batch_size, total_rows) batch_data = data[i:batch_end] values_list = [[row.get(col) for col in columns] for row in batch_data] cur.executemany(f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})", values_list) logger.info(f"[CHECKPOINT] Progress: {table_name} - {batch_end}/{total_rows} rows inserted") conn.commit() logger.info(f"[CHECKPOINT] Restore complete: {table_name} - {len(data)} rows successfully inserted") return len(data) except Exception as e: conn.rollback() logger.error(f"[CHECKPOINT] Restore FAILED: {table_name} - {str(e)}") raise e finally: conn.autocommit = False def create_checkpoint(description: str = "") -> Dict[str, Any]: """ 安全创建checkpoint,所有表操作在一个事务中。 Args: description: checkpoint的描述信息 Returns: 包含checkpoint信息的字典 """ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") checkpoint_id = f"checkpoint_{timestamp}" logger.info("=" * 80) logger.info(f"[CHECKPOINT] Starting checkpoint creation: {checkpoint_id}") if description: logger.info(f"[CHECKPOINT] Description: {description}") logger.info("=" * 80) tables = _get_all_tables() logger.info(f"[CHECKPOINT] Found {len(tables)} tables to backup: {', '.join(tables)}") checkpoint_data = { "checkpoint_id": checkpoint_id, "timestamp": timestamp, "description": description, "tables": {} } total_rows = 0 table_counts = {} with _lic_pg_conn(autocommit=False) as conn: try: for i, table in enumerate(tables, 1): logger.info(f"[CHECKPOINT] [{i}/{len(tables)}] Processing table: {table}") data, row_count = _backup_table(conn, table) checkpoint_data["tables"][table] = data table_counts[table] = row_count total_rows += row_count logger.info(f"[CHECKPOINT] [{i}/{len(tables)}] Table {table} backed up: {row_count} rows") # 全部成功后才提交 logger.info("[CHECKPOINT] All tables backed up successfully, committing transaction...") conn.commit() logger.info("[CHECKPOINT] Transaction committed") except Exception as e: # 任何失败都回滚 logger.error(f"[CHECKPOINT] ERROR during backup, rolling back: {str(e)}") conn.rollback() raise e finally: conn.autocommit = False checkpoint_data["table_counts"] = table_counts checkpoint_data["total_rows"] = total_rows checkpoints_dir = _get_checkpoints_dir() checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json") logger.info(f"[CHECKPOINT] Saving checkpoint file: {checkpoint_file}") def json_serializer(obj): """Convert non-JSON serializable objects to strings.""" try: import uuid if isinstance(obj, uuid.UUID): return str(obj) except ImportError: pass if hasattr(obj, 'isoformat'): return str(obj) raise TypeError(f"Object of type {type(obj)} is not JSON serializable") with open(checkpoint_file, "w", encoding="utf-8") as f: json.dump(checkpoint_data, f, ensure_ascii=False, indent=2, default=json_serializer) logger.info("=" * 80) logger.info(f"[CHECKPOINT] Checkpoint creation COMPLETED: {checkpoint_id}") logger.info(f"[CHECKPOINT] Total rows backed up: {total_rows}") logger.info(f"[CHECKPOINT] File: {checkpoint_file}") logger.info("=" * 80) return { "checkpoint_id": checkpoint_id, "timestamp": timestamp, "description": description, "total_rows": total_rows, "table_counts": table_counts } def list_checkpoints() -> List[Dict[str, Any]]: """List all available checkpoints.""" checkpoints_dir = _get_checkpoints_dir() checkpoints = [] if not os.path.exists(checkpoints_dir): return checkpoints for filename in os.listdir(checkpoints_dir): if filename.endswith(".json"): filepath = os.path.join(checkpoints_dir, filename) try: with open(filepath, "r", encoding="utf-8") as f: data = json.load(f) checkpoints.append({ "checkpoint_id": data["checkpoint_id"], "timestamp": data["timestamp"], "description": data.get("description", ""), "total_rows": data.get("total_rows", 0), "table_counts": data.get("table_counts", {}), "filename": filename }) except Exception as e: print(f"Error reading checkpoint {filename}: {e}") return sorted(checkpoints, key=lambda x: x["timestamp"], reverse=True) def restore_checkpoint( checkpoint_id: str, create_auto_backup: bool = True, batch_size: int = 1000 ) -> Dict[str, Any]: """ 安全恢复数据库从checkpoint。 ⚠️ 危险操作: 会覆盖现有数据! Args: checkpoint_id: 要恢复的checkpoint ID create_auto_backup: 是否在恢复前自动备份当前状态 batch_size: 批量插入的批次大小 (默认1000行/批) Returns: 包含恢复结果的字典 """ checkpoints_dir = _get_checkpoints_dir() checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json") logger.warning("=" * 80) logger.warning(f"[CHECKPOINT] WARNING: Starting restore operation: {checkpoint_id}") logger.warning("[CHECKPOINT] This will OVERWRITE all existing data in the database!") logger.warning("=" * 80) if not os.path.exists(checkpoint_file): error_msg = f"Checkpoint {checkpoint_id} not found" logger.error(f"[CHECKPOINT] {error_msg}") raise ValueError(error_msg) with open(checkpoint_file, "r", encoding="utf-8") as f: checkpoint_data = json.load(f) # 自动备份当前状态(可选但推荐) auto_backup_info = None if create_auto_backup: logger.info("[CHECKPOINT] Creating auto-backup before restore... (THIS MAY TAKE TIME)") try: import time start_time = time.time() auto_backup_info = create_checkpoint(f"auto_backup_before_restore_{checkpoint_id}") elapsed = time.time() - start_time logger.info(f"[CHECKPOINT] Auto-backup created in {elapsed:.2f}s: {auto_backup_info['checkpoint_id']}") logger.info(f"[CHECKPOINT] Auto-backup contains {auto_backup_info['total_rows']} rows") except Exception as e: logger.error(f"[CHECKPOINT] WARNING: Failed to create auto-backup: {e}") logger.warning("[CHECKPOINT] Continuing without auto-backup...") else: logger.info("[CHECKPOINT] Auto-backup DISABLED by user") tables = checkpoint_data.get("tables", {}) total_rows_in_checkpoint = sum(len(data) for data in tables.values()) logger.info("=" * 80) logger.info(f"[CHECKPOINT] Checkpoint details:") logger.info(f"[CHECKPOINT] ID: {checkpoint_id}") logger.info(f"[CHECKPOINT] Tables: {len(tables)}") logger.info(f"[CHECKPOINT] Total rows: {total_rows_in_checkpoint}") logger.info(f"[CHECKPOINT] Auto-backup: {'Yes (' + auto_backup_info['checkpoint_id'] + ')' if auto_backup_info else 'No'}") logger.info("=" * 80) restore_summary = { "checkpoint_id": checkpoint_id, "auto_backup": auto_backup_info.get("checkpoint_id") if auto_backup_info else None, "tables_restored": 0, "total_rows_restored": 0, "table_details": {}, "errors": [] } with _lic_pg_conn(autocommit=False) as conn: try: # 1. 构建表依赖关系图 logger.info("[CHECKPOINT] Building table dependency graph...") dependencies = _get_table_dependencies(conn) all_tables = list(tables.keys()) logger.info(f"[CHECKPOINT] Found {len(dependencies)} table dependencies") # 2. 拓扑排序获取恢复顺序 logger.info("[CHECKPOINT] Calculating restore order...") restore_order = _topological_sort_tables(all_tables, dependencies) logger.info(f"[CHECKPOINT] Restore order: {' -> '.join(restore_order)}") # 3. 锁定所有表(防止并发写入) logger.info("[CHECKPOINT] Acquiring exclusive locks on all tables...") cur = conn.cursor() for table in restore_order: cur.execute(f"LOCK TABLE {table} IN EXCLUSIVE MODE") logger.info("[CHECKPOINT] All tables locked exclusively") # 4. 按依赖顺序恢复表 logger.info("=" * 80) logger.info("[CHECKPOINT] Starting restore process...") logger.info("=" * 80) import time restore_start_time = time.time() for i, table_name in enumerate(restore_order, 1): data = tables.get(table_name, []) table_start_time = time.time() logger.info(f"[CHECKPOINT] [{i}/{len(restore_order)}] Preparing to restore table: {table_name}") try: rows_restored = _restore_table(conn, table_name, data, batch_size=batch_size) table_elapsed = time.time() - table_start_time restore_summary["tables_restored"] += 1 restore_summary["total_rows_restored"] += rows_restored restore_summary["table_details"][table_name] = rows_restored logger.info(f"[CHECKPOINT] [{i}/{len(restore_order)}] Table {table_name} restored: {rows_restored} rows in {table_elapsed:.2f}s") except Exception as e: error_msg = f"Failed to restore table {table_name}: {str(e)}" logger.error(f"[CHECKPOINT] ERROR: {error_msg}") restore_summary["errors"].append(error_msg) raise e restore_elapsed = time.time() - restore_start_time total_tables = len(restore_order) logger.info(f"[CHECKPOINT] All {total_tables} tables restored in {restore_elapsed:.2f}s") # 5. 提交事务 logger.info("=" * 80) logger.info("[CHECKPOINT] All tables restored successfully, committing transaction...") conn.commit() logger.info("[CHECKPOINT] Transaction committed successfully") logger.info("=" * 80) logger.warning("=" * 80) logger.warning(f"[CHECKPOINT] RESTORE COMPLETED SUCCESSFULLY: {checkpoint_id}") logger.warning(f"[CHECKPOINT] Tables restored: {restore_summary['tables_restored']}/{len(restore_order)}") logger.warning(f"[CHECKPOINT] Total rows restored: {restore_summary['total_rows_restored']}") if auto_backup_info: logger.warning(f"[CHECKPOINT] Auto-backup available: {auto_backup_info['checkpoint_id']}") logger.warning("=" * 80) return { "status": "success", "message": f"Successfully restored {restore_summary['tables_restored']} tables, " f"{restore_summary['total_rows_restored']} total rows", "summary": restore_summary } except Exception as e: # 回滚事务 logger.error(f"[CHECKPOINT] ERROR during restore, rolling back: {str(e)}") conn.rollback() logger.warning("[CHECKPOINT] Transaction rolled back, changes reverted") error_info = { "status": "error", "message": f"Restore failed: {str(e)}", "summary": restore_summary, "auto_backup_available": bool(auto_backup_info) } # 如果有自动备份,提供恢复建议 if auto_backup_info: error_info["recovery_suggestion"] = ( f"Use auto-backup to restore current state: " f"restore_checkpoint('{auto_backup_info['checkpoint_id']}')" ) logger.warning(f"[CHECKPOINT] Recovery suggestion: Use auto-backup '{auto_backup_info['checkpoint_id']}'") logger.error("=" * 80) logger.error(f"[CHECKPOINT] RESTORE FAILED: {str(e)}") if auto_backup_info: logger.error(f"[CHECKPOINT] You can restore from auto-backup: {auto_backup_info['checkpoint_id']}") logger.error("=" * 80) return error_info def delete_checkpoint(checkpoint_id: str) -> bool: """Delete a checkpoint file.""" checkpoints_dir = _get_checkpoints_dir() checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json") if os.path.exists(checkpoint_file): os.remove(checkpoint_file) return True return False # --------------------------------------------------------------------------- # Permit risk snapshot helpers (region_permit_risk_vw) # --------------------------------------------------------------------------- def _convert_snapshot_value(value: Any) -> Any: """Convert database values into JSON-serialisable primitives.""" if isinstance(value, (datetime, date)): return value.isoformat() if isinstance(value, Decimal): return float(value) if isinstance(value, uuid.UUID): return str(value) if isinstance(value, memoryview): return bytes(value).decode("utf-8", errors="replace") if isinstance(value, (bytes, bytearray)): return value.decode("utf-8", errors="replace") if isinstance(value, (list, tuple)): return [_convert_snapshot_value(v) for v in value] return value def _normalize_snapshot_payload(record: Dict[str, Any]) -> Dict[str, Any]: """Return a JSON-safe copy of the permit risk view record.""" return {key: _convert_snapshot_value(val) for key, val in record.items()} def _fetch_permit_risk_row( conn: pg.Connection, region_id: str, permit_id: str, risk_id: str ) -> Dict[str, Any]: """Fetch a single row from the consolidation view.""" sql = """ SELECT region_id, region_name, permit_id, permit_name, risk_id, risk_content, legal_basis, document_no, summary, theme_ids, theme_names, scope_ids, scope_descriptions, subitem_ids, permit_status, subitem_summary, responsible_contact, jurisdiction_scope, permit_detail_updated_at, permit_risk_key FROM region_permit_risk_vw WHERE region_id = %s AND permit_id = %s AND risk_id = %s LIMIT 1 """ cur = conn.cursor() cur.execute(sql, (region_id, permit_id, risk_id)) row = cur.fetchone() if not row: raise ValueError("Permit risk combination not found in consolidation view") columns = [desc[0] for desc in cur.description] return {columns[i]: row[i] for i in range(len(columns))} def _insert_permit_risk_snapshot( conn: pg.Connection, payload: Dict[str, Any], *, edited_by: Optional[str], change_summary: Optional[str], ) -> Dict[str, Any]: """Insert a snapshot row and return metadata.""" permit_risk_key = str(payload["permit_risk_key"]) cur = conn.cursor() cur.execute( """ SELECT COALESCE(MAX(version), 0) FROM permit_risk_snapshots WHERE permit_risk_key = %s FOR UPDATE """, (permit_risk_key,), ) row = cur.fetchone() next_version = (int(row[0]) + 1) if row else 1 cur.execute( """ INSERT INTO permit_risk_snapshots ( region_id, permit_id, risk_id, permit_risk_key, version, payload, edited_by, change_summary ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) RETURNING snapshot_id, created_at """, ( payload["region_id"], payload["permit_id"], payload["risk_id"], permit_risk_key, next_version, pg.Json(payload), edited_by, change_summary, ), ) snapshot_id, created_at = cur.fetchone() return { "snapshot_id": str(snapshot_id), "region_id": payload["region_id"], "permit_id": payload["permit_id"], "risk_id": payload["risk_id"], "permit_risk_key": permit_risk_key, "version": next_version, "created_at": _convert_snapshot_value(created_at), "edited_by": edited_by, "change_summary": change_summary or "", "payload": payload, } def _create_snapshot_with_connection( conn: pg.Connection, region_id: str, permit_id: str, risk_id: str, *, edited_by: Optional[str], change_summary: Optional[str], ) -> Dict[str, Any]: """Create a snapshot using an existing DB connection (no commit).""" view_record = _fetch_permit_risk_row(conn, region_id, permit_id, risk_id) payload = _normalize_snapshot_payload(view_record) metadata = _insert_permit_risk_snapshot( conn, payload, edited_by=edited_by, change_summary=change_summary, ) logger.info( "[CHECKPOINT] Snapshot created: %s version %s", metadata["permit_risk_key"], metadata["version"], ) return metadata def create_permit_risk_snapshot( region_id: str, permit_id: str, risk_id: str, *, edited_by: Optional[str] = None, change_summary: Optional[str] = None, ) -> Dict[str, Any]: """ Capture the current state of a region/permit/risk record as a versioned snapshot. Returns metadata about the created snapshot including version number. """ with _lic_pg_conn(autocommit=False) as conn: try: snapshot_meta = _create_snapshot_with_connection( conn, region_id, permit_id, risk_id, edited_by=edited_by, change_summary=change_summary, ) conn.commit() return snapshot_meta except Exception: conn.rollback() raise def list_permit_risk_snapshots( region_id: Optional[str] = None, permit_id: Optional[str] = None, risk_id: Optional[str] = None, *, permit_risk_key: Optional[str] = None, limit: int = 20, offset: int = 0, ) -> List[Dict[str, Any]]: """ List snapshots for a region/permit/risk combination ordered by version descending. At least one identifier (permit_risk_key or region/permit/risk) must be provided. """ filters: List[str] = [] params: List[Any] = [] if permit_risk_key: filters.append("permit_risk_key = %s") params.append(permit_risk_key) else: if region_id: filters.append("region_id = %s") params.append(region_id) if permit_id: filters.append("permit_id = %s") params.append(permit_id) if risk_id: filters.append("risk_id = %s") params.append(risk_id) if not filters: raise ValueError("At least one identifier must be provided to list snapshots") filters_clause = " AND ".join(filters) sql = f""" SELECT snapshot_id, version, permit_risk_key, edited_by, change_summary, created_at FROM permit_risk_snapshots WHERE {filters_clause} ORDER BY version DESC LIMIT %s OFFSET %s """ params.extend([limit, offset]) with _lic_pg_conn() as conn: cur = conn.cursor() cur.execute(sql, tuple(params)) rows = cur.fetchall() snapshots: List[Dict[str, Any]] = [] for snapshot_id, version, key, editor, summary, created_at in rows: snapshots.append( { "snapshot_id": str(snapshot_id), "permit_risk_key": key, "version": int(version), "created_at": _convert_snapshot_value(created_at), "edited_by": editor, "change_summary": summary or "", } ) return snapshots def get_permit_risk_snapshot(snapshot_id: str) -> Optional[Dict[str, Any]]: """Fetch a snapshot payload by its identifier.""" sql = """ SELECT snapshot_id, region_id, permit_id, risk_id, permit_risk_key, version, payload, edited_by, change_summary, created_at FROM permit_risk_snapshots WHERE snapshot_id = %s """ with _lic_pg_conn() as conn: cur = conn.cursor() cur.execute(sql, (snapshot_id,)) row = cur.fetchone() if not row: return None ( snap_id, region_id, permit_id, risk_id, permit_risk_key, version, payload, edited_by, change_summary, created_at, ) = row if isinstance(payload, (bytes, bytearray, memoryview)): payload_obj = json.loads(payload) else: payload_obj = payload if isinstance(payload, dict) else json.loads(payload) return { "snapshot_id": str(snap_id), "region_id": str(region_id), "permit_id": str(permit_id), "risk_id": str(risk_id), "permit_risk_key": permit_risk_key, "version": int(version), "created_at": _convert_snapshot_value(created_at), "edited_by": edited_by, "change_summary": change_summary or "", "payload": payload_obj, } def update_permit_risk_record( region_id: str, permit_id: str, risk_id: str, *, risk_content: Any = _UNSET, legal_basis: Any = _UNSET, document_no: Any = _UNSET, summary: Any = _UNSET, permit_status: Any = _UNSET, subitem_summary: Any = _UNSET, responsible_contact: Any = _UNSET, jurisdiction_scope: Any = _UNSET, edited_by: Optional[str] = None, change_summary: Optional[str] = None, ) -> Dict[str, Any]: """ Update the permit risk record while capturing a checkpoint snapshot beforehand. Returns the snapshot metadata (pre-change) and the refreshed view row (post-change). """ update_flags = [ risk_content, legal_basis, document_no, summary, permit_status, subitem_summary, responsible_contact, jurisdiction_scope, ] if all(flag is _UNSET for flag in update_flags): raise ValueError("No fields provided to update.") with _lic_pg_conn(autocommit=False) as conn: try: cur = conn.cursor() cur.execute( """ SELECT 1 FROM region_permit_risks WHERE region_id = %s AND permit_id = %s AND risk_id = %s FOR UPDATE """, (region_id, permit_id, risk_id), ) if cur.fetchone() is None: raise ValueError("Permit risk combination not found.") snapshot_meta = _create_snapshot_with_connection( conn, region_id, permit_id, risk_id, edited_by=edited_by, change_summary=change_summary, ) risk_updates: List[str] = [] risk_params: List[Any] = [] risk_fields = ( ("risk_content", risk_content), ("legal_basis", legal_basis), ("document_no", document_no), ("summary", summary), ) for column, value in risk_fields: if value is not _UNSET: risk_updates.append(f"{column} = %s") risk_params.append(value) if risk_updates: risk_params.append(risk_id) cur.execute( f"UPDATE risks SET {', '.join(risk_updates)} WHERE id = %s", tuple(risk_params), ) detail_columns: List[str] = [] detail_values: List[Any] = [] detail_fields = ( ("permit_status", permit_status), ("subitem_summary", subitem_summary), ("responsible_contact", responsible_contact), ("jurisdiction_scope", jurisdiction_scope), ) for column, value in detail_fields: if value is not _UNSET: detail_columns.append(column) detail_values.append(value) details_updated = False if detail_columns: insert_cols = ", ".join(["region_id", "permit_id"] + detail_columns) insert_placeholders = ", ".join(["%s"] * (2 + len(detail_values))) update_assignments = ", ".join( [f"{col} = EXCLUDED.{col}" for col in detail_columns] ) sql = f""" INSERT INTO region_permit_details ({insert_cols}) VALUES ({insert_placeholders}) ON CONFLICT (region_id, permit_id) DO UPDATE SET {update_assignments}, updated_at = now() """ cur.execute( sql, (region_id, permit_id, *detail_values), ) details_updated = True updated_record = _normalize_snapshot_payload( _fetch_permit_risk_row(conn, region_id, permit_id, risk_id) ) conn.commit() logger.info( "[CHECKPOINT] Permit risk updated: %s version %s -> new snapshot ready", snapshot_meta["permit_risk_key"], snapshot_meta["version"], ) return { "snapshot": snapshot_meta, "current": updated_record, "risk_updated": bool(risk_updates), "details_updated": details_updated, } except Exception: conn.rollback() raise