diff --git a/src/main/java/com/chinaweal/youfool/devops/ai/controller/AIAnswerMCPController.java b/src/main/java/com/chinaweal/youfool/devops/ai/controller/AIAnswerMCPController.java index 80a3c8d..5dcc78a 100644 --- a/src/main/java/com/chinaweal/youfool/devops/ai/controller/AIAnswerMCPController.java +++ b/src/main/java/com/chinaweal/youfool/devops/ai/controller/AIAnswerMCPController.java @@ -65,11 +65,15 @@ public class AIAnswerMCPController { public RestResult generateAnswerWithMCP( @Valid @RequestBody AIAnswerRequest request) { try { - log.info("收到MCP版本AI回答请求: 工单={}, 会话={}", request.getRepairId(), request.getSessionId()); + // 敏感信息脱敏的日志记录 + log.info("收到MCP版本AI回答请求: 工单ID存在={}, 会话ID存在={}", + request.getRepairId() != null, request.getSessionId() != null); - // 验证请求参数 - if (request.getRepairId() == null || request.getRepairId().trim().isEmpty()) { - return RestResult.error("工单ID不能为空"); + // 严格的参数验证 + String validationError = validateAIAnswerRequest(request); + if (validationError != null) { + log.warn("MCP请求参数验证失败: {}", validationError); + return RestResult.error(validationError); } AIAnswerResponse response = aiAnswerServiceMCP.generateAnswerWithMCP(request); @@ -282,4 +286,80 @@ public class AIAnswerMCPController { public String getAnalysis() { return analysis; } public void setAnalysis(String analysis) { this.analysis = analysis; } } + + /** + * 验证AI回答请求参数 + */ + private String validateAIAnswerRequest(AIAnswerRequest request) { + if (request == null) { + return "请求参数不能为空"; + } + + // 验证工单ID + if (request.getRepairId() == null || request.getRepairId().trim().isEmpty()) { + return "工单ID不能为空"; + } + + String repairId = request.getRepairId().trim(); + if (repairId.length() > 50) { + return "工单ID长度不能超过50个字符"; + } + + // 检查工单ID是否包含危险字符 + if (containsDangerousCharacters(repairId)) { + return "工单ID包含非法字符"; + } + + // 验证会话ID(如果提供) + if (request.getSessionId() != null) { + String sessionId = request.getSessionId().trim(); + if (sessionId.length() > 100) { + return "会话ID长度不能超过100个字符"; + } + if (containsDangerousCharacters(sessionId)) { + return "会话ID包含非法字符"; + } + } + + // 验证温度参数 + if (request.getTemperature() != null) { + if (request.getTemperature() < 0.0 || request.getTemperature() > 2.0) { + return "温度参数必须在0.0-2.0之间"; + } + } + + // 验证最大Token数 + if (request.getMaxTokens() != null) { + if (request.getMaxTokens() <= 0 || request.getMaxTokens() > 8000) { + return "最大Token数必须在1-8000之间"; + } + } + + return null; // 验证通过 + } + + /** + * 检查字符串是否包含危险字符 + */ + private boolean containsDangerousCharacters(String input) { + if (input == null) { + return false; + } + + String lowerCase = input.toLowerCase(); + String[] dangerousPatterns = { + "", "javascript:", "onclick=", "onerror=", + "onload=", "alert(", "eval(", "document.cookie", + "'", "\"", ";", "--", "/*", "*/", + "select ", "insert ", "update ", "delete ", "drop ", "union " + }; + + for (String pattern : dangerousPatterns) { + if (lowerCase.contains(pattern)) { + return true; + } + } + + return false; + } } \ No newline at end of file diff --git a/src/main/java/com/chinaweal/youfool/devops/ai/mcp/MCPServer.java b/src/main/java/com/chinaweal/youfool/devops/ai/mcp/MCPServer.java index 456c4fb..ec81c3b 100644 --- a/src/main/java/com/chinaweal/youfool/devops/ai/mcp/MCPServer.java +++ b/src/main/java/com/chinaweal/youfool/devops/ai/mcp/MCPServer.java @@ -8,6 +8,7 @@ import com.chinaweal.youfool.devops.repair.entity.Repair; import com.chinaweal.youfool.devops.repair.entity.RepairHandle; import com.chinaweal.youfool.devops.repair.mapper.RepairHandleMapper; import com.chinaweal.youfool.devops.repair.service.IRepairService; +import com.chinaweal.youfool.devops.ai.util.SecurityValidationUtils; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.fasterxml.jackson.databind.ObjectMapper; import lombok.RequiredArgsConstructor; @@ -15,6 +16,7 @@ import lombok.extern.slf4j.Slf4j; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import java.util.*; import java.util.stream.Collectors; @@ -108,7 +110,13 @@ public class MCPServer { */ public MCPResponse executeTool(String toolName, Map arguments) { try { - log.info("执行MCP工具调用: {}, 参数: {}", toolName, arguments); + // 敏感信息脱敏的日志记录 + log.info("执行MCP工具调用: 工具={}, 参数数量={}", toolName, arguments.size()); + + // 工具名称验证 + if (toolName == null || toolName.trim().isEmpty()) { + return MCPResponse.error("工具名称不能为空"); + } switch (toolName) { case "repair_query": @@ -120,12 +128,14 @@ public class MCPServer { case "knowledge_query": return handleKnowledgeQuery(arguments); default: - return MCPResponse.error("未知的工具: " + toolName); + log.warn("尝试调用未知的MCP工具: {}", toolName); + return MCPResponse.error("不支持的工具类型"); } } catch (Exception e) { - log.error("MCP工具调用失败: {}", e.getMessage(), e); - return MCPResponse.error("工具调用失败: " + e.getMessage()); + log.error("MCP工具调用异常: 工具={}, 错误类型={}", toolName, e.getClass().getSimpleName()); + // 不暴露详细的异常信息给客户端 + return MCPResponse.error("工具调用失败,请稍后重试"); } } @@ -134,14 +144,17 @@ public class MCPServer { */ private MCPResponse handleRepairQuery(Map arguments) { String repairId = (String) arguments.get("repairId"); - if (repairId == null || repairId.trim().isEmpty()) { - return MCPResponse.error("工单ID不能为空"); + + // 严格的输入验证 + if (!SecurityValidationUtils.isValidRepairId(repairId)) { + return MCPResponse.error("工单ID格式无效"); } try { Repair repair = repairService.getById(repairId); if (repair == null) { - return MCPResponse.error("工单不存在: " + repairId); + log.warn("查询不存在的工单: {}", repairId); + return MCPResponse.error("工单不存在"); } Map result = new HashMap<>(); @@ -156,12 +169,12 @@ public class MCPServer { result.put("createTime", repair.getCreateTime()); result.put("launchTime", repair.getLaunchTime()); - log.info("MCP工具 repair_query 成功返回工单: {}", repairId); + log.info("MCP工具 repair_query 执行成功"); return MCPResponse.success(result); } catch (Exception e) { - log.error("查询工单失败: {}", repairId, e); - return MCPResponse.error("查询工单失败: " + e.getMessage()); + log.error("查询工单异常: 工单ID={}, 错误类型={}", repairId, e.getClass().getSimpleName()); + return MCPResponse.error("查询工单失败,请稍后重试"); } } @@ -170,8 +183,10 @@ public class MCPServer { */ private MCPResponse handleFeedbackQuery(Map arguments) { String repairId = (String) arguments.get("repairId"); - if (repairId == null || repairId.trim().isEmpty()) { - return MCPResponse.error("工单ID不能为空"); + + // 严格的输入验证 + if (!SecurityValidationUtils.isValidRepairId(repairId)) { + return MCPResponse.error("工单ID格式无效"); } try { @@ -204,25 +219,35 @@ public class MCPServer { response.put("latestFeedback", results.get(0)); } - log.info("MCP工具 repair_feedback_query 成功返回 {} 条feedback记录", results.size()); + log.info("MCP工具 repair_feedback_query 执行成功, 返回记录数: {}", results.size()); return MCPResponse.success(response); } catch (Exception e) { - log.error("查询feedback失败: {}", repairId, e); - return MCPResponse.error("查询feedback失败: " + e.getMessage()); + log.error("查询feedback异常: 工单ID={}, 错误类型={}", repairId, e.getClass().getSimpleName()); + return MCPResponse.error("查询feedback失败,请稍后重试"); } } /** * 处理相似度检索 */ + @Transactional(readOnly = true) private MCPResponse handleSimilaritySearch(Map arguments) { String queryText = (String) arguments.get("queryText"); Integer topK = (Integer) arguments.getOrDefault("topK", 5); Double threshold = ((Number) arguments.getOrDefault("threshold", 0.7)).doubleValue(); - if (queryText == null || queryText.trim().isEmpty()) { - return MCPResponse.error("查询文本不能为空"); + // 严格的输入验证 + if (!SecurityValidationUtils.isValidQueryText(queryText)) { + return MCPResponse.error("查询文本格式无效或过长"); + } + + if (!SecurityValidationUtils.isValidNumberRange(topK, 1, 50, "topK")) { + return MCPResponse.error("topK参数必须在1-50之间"); + } + + if (!SecurityValidationUtils.isValidNumberRange(threshold, 0.0, 1.0, "threshold")) { + return MCPResponse.error("阈值参数必须在0-1之间"); } try { @@ -363,4 +388,5 @@ public class MCPServer { return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); } + } \ No newline at end of file diff --git a/src/main/java/com/chinaweal/youfool/devops/ai/service/QwenChatService.java b/src/main/java/com/chinaweal/youfool/devops/ai/service/QwenChatService.java index 11fd3ea..96f9c07 100644 --- a/src/main/java/com/chinaweal/youfool/devops/ai/service/QwenChatService.java +++ b/src/main/java/com/chinaweal/youfool/devops/ai/service/QwenChatService.java @@ -659,7 +659,8 @@ public class QwenChatService { */ private ChatRequest processMCPToolCalls(ChatRequest request, MCPServer mcpServer, List usedMcpTools) { try { - log.info("开始MCP工具调用处理,会话: {}", request.getSessionId()); + // 敏感信息脱敏的日志记录 + log.info("开始MCP工具调用处理,会话ID存在: {}", request.getSessionId() != null); // 基于用户请求分析需要调用的工具 String userMessage = extractUserMessage(request); @@ -669,6 +670,12 @@ public class QwenChatService { log.warn("无法从请求中提取工单ID,跳过MCP工具调用"); return request; } + + // 验证提取的工单ID + if (!isValidRepairId(repairId)) { + log.warn("提取的工单ID格式无效,跳过MCP工具调用"); + return request; + } StringBuilder mcpResults = new StringBuilder(); mcpResults.append("=== MCP工具调用结果 ===\n\n"); @@ -973,4 +980,36 @@ public class QwenChatService { return ""; } } + + /** + * 验证工单ID格式(复用MCPServer的验证逻辑) + */ + private boolean isValidRepairId(String repairId) { + if (repairId == null || repairId.trim().isEmpty()) { + return false; + } + + repairId = repairId.trim(); + + // 检查长度限制 + if (repairId.length() < 1 || repairId.length() > 50) { + return false; + } + + // 检查是否包含危险字符 + String[] dangerousPatterns = { + "'", "\"", ";", "--", "/*", "*/", "xp_", "sp_", + "exec", "execute", "select", "insert", "update", "delete", + "union", "script", "<", ">" + }; + + String lowerCaseId = repairId.toLowerCase(); + for (String pattern : dangerousPatterns) { + if (lowerCaseId.contains(pattern)) { + return false; + } + } + + return true; + } } \ No newline at end of file diff --git a/src/main/java/com/chinaweal/youfool/devops/ai/util/SecurityValidationUtils.java b/src/main/java/com/chinaweal/youfool/devops/ai/util/SecurityValidationUtils.java new file mode 100644 index 0000000..3d8a867 --- /dev/null +++ b/src/main/java/com/chinaweal/youfool/devops/ai/util/SecurityValidationUtils.java @@ -0,0 +1,266 @@ +package com.chinaweal.youfool.devops.ai.util; + +import lombok.extern.slf4j.Slf4j; + +import java.util.regex.Pattern; + +/** + * 安全验证工具类 + * + * 提供统一的输入验证和安全检查功能 + * + * @author AI开发团队 + * @since 1.0.0 + */ +@Slf4j +public class SecurityValidationUtils { + + // 危险字符模式(SQL注入防护) + private static final String[] SQL_INJECTION_PATTERNS = { + "'", "\"", ";", "--", "/*", "*/", "xp_", "sp_", + "exec", "execute", "select", "insert", "update", "delete", + "drop", "union", "alter", "create", "truncate" + }; + + // XSS攻击模式 + private static final String[] XSS_PATTERNS = { + "", "javascript:", "onclick=", "onerror=", + "onload=", "onmouseover=", "onfocus=", "onblur=", "onchange=", + "alert(", "eval(", "document.cookie", "window.location", "document.write" + }; + + // 工单ID格式正则(允许字母、数字、下划线、短横线) + private static final Pattern REPAIR_ID_PATTERN = Pattern.compile("^[a-zA-Z0-9_-]+$"); + + // 会话ID格式正则(允许字母、数字、短横线、下划线) + private static final Pattern SESSION_ID_PATTERN = Pattern.compile("^[a-zA-Z0-9_-]+$"); + + /** + * 验证工单ID格式 + * + * @param repairId 工单ID + * @return 是否有效 + */ + public static boolean isValidRepairId(String repairId) { + if (repairId == null || repairId.trim().isEmpty()) { + return false; + } + + String trimmedId = repairId.trim(); + + // 检查长度限制(1-50个字符) + if (trimmedId.length() < 1 || trimmedId.length() > 50) { + log.warn("工单ID长度无效: {}", trimmedId.length()); + return false; + } + + // 检查格式(只允许字母数字下划线短横线) + if (!REPAIR_ID_PATTERN.matcher(trimmedId).matches()) { + log.warn("工单ID格式无效: 包含非法字符"); + return false; + } + + // 检查是否包含SQL注入模式 + if (containsSQLInjectionPattern(trimmedId)) { + log.warn("工单ID包含疑似SQL注入模式"); + return false; + } + + return true; + } + + /** + * 验证会话ID格式 + * + * @param sessionId 会话ID + * @return 是否有效 + */ + public static boolean isValidSessionId(String sessionId) { + if (sessionId == null || sessionId.trim().isEmpty()) { + return true; // 会话ID是可选的 + } + + String trimmedId = sessionId.trim(); + + // 检查长度限制(1-100个字符) + if (trimmedId.length() > 100) { + log.warn("会话ID长度超限: {}", trimmedId.length()); + return false; + } + + // 检查格式 + if (!SESSION_ID_PATTERN.matcher(trimmedId).matches()) { + log.warn("会话ID格式无效: 包含非法字符"); + return false; + } + + return true; + } + + /** + * 验证查询文本格式 + * + * @param queryText 查询文本 + * @return 是否有效 + */ + public static boolean isValidQueryText(String queryText) { + if (queryText == null || queryText.trim().isEmpty()) { + return false; + } + + String trimmedText = queryText.trim(); + + // 检查长度限制(最大2000字符) + if (trimmedText.length() > 2000) { + log.warn("查询文本长度超限: {}", trimmedText.length()); + return false; + } + + // 检查是否包含XSS攻击模式 + if (containsXSSPattern(trimmedText)) { + log.warn("查询文本包含疑似XSS攻击模式"); + return false; + } + + // 检查是否包含SQL注入模式 + if (containsSQLInjectionPattern(trimmedText)) { + log.warn("查询文本包含疑似SQL注入模式"); + return false; + } + + return true; + } + + /** + * 验证数值参数范围 + * + * @param value 数值 + * @param min 最小值 + * @param max 最大值 + * @param paramName 参数名称(用于日志) + * @return 是否有效 + */ + public static boolean isValidNumberRange(Number value, Number min, Number max, String paramName) { + if (value == null) { + return true; // 可选参数 + } + + double doubleValue = value.doubleValue(); + double minValue = min.doubleValue(); + double maxValue = max.doubleValue(); + + if (doubleValue < minValue || doubleValue > maxValue) { + log.warn("参数{}值超出范围: {}, 有效范围: [{}, {}]", paramName, doubleValue, minValue, maxValue); + return false; + } + + return true; + } + + /** + * 清理文本内容(移除潜在的危险字符) + * + * @param input 输入文本 + * @return 清理后的文本 + */ + public static String sanitizeText(String input) { + if (input == null) { + return null; + } + + String cleaned = input.trim(); + + // 移除HTML标签 + cleaned = cleaned.replaceAll("<[^>]*>", ""); + + // 移除JavaScript事件处理器 + cleaned = cleaned.replaceAll("(?i)on\\w+\\s*=", ""); + + // 移除javascript协议 + cleaned = cleaned.replaceAll("(?i)javascript:", ""); + + // 限制长度 + if (cleaned.length() > 2000) { + cleaned = cleaned.substring(0, 2000) + "..."; + } + + return cleaned; + } + + /** + * 检查是否包含SQL注入模式 + */ + private static boolean containsSQLInjectionPattern(String input) { + String lowerCase = input.toLowerCase(); + + for (String pattern : SQL_INJECTION_PATTERNS) { + if (lowerCase.contains(pattern)) { + return true; + } + } + + return false; + } + + /** + * 检查是否包含XSS攻击模式 + */ + private static boolean containsXSSPattern(String input) { + String lowerCase = input.toLowerCase(); + + for (String pattern : XSS_PATTERNS) { + if (lowerCase.contains(pattern)) { + return true; + } + } + + return false; + } + + /** + * 生成安全的错误消息(不暴露系统内部信息) + * + * @param internalError 内部错误信息 + * @param userFriendlyMessage 用户友好的错误信息 + * @return 安全的错误消息 + */ + public static String createSafeErrorMessage(String internalError, String userFriendlyMessage) { + // 记录详细的内部错误用于调试 + log.error("内部错误: {}", internalError); + + // 返回用户友好的错误信息 + return userFriendlyMessage; + } + + /** + * 验证文件名安全性 + * + * @param filename 文件名 + * @return 是否安全 + */ + public static boolean isValidFilename(String filename) { + if (filename == null || filename.trim().isEmpty()) { + return false; + } + + String trimmedName = filename.trim(); + + // 检查长度 + if (trimmedName.length() > 255) { + return false; + } + + // 检查是否包含路径遍历字符 + String[] dangerousPatterns = { + "..", "/", "\\", ":", "*", "?", "\"", "<", ">", "|" + }; + + for (String pattern : dangerousPatterns) { + if (trimmedName.contains(pattern)) { + return false; + } + } + + return true; + } +} \ No newline at end of file