fs-lawrisk/PATCH_CHECKPOINT_SECURITY.md

397 lines
12 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 检查点功能安全补丁
## 修复清单
### ⚠️ 当前代码存在的安全问题
1. **数据永久丢失风险** - 使用 `TRUNCATE CASCADE` 无回退机制
2. **违反外键约束** - 恢复顺序不考虑依赖关系
3. **并发写入冲突** - 无表级锁保护
4. **部分失败不一致** - checkpoint创建无事务保护
---
## 立即修复方案
### 修复1: 添加表依赖排序 (lawrisk/services/licensing_repo.py)
`_get_all_tables()` 函数后添加:
```python
def _get_table_dependencies(conn: pg.Connection) -> Dict[str, List[str]]:
"""获取表依赖关系: {表名: [依赖该表的子表列表]}"""
sql = """
SELECT
ccu.table_name AS referenced_table,
tc.table_name AS dependent_table
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
JOIN information_schema.constraint_column_usage ccu
ON tc.constraint_name = ccu.constraint_name
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = 'public'
"""
cur = conn.cursor()
cur.execute(sql)
dependencies = {}
for row in cur.fetchall():
referenced_table, dependent_table = row
if referenced_table not in dependencies:
dependencies[referenced_table] = []
dependencies[referenced_table].append(dependent_table)
return dependencies
def _topological_sort_tables(all_tables: List[str], dependencies: Dict[str, List[str]]) -> List[str]:
"""拓扑排序: 先返回父表,再返回子表"""
from collections import deque
# 计算每个表的入度(被多少表依赖)
in_degree = {table: 0 for table in all_tables}
for parent_table, children in dependencies.items():
for child in children:
if child in all_tables:
in_degree[child] += 1
# 父表入度为0优先处理
queue = deque([table for table in all_tables if in_degree[table] == 0])
result = []
while queue:
table = queue.popleft()
result.append(table)
# 减少依赖该表的子表的入度
for parent, children in dependencies.items():
if parent == table:
for child in children:
if child in all_tables and in_degree[child] > 0:
in_degree[child] -= 1
if in_degree[child] == 0:
queue.append(child)
# 如果有环,添加剩余表
for table in all_tables:
if table not in result:
result.append(table)
return result
```
### 修复2: 改进 _restore_table() 函数 (第380-409行)
```python
def _restore_table_safe(conn: pg.Connection, table_name: str, data: List[Dict[str, Any]]) -> int:
"""安全恢复表,使用单事务"""
if not data:
return 0
conn.autocommit = False
try:
cur = conn.cursor()
# 获取列信息
first_row = data[0]
columns = list(first_row.keys())
placeholders = ", ".join(["%s"] * len(columns))
# TRUNCATE表 (使用CASCADE处理外键)
truncate_sql = f"TRUNCATE TABLE {table_name} CASCADE"
cur.execute(truncate_sql)
# 批量插入
insert_sql = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})"
for row in data:
values = [row.get(col) for col in columns]
cur.execute(insert_sql, values)
conn.commit()
return len(data)
except Exception as e:
conn.rollback()
raise e
finally:
conn.autocommit = False
```
### 修复3: 改进 restore_checkpoint() 函数 (第493-526行)
```python
def restore_checkpoint(checkpoint_id: str, create_auto_backup: bool = True) -> Dict[str, Any]:
"""
恢复数据库从checkpoint。
⚠️ 危险操作: 会永久删除现有数据!
Args:
checkpoint_id: 要恢复的checkpoint ID
create_auto_backup: 是否在恢复前自动备份当前状态
Returns:
恢复结果字典
"""
checkpoints_dir = _get_checkpoints_dir()
checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json")
if not os.path.exists(checkpoint_file):
raise ValueError(f"Checkpoint {checkpoint_id} not found")
with open(checkpoint_file, "r", encoding="utf-8") as f:
checkpoint_data = json.load(f)
# 自动备份当前状态(可选但强烈推荐)
auto_backup_info = None
if create_auto_backup:
auto_backup_info = create_checkpoint_safe(
f"auto_backup_before_restore_{checkpoint_id}"
)
tables = checkpoint_data.get("tables", {})
restore_summary = {
"checkpoint_id": checkpoint_id,
"auto_backup": auto_backup_info.get("checkpoint_id") if auto_backup_info else None,
"tables_restored": 0,
"total_rows_restored": 0,
"table_details": {},
"errors": []
}
with _lic_pg_conn(autocommit=False) as conn:
try:
# 1. 构建依赖关系图
dependencies = _get_table_dependencies(conn)
all_tables = list(tables.keys())
# 2. 拓扑排序获取恢复顺序
restore_order = _topological_sort_tables(all_tables, dependencies)
# 3. 锁定所有表(防止并发写入)
cur = conn.cursor()
for table in restore_order:
cur.execute(f"LOCK TABLE {table} IN EXCLUSIVE MODE")
# 4. 按依赖顺序恢复
for table_name in restore_order:
data = tables.get(table_name, [])
rows_restored = _restore_table_safe(conn, table_name, data)
restore_summary["tables_restored"] += 1
restore_summary["total_rows_restored"] += rows_restored
restore_summary["table_details"][table_name] = rows_restored
# 5. 提交事务
conn.commit()
return {
"status": "success",
"message": f"Successfully restored {restore_summary['tables_restored']} tables",
"summary": restore_summary
}
except Exception as e:
# 回滚事务
conn.rollback()
return {
"status": "error",
"message": f"Restore failed: {str(e)}",
"summary": restore_summary,
"auto_backup_available": bool(auto_backup_info),
"recovery_suggestion": (
f"Use auto-backup to restore: {auto_backup_info['checkpoint_id']}"
if auto_backup_info else "No auto-backup available"
)
}
```
### 修复4: 改进 create_checkpoint() 函数 (第411-463行)
```python
def create_checkpoint_safe(description: str = "") -> Dict[str, Any]:
"""安全创建checkpoint单事务保护"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_id = f"checkpoint_{timestamp}"
tables = _get_all_tables()
checkpoint_data = {
"checkpoint_id": checkpoint_id,
"timestamp": timestamp,
"description": description,
"tables": {}
}
total_rows = 0
table_counts = {}
with _lic_pg_conn() as conn:
conn.autocommit = False
try:
for table in tables:
data, row_count = _backup_table(conn, table)
checkpoint_data["tables"][table] = data
table_counts[table] = row_count
total_rows += row_count
conn.commit() # 全部成功或失败
except Exception as e:
conn.rollback()
raise e
finally:
conn.autocommit = False
checkpoint_data["table_counts"] = table_counts
checkpoint_data["total_rows"] = total_rows
# 保存到文件
checkpoints_dir = _get_checkpoints_dir()
checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json")
def json_serializer(obj):
try:
import uuid
if isinstance(obj, uuid.UUID):
return str(obj)
except ImportError:
pass
if hasattr(obj, 'isoformat'):
return str(obj)
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
with open(checkpoint_file, "w", encoding="utf-8") as f:
json.dump(checkpoint_data, f, ensure_ascii=False, indent=2, default=json_serializer)
return {
"checkpoint_id": checkpoint_id,
"timestamp": timestamp,
"description": description,
"total_rows": total_rows,
"table_counts": table_counts
}
```
### 修复5: 改进API路由 (lawrisk/api/v2.py)
修改 admin_restore_checkpoint 路由,添加警告:
```python
@v2_bp.route('/admin/checkpoints/<checkpoint_id>/restore', methods=['POST'])
def admin_restore_checkpoint(checkpoint_id):
"""
⚠️ DANGEROUS OPERATION ⚠️
恢复数据库从checkpoint。
此操作会:
1. 永久删除当前数据库中的所有数据
2. 从指定的checkpoint恢复数据
3. 如果失败,可能导致数据丢失
建议:
- 确保已创建备份
- 在生产环境中使用时格外小心
- 考虑使用 create_auto_backup=true 参数
"""
try:
# 获取参数
if request.is_json:
payload = request.get_json(silent=True) or {}
else:
payload = request.form.to_dict(flat=True) if request.form else {}
create_auto_backup = str(payload.get("create_auto_backup", "true")).lower() in {"1", "true", "yes", "on"}
# 执行恢复
restore_result = restore_checkpoint(checkpoint_id, create_auto_backup=create_auto_backup)
if restore_result.get("status") == "success":
return jsonify({
"success": True,
"message": restore_result["message"],
"data": {
"checkpoint_id": restore_result["summary"]["checkpoint_id"],
"tables_restored": restore_result["summary"]["tables_restored"],
"total_rows": restore_result["summary"]["total_rows_restored"],
"auto_backup": restore_result["summary"].get("auto_backup")
}
})
else:
return jsonify({
"success": False,
"message": restore_result["message"],
"data": {
"errors": restore_result["summary"].get("errors", []),
"auto_backup_available": restore_result.get("auto_backup_available", False),
"recovery_suggestion": restore_result.get("recovery_suggestion")
}
}), 500
except Exception as exc:
print(f"admin_restore_checkpoint error: {exc}")
return jsonify({
"success": False,
"message": f"Restore failed: {str(exc)}",
"data": {}
}), 500
```
---
## 建议的测试用例
### 测试1: 正常恢复流程
```python
def test_restore_checkpoint_normal():
# 1. 创建数据
# 2. 创建checkpoint
# 3. 修改数据
# 4. 恢复checkpoint
# 5. 验证数据恢复正确
pass
```
### 测试2: 自动备份功能
```python
def test_auto_backup_before_restore():
# 1. 创建初始数据
# 2. 创建checkpoint A
# 3. 修改数据
# 4. 创建checkpoint B
# 5. 恢复checkpoint A (应该自动备份当前状态)
# 6. 验证checkpoint B仍然存在
pass
```
### 测试3: 并发写入保护
```python
def test_restore_with_concurrent_writes():
# 1. 创建checkpoint
# 2. 启动后台线程持续写入
# 3. 恢复checkpoint
# 4. 验证恢复没有数据竞争
pass
```
---
## 总结
**立即要做的事情**:
1. ✅ 应用修复补丁到 `licensing_repo.py`
2. ✅ 更新API文档添加安全警告
3. ✅ 创建测试用例
4. ✅ 在生产环境使用前充分测试
**不要做的事情**:
- ❌ 不要在生产环境使用未修复的 restore_checkpoint
- ❌ 不要在恢复过程中停止应用
- ❌ 不要依赖未验证的checkpoint文件
**检查清单**:
- [ ] 依赖关系排序已实现
- [ ] 表锁已添加
- [ ] 自动备份已实现
- [ ] 事务保护已添加
- [ ] API警告已添加
- [ ] 测试已通过