import os from typing import Iterable, Tuple from flask import Flask, request, Response, jsonify from datetime import datetime def _get_bool(name: str, default: bool = False) -> bool: v = os.getenv(name) if v is None: return default return str(v).strip().lower() in {"1", "true", "yes", "on"} def _parse_allowed(origins_raw: str | None) -> list[str]: if not origins_raw: return ["*"] parts = [p.strip() for p in origins_raw.split(",") if p.strip()] return parts or ["*"] def _origin_matches(origin: str, allowed: Iterable[str], strict: bool) -> tuple[bool, str | None]: # strict=True: only exact match items that are not wildcards # strict=False: allow simple suffix matches: .example.com or *.example.com for pat in allowed: if pat == "*": if strict: # strict mode ignores wildcard continue return True, "*" if origin == pat: return True, pat if not strict: if pat.startswith("*."): suf = pat[1:] # .example.com if origin.endswith(suf): return True, pat if pat.startswith("."): if origin.endswith(pat): return True, pat return False, None def init_smart_cors(app: Flask) -> None: """Attach Smart CORS handlers to a Flask app. Env vars: - ALLOWED_ORIGINS: comma-separated, e.g. https://a.com,https://b.com or * - CORS_STRICT: true/false (default false) - CORS_DEBUG: true/false - NGINX_CORS_MODE: true/false (when Nginx sets Allow-Origin; app only supplements others) - CORS_MAX_AGE: seconds for preflight caching (default 86400) - CORS_EXPOSE_HEADERS: override exposed headers list """ allowed = _parse_allowed(os.getenv("ALLOWED_ORIGINS")) # 添加前端实际使用的域名 frontend_origins = [ "http://chinaweal.com.cn:8090", "http://www.chinaweal.com.cn:8090", "https://chinaweal.com.cn", "https://www.chinaweal.com.cn", "http://172.22.80.130:8000", # 从日志中看到的referer ] # 添加默认的本地开发源 default_origins = [ "http://localhost:3000", "http://localhost:8000", "http://localhost:8081", "http://127.0.0.1:3000", "http://127.0.0.1:8000", "http://127.0.0.1:8081" ] # 合并并去重 all_origins = list(set(allowed + frontend_origins + default_origins)) strict = _get_bool("CORS_STRICT", False) debug = _get_bool("CORS_DEBUG", False) nginx_mode = _get_bool("NGINX_CORS_MODE", False) max_age = os.getenv("CORS_MAX_AGE", "86400") expose_headers = os.getenv( "CORS_EXPOSE_HEADERS", "Content-Length, Content-Type, Authorization, X-Request-Id", ) def _is_proxy_request() -> bool: hdrs = request.headers return any(h in hdrs for h in ("X-Forwarded-For", "X-Forwarded-Proto", "X-Real-IP")) def _log(level: str, msg: str, **fields): if not debug and level == "DEBUG": return logger = app.logger kv = " ".join(f"{k}={v}" for k, v in fields.items()) if fields else "" logger.log( 20 if level == "INFO" else 10 if level == "DEBUG" else 30, f"[CORS] {msg} {kv}".rstrip(), ) @app.before_request def _handle_preflight(): # type: ignore if request.method != "OPTIONS": return None origin = request.headers.get("Origin", "") acrm = request.headers.get("Access-Control-Request-Method", "") acrh = request.headers.get("Access-Control-Request-Headers", "") allowed_ok, matched = _origin_matches(origin, all_origins, strict) if origin else (False, None) # Build a minimal preflight response resp = Response(status=204) # When Nginx handles Allow-Origin, do not override it if not nginx_mode: if allowed_ok: if matched == "*": # If behind proxy (likely Nginx will add AO), skip to avoid duplicates if not _is_proxy_request(): resp.headers["Access-Control-Allow-Origin"] = "*" else: _log("DEBUG", "proxy-skip-allow-origin(preflight)") else: resp.headers["Access-Control-Allow-Origin"] = origin resp.headers["Access-Control-Allow-Credentials"] = "true" resp.headers.add("Vary", "Origin") # Always supplement the rest resp.headers["Access-Control-Allow-Methods"] = acrm or "GET, POST, OPTIONS" if acrh: resp.headers["Access-Control-Allow-Headers"] = acrh else: resp.headers["Access-Control-Allow-Headers"] = ( "Content-Type, Authorization, X-Requested-With, X-Request-Id" ) resp.headers["Access-Control-Max-Age"] = max_age resp.headers["Access-Control-Expose-Headers"] = expose_headers decision = "deny" if not origin: decision = "no-origin" elif allowed_ok: decision = "allow-*" if matched == "*" else "allow-origin" resp.headers["X-CORS-Decision"] = f"preflight; {decision}; nginx_mode={nginx_mode}" _log( "INFO", "preflight", path=request.path, origin=origin, request_method=acrm or "", request_headers=acrh or "", allowed=allowed_ok, matched=matched or "", nginx_mode=nginx_mode, decision=decision, ) return resp @app.after_request def _add_cors_headers(response: Response): # type: ignore origin = request.headers.get("Origin", "") if not origin: return response allowed_ok, matched = _origin_matches(origin, all_origins, strict) # When Nginx sets Allow-Origin, we only supplement others to avoid duplicates if not nginx_mode: if allowed_ok: if matched == "*": # Do not set credentials when wildcard is used; if proxied, skip to avoid dup if not _is_proxy_request(): response.headers.setdefault("Access-Control-Allow-Origin", "*") else: _log("DEBUG", "proxy-skip-allow-origin(response)") else: response.headers.setdefault("Access-Control-Allow-Origin", origin) response.headers.setdefault("Access-Control-Allow-Credentials", "true") # Ensure caches vary by Origin when dynamic vary_val = response.headers.get("Vary") if not vary_val: response.headers["Vary"] = "Origin" elif "Origin" not in vary_val: response.headers["Vary"] = f"{vary_val}, Origin" # Expose common headers (safe to always include) response.headers.setdefault("Access-Control-Expose-Headers", expose_headers) # Add decision header and log decision = "deny" if allowed_ok: decision = "allow-*" if matched == "*" else "allow-origin" response.headers.setdefault( "X-CORS-Decision", f"response; {decision}; nginx_mode={nginx_mode}", ) _log( "INFO", "response", path=request.path, method=request.method, origin=origin, allowed=allowed_ok, matched=matched or "", nginx_mode=nginx_mode, decision=decision, ) return response # Diagnosis endpoint to aid verification and debugging @app.get("/api/cors-diagnosis") def cors_diagnosis(): # type: ignore origin = request.headers.get("Origin", "") ok, matched = _origin_matches(origin, all_origins, strict) if origin else (False, None) data = { "nginx_mode": nginx_mode, "detected_proxy": _is_proxy_request(), "allowed_origins": all_origins, "strict": strict, "request_origin": origin, "matched_rule": matched, "origin_allowed": ok, "notes": ( "Nginx handles Allow-Origin; app supplements other headers" if nginx_mode else "App sets CORS headers end-to-end" ), } return jsonify(data)