# 检查点功能安全补丁 ## 修复清单 ### ⚠️ 当前代码存在的安全问题 1. **数据永久丢失风险** - 使用 `TRUNCATE CASCADE` 无回退机制 2. **违反外键约束** - 恢复顺序不考虑依赖关系 3. **并发写入冲突** - 无表级锁保护 4. **部分失败不一致** - checkpoint创建无事务保护 --- ## 立即修复方案 ### 修复1: 添加表依赖排序 (lawrisk/services/licensing_repo.py) 在 `_get_all_tables()` 函数后添加: ```python def _get_table_dependencies(conn: pg.Connection) -> Dict[str, List[str]]: """获取表依赖关系: {表名: [依赖该表的子表列表]}""" sql = """ SELECT ccu.table_name AS referenced_table, tc.table_name AS dependent_table FROM information_schema.table_constraints tc JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name JOIN information_schema.constraint_column_usage ccu ON tc.constraint_name = ccu.constraint_name WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_schema = 'public' """ cur = conn.cursor() cur.execute(sql) dependencies = {} for row in cur.fetchall(): referenced_table, dependent_table = row if referenced_table not in dependencies: dependencies[referenced_table] = [] dependencies[referenced_table].append(dependent_table) return dependencies def _topological_sort_tables(all_tables: List[str], dependencies: Dict[str, List[str]]) -> List[str]: """拓扑排序: 先返回父表,再返回子表""" from collections import deque # 计算每个表的入度(被多少表依赖) in_degree = {table: 0 for table in all_tables} for parent_table, children in dependencies.items(): for child in children: if child in all_tables: in_degree[child] += 1 # 父表入度为0,优先处理 queue = deque([table for table in all_tables if in_degree[table] == 0]) result = [] while queue: table = queue.popleft() result.append(table) # 减少依赖该表的子表的入度 for parent, children in dependencies.items(): if parent == table: for child in children: if child in all_tables and in_degree[child] > 0: in_degree[child] -= 1 if in_degree[child] == 0: queue.append(child) # 如果有环,添加剩余表 for table in all_tables: if table not in result: result.append(table) return result ``` ### 修复2: 改进 _restore_table() 函数 (第380-409行) ```python def _restore_table_safe(conn: pg.Connection, table_name: str, data: List[Dict[str, Any]]) -> int: """安全恢复表,使用单事务""" if not data: return 0 conn.autocommit = False try: cur = conn.cursor() # 获取列信息 first_row = data[0] columns = list(first_row.keys()) placeholders = ", ".join(["%s"] * len(columns)) # TRUNCATE表 (使用CASCADE处理外键) truncate_sql = f"TRUNCATE TABLE {table_name} CASCADE" cur.execute(truncate_sql) # 批量插入 insert_sql = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})" for row in data: values = [row.get(col) for col in columns] cur.execute(insert_sql, values) conn.commit() return len(data) except Exception as e: conn.rollback() raise e finally: conn.autocommit = False ``` ### 修复3: 改进 restore_checkpoint() 函数 (第493-526行) ```python def restore_checkpoint(checkpoint_id: str, create_auto_backup: bool = True) -> Dict[str, Any]: """ 恢复数据库从checkpoint。 ⚠️ 危险操作: 会永久删除现有数据! Args: checkpoint_id: 要恢复的checkpoint ID create_auto_backup: 是否在恢复前自动备份当前状态 Returns: 恢复结果字典 """ checkpoints_dir = _get_checkpoints_dir() checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json") if not os.path.exists(checkpoint_file): raise ValueError(f"Checkpoint {checkpoint_id} not found") with open(checkpoint_file, "r", encoding="utf-8") as f: checkpoint_data = json.load(f) # 自动备份当前状态(可选但强烈推荐) auto_backup_info = None if create_auto_backup: auto_backup_info = create_checkpoint_safe( f"auto_backup_before_restore_{checkpoint_id}" ) tables = checkpoint_data.get("tables", {}) restore_summary = { "checkpoint_id": checkpoint_id, "auto_backup": auto_backup_info.get("checkpoint_id") if auto_backup_info else None, "tables_restored": 0, "total_rows_restored": 0, "table_details": {}, "errors": [] } with _lic_pg_conn(autocommit=False) as conn: try: # 1. 构建依赖关系图 dependencies = _get_table_dependencies(conn) all_tables = list(tables.keys()) # 2. 拓扑排序获取恢复顺序 restore_order = _topological_sort_tables(all_tables, dependencies) # 3. 锁定所有表(防止并发写入) cur = conn.cursor() for table in restore_order: cur.execute(f"LOCK TABLE {table} IN EXCLUSIVE MODE") # 4. 按依赖顺序恢复 for table_name in restore_order: data = tables.get(table_name, []) rows_restored = _restore_table_safe(conn, table_name, data) restore_summary["tables_restored"] += 1 restore_summary["total_rows_restored"] += rows_restored restore_summary["table_details"][table_name] = rows_restored # 5. 提交事务 conn.commit() return { "status": "success", "message": f"Successfully restored {restore_summary['tables_restored']} tables", "summary": restore_summary } except Exception as e: # 回滚事务 conn.rollback() return { "status": "error", "message": f"Restore failed: {str(e)}", "summary": restore_summary, "auto_backup_available": bool(auto_backup_info), "recovery_suggestion": ( f"Use auto-backup to restore: {auto_backup_info['checkpoint_id']}" if auto_backup_info else "No auto-backup available" ) } ``` ### 修复4: 改进 create_checkpoint() 函数 (第411-463行) ```python def create_checkpoint_safe(description: str = "") -> Dict[str, Any]: """安全创建checkpoint,单事务保护""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") checkpoint_id = f"checkpoint_{timestamp}" tables = _get_all_tables() checkpoint_data = { "checkpoint_id": checkpoint_id, "timestamp": timestamp, "description": description, "tables": {} } total_rows = 0 table_counts = {} with _lic_pg_conn() as conn: conn.autocommit = False try: for table in tables: data, row_count = _backup_table(conn, table) checkpoint_data["tables"][table] = data table_counts[table] = row_count total_rows += row_count conn.commit() # 全部成功或失败 except Exception as e: conn.rollback() raise e finally: conn.autocommit = False checkpoint_data["table_counts"] = table_counts checkpoint_data["total_rows"] = total_rows # 保存到文件 checkpoints_dir = _get_checkpoints_dir() checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json") def json_serializer(obj): try: import uuid if isinstance(obj, uuid.UUID): return str(obj) except ImportError: pass if hasattr(obj, 'isoformat'): return str(obj) raise TypeError(f"Object of type {type(obj)} is not JSON serializable") with open(checkpoint_file, "w", encoding="utf-8") as f: json.dump(checkpoint_data, f, ensure_ascii=False, indent=2, default=json_serializer) return { "checkpoint_id": checkpoint_id, "timestamp": timestamp, "description": description, "total_rows": total_rows, "table_counts": table_counts } ``` ### 修复5: 改进API路由 (lawrisk/api/v2.py) 修改 admin_restore_checkpoint 路由,添加警告: ```python @v2_bp.route('/admin/checkpoints//restore', methods=['POST']) def admin_restore_checkpoint(checkpoint_id): """ ⚠️ DANGEROUS OPERATION ⚠️ 恢复数据库从checkpoint。 此操作会: 1. 永久删除当前数据库中的所有数据 2. 从指定的checkpoint恢复数据 3. 如果失败,可能导致数据丢失 建议: - 确保已创建备份 - 在生产环境中使用时格外小心 - 考虑使用 create_auto_backup=true 参数 """ try: # 获取参数 if request.is_json: payload = request.get_json(silent=True) or {} else: payload = request.form.to_dict(flat=True) if request.form else {} create_auto_backup = str(payload.get("create_auto_backup", "true")).lower() in {"1", "true", "yes", "on"} # 执行恢复 restore_result = restore_checkpoint(checkpoint_id, create_auto_backup=create_auto_backup) if restore_result.get("status") == "success": return jsonify({ "success": True, "message": restore_result["message"], "data": { "checkpoint_id": restore_result["summary"]["checkpoint_id"], "tables_restored": restore_result["summary"]["tables_restored"], "total_rows": restore_result["summary"]["total_rows_restored"], "auto_backup": restore_result["summary"].get("auto_backup") } }) else: return jsonify({ "success": False, "message": restore_result["message"], "data": { "errors": restore_result["summary"].get("errors", []), "auto_backup_available": restore_result.get("auto_backup_available", False), "recovery_suggestion": restore_result.get("recovery_suggestion") } }), 500 except Exception as exc: print(f"admin_restore_checkpoint error: {exc}") return jsonify({ "success": False, "message": f"Restore failed: {str(exc)}", "data": {} }), 500 ``` --- ## 建议的测试用例 ### 测试1: 正常恢复流程 ```python def test_restore_checkpoint_normal(): # 1. 创建数据 # 2. 创建checkpoint # 3. 修改数据 # 4. 恢复checkpoint # 5. 验证数据恢复正确 pass ``` ### 测试2: 自动备份功能 ```python def test_auto_backup_before_restore(): # 1. 创建初始数据 # 2. 创建checkpoint A # 3. 修改数据 # 4. 创建checkpoint B # 5. 恢复checkpoint A (应该自动备份当前状态) # 6. 验证checkpoint B仍然存在 pass ``` ### 测试3: 并发写入保护 ```python def test_restore_with_concurrent_writes(): # 1. 创建checkpoint # 2. 启动后台线程持续写入 # 3. 恢复checkpoint # 4. 验证恢复没有数据竞争 pass ``` --- ## 总结 **立即要做的事情**: 1. ✅ 应用修复补丁到 `licensing_repo.py` 2. ✅ 更新API文档,添加安全警告 3. ✅ 创建测试用例 4. ✅ 在生产环境使用前充分测试 **不要做的事情**: - ❌ 不要在生产环境使用未修复的 restore_checkpoint - ❌ 不要在恢复过程中停止应用 - ❌ 不要依赖未验证的checkpoint文件 **检查清单**: - [ ] 依赖关系排序已实现 - [ ] 表锁已添加 - [ ] 自动备份已实现 - [ ] 事务保护已添加 - [ ] API警告已添加 - [ ] 测试已通过