397 lines
12 KiB
Markdown
397 lines
12 KiB
Markdown
|
|
# 检查点功能安全补丁
|
|||
|
|
|
|||
|
|
## 修复清单
|
|||
|
|
|
|||
|
|
### ⚠️ 当前代码存在的安全问题
|
|||
|
|
|
|||
|
|
1. **数据永久丢失风险** - 使用 `TRUNCATE CASCADE` 无回退机制
|
|||
|
|
2. **违反外键约束** - 恢复顺序不考虑依赖关系
|
|||
|
|
3. **并发写入冲突** - 无表级锁保护
|
|||
|
|
4. **部分失败不一致** - checkpoint创建无事务保护
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 立即修复方案
|
|||
|
|
|
|||
|
|
### 修复1: 添加表依赖排序 (lawrisk/services/licensing_repo.py)
|
|||
|
|
|
|||
|
|
在 `_get_all_tables()` 函数后添加:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
def _get_table_dependencies(conn: pg.Connection) -> Dict[str, List[str]]:
|
|||
|
|
"""获取表依赖关系: {表名: [依赖该表的子表列表]}"""
|
|||
|
|
sql = """
|
|||
|
|
SELECT
|
|||
|
|
ccu.table_name AS referenced_table,
|
|||
|
|
tc.table_name AS dependent_table
|
|||
|
|
FROM information_schema.table_constraints tc
|
|||
|
|
JOIN information_schema.key_column_usage kcu
|
|||
|
|
ON tc.constraint_name = kcu.constraint_name
|
|||
|
|
JOIN information_schema.constraint_column_usage ccu
|
|||
|
|
ON tc.constraint_name = ccu.constraint_name
|
|||
|
|
WHERE tc.constraint_type = 'FOREIGN KEY'
|
|||
|
|
AND tc.table_schema = 'public'
|
|||
|
|
"""
|
|||
|
|
cur = conn.cursor()
|
|||
|
|
cur.execute(sql)
|
|||
|
|
|
|||
|
|
dependencies = {}
|
|||
|
|
for row in cur.fetchall():
|
|||
|
|
referenced_table, dependent_table = row
|
|||
|
|
if referenced_table not in dependencies:
|
|||
|
|
dependencies[referenced_table] = []
|
|||
|
|
dependencies[referenced_table].append(dependent_table)
|
|||
|
|
|
|||
|
|
return dependencies
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _topological_sort_tables(all_tables: List[str], dependencies: Dict[str, List[str]]) -> List[str]:
|
|||
|
|
"""拓扑排序: 先返回父表,再返回子表"""
|
|||
|
|
from collections import deque
|
|||
|
|
|
|||
|
|
# 计算每个表的入度(被多少表依赖)
|
|||
|
|
in_degree = {table: 0 for table in all_tables}
|
|||
|
|
for parent_table, children in dependencies.items():
|
|||
|
|
for child in children:
|
|||
|
|
if child in all_tables:
|
|||
|
|
in_degree[child] += 1
|
|||
|
|
|
|||
|
|
# 父表入度为0,优先处理
|
|||
|
|
queue = deque([table for table in all_tables if in_degree[table] == 0])
|
|||
|
|
result = []
|
|||
|
|
|
|||
|
|
while queue:
|
|||
|
|
table = queue.popleft()
|
|||
|
|
result.append(table)
|
|||
|
|
|
|||
|
|
# 减少依赖该表的子表的入度
|
|||
|
|
for parent, children in dependencies.items():
|
|||
|
|
if parent == table:
|
|||
|
|
for child in children:
|
|||
|
|
if child in all_tables and in_degree[child] > 0:
|
|||
|
|
in_degree[child] -= 1
|
|||
|
|
if in_degree[child] == 0:
|
|||
|
|
queue.append(child)
|
|||
|
|
|
|||
|
|
# 如果有环,添加剩余表
|
|||
|
|
for table in all_tables:
|
|||
|
|
if table not in result:
|
|||
|
|
result.append(table)
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 修复2: 改进 _restore_table() 函数 (第380-409行)
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
def _restore_table_safe(conn: pg.Connection, table_name: str, data: List[Dict[str, Any]]) -> int:
|
|||
|
|
"""安全恢复表,使用单事务"""
|
|||
|
|
if not data:
|
|||
|
|
return 0
|
|||
|
|
|
|||
|
|
conn.autocommit = False
|
|||
|
|
try:
|
|||
|
|
cur = conn.cursor()
|
|||
|
|
|
|||
|
|
# 获取列信息
|
|||
|
|
first_row = data[0]
|
|||
|
|
columns = list(first_row.keys())
|
|||
|
|
placeholders = ", ".join(["%s"] * len(columns))
|
|||
|
|
|
|||
|
|
# TRUNCATE表 (使用CASCADE处理外键)
|
|||
|
|
truncate_sql = f"TRUNCATE TABLE {table_name} CASCADE"
|
|||
|
|
cur.execute(truncate_sql)
|
|||
|
|
|
|||
|
|
# 批量插入
|
|||
|
|
insert_sql = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})"
|
|||
|
|
for row in data:
|
|||
|
|
values = [row.get(col) for col in columns]
|
|||
|
|
cur.execute(insert_sql, values)
|
|||
|
|
|
|||
|
|
conn.commit()
|
|||
|
|
return len(data)
|
|||
|
|
except Exception as e:
|
|||
|
|
conn.rollback()
|
|||
|
|
raise e
|
|||
|
|
finally:
|
|||
|
|
conn.autocommit = False
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 修复3: 改进 restore_checkpoint() 函数 (第493-526行)
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
def restore_checkpoint(checkpoint_id: str, create_auto_backup: bool = True) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
恢复数据库从checkpoint。
|
|||
|
|
|
|||
|
|
⚠️ 危险操作: 会永久删除现有数据!
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
checkpoint_id: 要恢复的checkpoint ID
|
|||
|
|
create_auto_backup: 是否在恢复前自动备份当前状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
恢复结果字典
|
|||
|
|
"""
|
|||
|
|
checkpoints_dir = _get_checkpoints_dir()
|
|||
|
|
checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json")
|
|||
|
|
|
|||
|
|
if not os.path.exists(checkpoint_file):
|
|||
|
|
raise ValueError(f"Checkpoint {checkpoint_id} not found")
|
|||
|
|
|
|||
|
|
with open(checkpoint_file, "r", encoding="utf-8") as f:
|
|||
|
|
checkpoint_data = json.load(f)
|
|||
|
|
|
|||
|
|
# 自动备份当前状态(可选但强烈推荐)
|
|||
|
|
auto_backup_info = None
|
|||
|
|
if create_auto_backup:
|
|||
|
|
auto_backup_info = create_checkpoint_safe(
|
|||
|
|
f"auto_backup_before_restore_{checkpoint_id}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
tables = checkpoint_data.get("tables", {})
|
|||
|
|
restore_summary = {
|
|||
|
|
"checkpoint_id": checkpoint_id,
|
|||
|
|
"auto_backup": auto_backup_info.get("checkpoint_id") if auto_backup_info else None,
|
|||
|
|
"tables_restored": 0,
|
|||
|
|
"total_rows_restored": 0,
|
|||
|
|
"table_details": {},
|
|||
|
|
"errors": []
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
with _lic_pg_conn(autocommit=False) as conn:
|
|||
|
|
try:
|
|||
|
|
# 1. 构建依赖关系图
|
|||
|
|
dependencies = _get_table_dependencies(conn)
|
|||
|
|
all_tables = list(tables.keys())
|
|||
|
|
|
|||
|
|
# 2. 拓扑排序获取恢复顺序
|
|||
|
|
restore_order = _topological_sort_tables(all_tables, dependencies)
|
|||
|
|
|
|||
|
|
# 3. 锁定所有表(防止并发写入)
|
|||
|
|
cur = conn.cursor()
|
|||
|
|
for table in restore_order:
|
|||
|
|
cur.execute(f"LOCK TABLE {table} IN EXCLUSIVE MODE")
|
|||
|
|
|
|||
|
|
# 4. 按依赖顺序恢复
|
|||
|
|
for table_name in restore_order:
|
|||
|
|
data = tables.get(table_name, [])
|
|||
|
|
rows_restored = _restore_table_safe(conn, table_name, data)
|
|||
|
|
|
|||
|
|
restore_summary["tables_restored"] += 1
|
|||
|
|
restore_summary["total_rows_restored"] += rows_restored
|
|||
|
|
restore_summary["table_details"][table_name] = rows_restored
|
|||
|
|
|
|||
|
|
# 5. 提交事务
|
|||
|
|
conn.commit()
|
|||
|
|
return {
|
|||
|
|
"status": "success",
|
|||
|
|
"message": f"Successfully restored {restore_summary['tables_restored']} tables",
|
|||
|
|
"summary": restore_summary
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
# 回滚事务
|
|||
|
|
conn.rollback()
|
|||
|
|
return {
|
|||
|
|
"status": "error",
|
|||
|
|
"message": f"Restore failed: {str(e)}",
|
|||
|
|
"summary": restore_summary,
|
|||
|
|
"auto_backup_available": bool(auto_backup_info),
|
|||
|
|
"recovery_suggestion": (
|
|||
|
|
f"Use auto-backup to restore: {auto_backup_info['checkpoint_id']}"
|
|||
|
|
if auto_backup_info else "No auto-backup available"
|
|||
|
|
)
|
|||
|
|
}
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 修复4: 改进 create_checkpoint() 函数 (第411-463行)
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
def create_checkpoint_safe(description: str = "") -> Dict[str, Any]:
|
|||
|
|
"""安全创建checkpoint,单事务保护"""
|
|||
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|||
|
|
checkpoint_id = f"checkpoint_{timestamp}"
|
|||
|
|
|
|||
|
|
tables = _get_all_tables()
|
|||
|
|
checkpoint_data = {
|
|||
|
|
"checkpoint_id": checkpoint_id,
|
|||
|
|
"timestamp": timestamp,
|
|||
|
|
"description": description,
|
|||
|
|
"tables": {}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
total_rows = 0
|
|||
|
|
table_counts = {}
|
|||
|
|
|
|||
|
|
with _lic_pg_conn() as conn:
|
|||
|
|
conn.autocommit = False
|
|||
|
|
try:
|
|||
|
|
for table in tables:
|
|||
|
|
data, row_count = _backup_table(conn, table)
|
|||
|
|
checkpoint_data["tables"][table] = data
|
|||
|
|
table_counts[table] = row_count
|
|||
|
|
total_rows += row_count
|
|||
|
|
|
|||
|
|
conn.commit() # 全部成功或失败
|
|||
|
|
except Exception as e:
|
|||
|
|
conn.rollback()
|
|||
|
|
raise e
|
|||
|
|
finally:
|
|||
|
|
conn.autocommit = False
|
|||
|
|
|
|||
|
|
checkpoint_data["table_counts"] = table_counts
|
|||
|
|
checkpoint_data["total_rows"] = total_rows
|
|||
|
|
|
|||
|
|
# 保存到文件
|
|||
|
|
checkpoints_dir = _get_checkpoints_dir()
|
|||
|
|
checkpoint_file = os.path.join(checkpoints_dir, f"{checkpoint_id}.json")
|
|||
|
|
|
|||
|
|
def json_serializer(obj):
|
|||
|
|
try:
|
|||
|
|
import uuid
|
|||
|
|
if isinstance(obj, uuid.UUID):
|
|||
|
|
return str(obj)
|
|||
|
|
except ImportError:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
if hasattr(obj, 'isoformat'):
|
|||
|
|
return str(obj)
|
|||
|
|
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
|||
|
|
|
|||
|
|
with open(checkpoint_file, "w", encoding="utf-8") as f:
|
|||
|
|
json.dump(checkpoint_data, f, ensure_ascii=False, indent=2, default=json_serializer)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"checkpoint_id": checkpoint_id,
|
|||
|
|
"timestamp": timestamp,
|
|||
|
|
"description": description,
|
|||
|
|
"total_rows": total_rows,
|
|||
|
|
"table_counts": table_counts
|
|||
|
|
}
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 修复5: 改进API路由 (lawrisk/api/v2.py)
|
|||
|
|
|
|||
|
|
修改 admin_restore_checkpoint 路由,添加警告:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
@v2_bp.route('/admin/checkpoints/<checkpoint_id>/restore', methods=['POST'])
|
|||
|
|
def admin_restore_checkpoint(checkpoint_id):
|
|||
|
|
"""
|
|||
|
|
⚠️ DANGEROUS OPERATION ⚠️
|
|||
|
|
恢复数据库从checkpoint。
|
|||
|
|
|
|||
|
|
此操作会:
|
|||
|
|
1. 永久删除当前数据库中的所有数据
|
|||
|
|
2. 从指定的checkpoint恢复数据
|
|||
|
|
3. 如果失败,可能导致数据丢失
|
|||
|
|
|
|||
|
|
建议:
|
|||
|
|
- 确保已创建备份
|
|||
|
|
- 在生产环境中使用时格外小心
|
|||
|
|
- 考虑使用 create_auto_backup=true 参数
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 获取参数
|
|||
|
|
if request.is_json:
|
|||
|
|
payload = request.get_json(silent=True) or {}
|
|||
|
|
else:
|
|||
|
|
payload = request.form.to_dict(flat=True) if request.form else {}
|
|||
|
|
|
|||
|
|
create_auto_backup = str(payload.get("create_auto_backup", "true")).lower() in {"1", "true", "yes", "on"}
|
|||
|
|
|
|||
|
|
# 执行恢复
|
|||
|
|
restore_result = restore_checkpoint(checkpoint_id, create_auto_backup=create_auto_backup)
|
|||
|
|
|
|||
|
|
if restore_result.get("status") == "success":
|
|||
|
|
return jsonify({
|
|||
|
|
"success": True,
|
|||
|
|
"message": restore_result["message"],
|
|||
|
|
"data": {
|
|||
|
|
"checkpoint_id": restore_result["summary"]["checkpoint_id"],
|
|||
|
|
"tables_restored": restore_result["summary"]["tables_restored"],
|
|||
|
|
"total_rows": restore_result["summary"]["total_rows_restored"],
|
|||
|
|
"auto_backup": restore_result["summary"].get("auto_backup")
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
else:
|
|||
|
|
return jsonify({
|
|||
|
|
"success": False,
|
|||
|
|
"message": restore_result["message"],
|
|||
|
|
"data": {
|
|||
|
|
"errors": restore_result["summary"].get("errors", []),
|
|||
|
|
"auto_backup_available": restore_result.get("auto_backup_available", False),
|
|||
|
|
"recovery_suggestion": restore_result.get("recovery_suggestion")
|
|||
|
|
}
|
|||
|
|
}), 500
|
|||
|
|
|
|||
|
|
except Exception as exc:
|
|||
|
|
print(f"admin_restore_checkpoint error: {exc}")
|
|||
|
|
return jsonify({
|
|||
|
|
"success": False,
|
|||
|
|
"message": f"Restore failed: {str(exc)}",
|
|||
|
|
"data": {}
|
|||
|
|
}), 500
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 建议的测试用例
|
|||
|
|
|
|||
|
|
### 测试1: 正常恢复流程
|
|||
|
|
```python
|
|||
|
|
def test_restore_checkpoint_normal():
|
|||
|
|
# 1. 创建数据
|
|||
|
|
# 2. 创建checkpoint
|
|||
|
|
# 3. 修改数据
|
|||
|
|
# 4. 恢复checkpoint
|
|||
|
|
# 5. 验证数据恢复正确
|
|||
|
|
pass
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 测试2: 自动备份功能
|
|||
|
|
```python
|
|||
|
|
def test_auto_backup_before_restore():
|
|||
|
|
# 1. 创建初始数据
|
|||
|
|
# 2. 创建checkpoint A
|
|||
|
|
# 3. 修改数据
|
|||
|
|
# 4. 创建checkpoint B
|
|||
|
|
# 5. 恢复checkpoint A (应该自动备份当前状态)
|
|||
|
|
# 6. 验证checkpoint B仍然存在
|
|||
|
|
pass
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 测试3: 并发写入保护
|
|||
|
|
```python
|
|||
|
|
def test_restore_with_concurrent_writes():
|
|||
|
|
# 1. 创建checkpoint
|
|||
|
|
# 2. 启动后台线程持续写入
|
|||
|
|
# 3. 恢复checkpoint
|
|||
|
|
# 4. 验证恢复没有数据竞争
|
|||
|
|
pass
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 总结
|
|||
|
|
|
|||
|
|
**立即要做的事情**:
|
|||
|
|
1. ✅ 应用修复补丁到 `licensing_repo.py`
|
|||
|
|
2. ✅ 更新API文档,添加安全警告
|
|||
|
|
3. ✅ 创建测试用例
|
|||
|
|
4. ✅ 在生产环境使用前充分测试
|
|||
|
|
|
|||
|
|
**不要做的事情**:
|
|||
|
|
- ❌ 不要在生产环境使用未修复的 restore_checkpoint
|
|||
|
|
- ❌ 不要在恢复过程中停止应用
|
|||
|
|
- ❌ 不要依赖未验证的checkpoint文件
|
|||
|
|
|
|||
|
|
**检查清单**:
|
|||
|
|
- [ ] 依赖关系排序已实现
|
|||
|
|
- [ ] 表锁已添加
|
|||
|
|
- [ ] 自动备份已实现
|
|||
|
|
- [ ] 事务保护已添加
|
|||
|
|
- [ ] API警告已添加
|
|||
|
|
- [ ] 测试已通过
|