fs-lawrisk/PATCH_CHECKPOINT_SECURITY.md

12 KiB
Raw Blame History

检查点功能安全补丁

修复清单

⚠️ 当前代码存在的安全问题

  1. 数据永久丢失风险 - 使用 TRUNCATE CASCADE 无回退机制
  2. 违反外键约束 - 恢复顺序不考虑依赖关系
  3. 并发写入冲突 - 无表级锁保护
  4. 部分失败不一致 - checkpoint创建无事务保护

立即修复方案

修复1: 添加表依赖排序 (lawrisk/services/licensing_repo.py)

_get_all_tables() 函数后添加:

def _get_table_dependencies(conn: pg.Connection) -> Dict[str, List[str]]:
    """获取表依赖关系: {表名: [依赖该表的子表列表]}"""
    sql = """
        SELECT
            ccu.table_name AS referenced_table,
            tc.table_name AS dependent_table
        FROM information_schema.table_constraints tc
        JOIN information_schema.key_column_usage kcu
          ON tc.constraint_name = kcu.constraint_name
        JOIN information_schema.constraint_column_usage ccu
          ON tc.constraint_name = ccu.constraint_name
        WHERE tc.constraint_type = 'FOREIGN KEY'
          AND tc.table_schema = 'public'
    """
    cur = conn.cursor()
    cur.execute(sql)

    dependencies = {}
    for row in cur.fetchall():
        referenced_table, dependent_table = row
        if referenced_table not in dependencies:
            dependencies[referenced_table] = []
        dependencies[referenced_table].append(dependent_table)

    return dependencies


def _topological_sort_tables(all_tables: List[str], dependencies: Dict[str, List[str]]) -> List[str]:
    """拓扑排序: 先返回父表,再返回子表"""
    from collections import deque

    # 计算每个表的入度(被多少表依赖)
    in_degree = {table: 0 for table in all_tables}
    for parent_table, children in dependencies.items():
        for child in children:
            if child in all_tables:
                in_degree[child] += 1

    # 父表入度为0优先处理
    queue = deque([table for table in all_tables if in_degree[table] == 0])
    result = []

    while queue:
        table = queue.popleft()
        result.append(table)

        # 减少依赖该表的子表的入度
        for parent, children in dependencies.items():
            if parent == table:
                for child in children:
                    if child in all_tables and in_degree[child] > 0:
                        in_degree[child] -= 1
                        if in_degree[child] == 0:
                            queue.append(child)

    # 如果有环,添加剩余表
    for table in all_tables:
        if table not in result:
            result.append(table)

    return result

修复2: 改进 _restore_table() 函数 (第380-409行)

def _restore_table_safe(conn: pg.Connection, table_name: str, data: List[Dict[str, Any]]) -> int:
    """安全恢复表,使用单事务"""
    if not data:
        return 0

    conn.autocommit = False
    try:
        cur = conn.cursor()

        # 获取列信息
        first_row = data[0]
        columns = list(first_row.keys())
        placeholders = ", ".join(["%s"] * len(columns))

        # TRUNCATE表 (使用CASCADE处理外键)
        truncate_sql = f"TRUNCATE TABLE {table_name} CASCADE"
        cur.execute(truncate_sql)

        # 批量插入
        insert_sql = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})"
        for row in data:
            values = [row.get(col) for col in columns]
            cur.execute(insert_sql, values)

        conn.commit()
        return len(data)
    except Exception as e:
        conn.rollback()
        raise e
    finally:
        conn.autocommit = False

修复3: 改进 restore_checkpoint() 函数 (第493-526行)

def restore_checkpoint(checkpoint_id: str, create_auto_backup: bool = True) -> Dict[str, Any]:
    """
    恢复数据库从checkpoint。

    ⚠️ 危险操作: 会永久删除现有数据!

    Args:
        checkpoint_id: 要恢复的checkpoint ID
        create_auto_backup: 是否在恢复前自动备份当前状态

    Returns:
        恢复结果字典
    """
    checkpoints_dir = _get_checkpoints_dir()
    checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json")

    if not os.path.exists(checkpoint_file):
        raise ValueError(f"Checkpoint {checkpoint_id} not found")

    with open(checkpoint_file, "r", encoding="utf-8") as f:
        checkpoint_data = json.load(f)

    # 自动备份当前状态(可选但强烈推荐)
    auto_backup_info = None
    if create_auto_backup:
        auto_backup_info = create_checkpoint_safe(
            f"auto_backup_before_restore_{checkpoint_id}"
        )

    tables = checkpoint_data.get("tables", {})
    restore_summary = {
        "checkpoint_id": checkpoint_id,
        "auto_backup": auto_backup_info.get("checkpoint_id") if auto_backup_info else None,
        "tables_restored": 0,
        "total_rows_restored": 0,
        "table_details": {},
        "errors": []
    }

    with _lic_pg_conn(autocommit=False) as conn:
        try:
            # 1. 构建依赖关系图
            dependencies = _get_table_dependencies(conn)
            all_tables = list(tables.keys())

            # 2. 拓扑排序获取恢复顺序
            restore_order = _topological_sort_tables(all_tables, dependencies)

            # 3. 锁定所有表(防止并发写入)
            cur = conn.cursor()
            for table in restore_order:
                cur.execute(f"LOCK TABLE {table} IN EXCLUSIVE MODE")

            # 4. 按依赖顺序恢复
            for table_name in restore_order:
                data = tables.get(table_name, [])
                rows_restored = _restore_table_safe(conn, table_name, data)

                restore_summary["tables_restored"] += 1
                restore_summary["total_rows_restored"] += rows_restored
                restore_summary["table_details"][table_name] = rows_restored

            # 5. 提交事务
            conn.commit()
            return {
                "status": "success",
                "message": f"Successfully restored {restore_summary['tables_restored']} tables",
                "summary": restore_summary
            }

        except Exception as e:
            # 回滚事务
            conn.rollback()
            return {
                "status": "error",
                "message": f"Restore failed: {str(e)}",
                "summary": restore_summary,
                "auto_backup_available": bool(auto_backup_info),
                "recovery_suggestion": (
                    f"Use auto-backup to restore: {auto_backup_info['checkpoint_id']}"
                    if auto_backup_info else "No auto-backup available"
                )
            }

修复4: 改进 create_checkpoint() 函数 (第411-463行)

def create_checkpoint_safe(description: str = "") -> Dict[str, Any]:
    """安全创建checkpoint单事务保护"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    checkpoint_id = f"checkpoint_{timestamp}"

    tables = _get_all_tables()
    checkpoint_data = {
        "checkpoint_id": checkpoint_id,
        "timestamp": timestamp,
        "description": description,
        "tables": {}
    }

    total_rows = 0
    table_counts = {}

    with _lic_pg_conn() as conn:
        conn.autocommit = False
        try:
            for table in tables:
                data, row_count = _backup_table(conn, table)
                checkpoint_data["tables"][table] = data
                table_counts[table] = row_count
                total_rows += row_count

            conn.commit()  # 全部成功或失败
        except Exception as e:
            conn.rollback()
            raise e
        finally:
            conn.autocommit = False

    checkpoint_data["table_counts"] = table_counts
    checkpoint_data["total_rows"] = total_rows

    # 保存到文件
    checkpoints_dir = _get_checkpoints_dir()
    checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json")

    def json_serializer(obj):
        try:
            import uuid
            if isinstance(obj, uuid.UUID):
                return str(obj)
        except ImportError:
            pass

        if hasattr(obj, 'isoformat'):
            return str(obj)
        raise TypeError(f"Object of type {type(obj)} is not JSON serializable")

    with open(checkpoint_file, "w", encoding="utf-8") as f:
        json.dump(checkpoint_data, f, ensure_ascii=False, indent=2, default=json_serializer)

    return {
        "checkpoint_id": checkpoint_id,
        "timestamp": timestamp,
        "description": description,
        "total_rows": total_rows,
        "table_counts": table_counts
    }

修复5: 改进API路由 (lawrisk/api/v2.py)

修改 admin_restore_checkpoint 路由,添加警告:

@v2_bp.route('/admin/checkpoints/<checkpoint_id>/restore', methods=['POST'])
def admin_restore_checkpoint(checkpoint_id):
    """
    ⚠️ DANGEROUS OPERATION ⚠️
    恢复数据库从checkpoint。

    此操作会:
    1. 永久删除当前数据库中的所有数据
    2. 从指定的checkpoint恢复数据
    3. 如果失败,可能导致数据丢失

    建议:
    - 确保已创建备份
    - 在生产环境中使用时格外小心
    - 考虑使用 create_auto_backup=true 参数
    """
    try:
        # 获取参数
        if request.is_json:
            payload = request.get_json(silent=True) or {}
        else:
            payload = request.form.to_dict(flat=True) if request.form else {}

        create_auto_backup = str(payload.get("create_auto_backup", "true")).lower() in {"1", "true", "yes", "on"}

        # 执行恢复
        restore_result = restore_checkpoint(checkpoint_id, create_auto_backup=create_auto_backup)

        if restore_result.get("status") == "success":
            return jsonify({
                "success": True,
                "message": restore_result["message"],
                "data": {
                    "checkpoint_id": restore_result["summary"]["checkpoint_id"],
                    "tables_restored": restore_result["summary"]["tables_restored"],
                    "total_rows": restore_result["summary"]["total_rows_restored"],
                    "auto_backup": restore_result["summary"].get("auto_backup")
                }
            })
        else:
            return jsonify({
                "success": False,
                "message": restore_result["message"],
                "data": {
                    "errors": restore_result["summary"].get("errors", []),
                    "auto_backup_available": restore_result.get("auto_backup_available", False),
                    "recovery_suggestion": restore_result.get("recovery_suggestion")
                }
            }), 500

    except Exception as exc:
        print(f"admin_restore_checkpoint error: {exc}")
        return jsonify({
            "success": False,
            "message": f"Restore failed: {str(exc)}",
            "data": {}
        }), 500

建议的测试用例

测试1: 正常恢复流程

def test_restore_checkpoint_normal():
    # 1. 创建数据
    # 2. 创建checkpoint
    # 3. 修改数据
    # 4. 恢复checkpoint
    # 5. 验证数据恢复正确
    pass

测试2: 自动备份功能

def test_auto_backup_before_restore():
    # 1. 创建初始数据
    # 2. 创建checkpoint A
    # 3. 修改数据
    # 4. 创建checkpoint B
    # 5. 恢复checkpoint A (应该自动备份当前状态)
    # 6. 验证checkpoint B仍然存在
    pass

测试3: 并发写入保护

def test_restore_with_concurrent_writes():
    # 1. 创建checkpoint
    # 2. 启动后台线程持续写入
    # 3. 恢复checkpoint
    # 4. 验证恢复没有数据竞争
    pass

总结

立即要做的事情:

  1. 应用修复补丁到 licensing_repo.py
  2. 更新API文档添加安全警告
  3. 创建测试用例
  4. 在生产环境使用前充分测试

不要做的事情:

  • 不要在生产环境使用未修复的 restore_checkpoint
  • 不要在恢复过程中停止应用
  • 不要依赖未验证的checkpoint文件

检查清单:

  • 依赖关系排序已实现
  • 表锁已添加
  • 自动备份已实现
  • 事务保护已添加
  • API警告已添加
  • 测试已通过