fs-lawrisk/lawrisk/middleware/smart_cors_middleware.py

232 lines
8.4 KiB
Python
Raw Normal View History

2025-10-22 19:59:48 +08:00
import os
from typing import Iterable, Tuple
from flask import Flask, request, Response, jsonify
from datetime import datetime
def _get_bool(name: str, default: bool = False) -> bool:
v = os.getenv(name)
if v is None:
return default
return str(v).strip().lower() in {"1", "true", "yes", "on"}
def _parse_allowed(origins_raw: str | None) -> list[str]:
if not origins_raw:
return ["*"]
parts = [p.strip() for p in origins_raw.split(",") if p.strip()]
return parts or ["*"]
def _origin_matches(origin: str, allowed: Iterable[str], strict: bool) -> tuple[bool, str | None]:
# strict=True: only exact match items that are not wildcards
# strict=False: allow simple suffix matches: .example.com or *.example.com
for pat in allowed:
if pat == "*":
if strict:
# strict mode ignores wildcard
continue
return True, "*"
if origin == pat:
return True, pat
if not strict:
if pat.startswith("*."):
suf = pat[1:] # .example.com
if origin.endswith(suf):
return True, pat
if pat.startswith("."):
if origin.endswith(pat):
return True, pat
return False, None
def init_smart_cors(app: Flask) -> None:
"""Attach Smart CORS handlers to a Flask app.
Env vars:
- ALLOWED_ORIGINS: comma-separated, e.g. https://a.com,https://b.com or *
- CORS_STRICT: true/false (default false)
- CORS_DEBUG: true/false
- NGINX_CORS_MODE: true/false (when Nginx sets Allow-Origin; app only supplements others)
- CORS_MAX_AGE: seconds for preflight caching (default 86400)
- CORS_EXPOSE_HEADERS: override exposed headers list
"""
allowed = _parse_allowed(os.getenv("ALLOWED_ORIGINS"))
# 添加前端实际使用的域名
frontend_origins = [
"http://chinaweal.com.cn:8090",
"http://www.chinaweal.com.cn:8090",
"https://chinaweal.com.cn",
"https://www.chinaweal.com.cn",
"http://172.22.80.130:8000", # 从日志中看到的referer
]
# 添加默认的本地开发源
default_origins = [
"http://localhost:3000",
"http://localhost:8000",
"http://localhost:8081",
"http://127.0.0.1:3000",
"http://127.0.0.1:8000",
"http://127.0.0.1:8081"
]
# 合并并去重
all_origins = list(set(allowed + frontend_origins + default_origins))
strict = _get_bool("CORS_STRICT", False)
debug = _get_bool("CORS_DEBUG", False)
nginx_mode = _get_bool("NGINX_CORS_MODE", False)
max_age = os.getenv("CORS_MAX_AGE", "86400")
expose_headers = os.getenv(
"CORS_EXPOSE_HEADERS",
"Content-Length, Content-Type, Authorization, X-Request-Id",
)
def _is_proxy_request() -> bool:
hdrs = request.headers
return any(h in hdrs for h in ("X-Forwarded-For", "X-Forwarded-Proto", "X-Real-IP"))
def _log(level: str, msg: str, **fields):
if not debug and level == "DEBUG":
return
logger = app.logger
kv = " ".join(f"{k}={v}" for k, v in fields.items()) if fields else ""
logger.log(
20 if level == "INFO" else 10 if level == "DEBUG" else 30,
f"[CORS] {msg} {kv}".rstrip(),
)
@app.before_request
def _handle_preflight(): # type: ignore
if request.method != "OPTIONS":
return None
origin = request.headers.get("Origin", "")
acrm = request.headers.get("Access-Control-Request-Method", "")
acrh = request.headers.get("Access-Control-Request-Headers", "")
allowed_ok, matched = _origin_matches(origin, all_origins, strict) if origin else (False, None)
# Build a minimal preflight response
resp = Response(status=204)
# When Nginx handles Allow-Origin, do not override it
if not nginx_mode:
if allowed_ok:
if matched == "*":
# If behind proxy (likely Nginx will add AO), skip to avoid duplicates
if not _is_proxy_request():
resp.headers["Access-Control-Allow-Origin"] = "*"
else:
_log("DEBUG", "proxy-skip-allow-origin(preflight)")
else:
resp.headers["Access-Control-Allow-Origin"] = origin
resp.headers["Access-Control-Allow-Credentials"] = "true"
resp.headers.add("Vary", "Origin")
# Always supplement the rest
resp.headers["Access-Control-Allow-Methods"] = acrm or "GET, POST, OPTIONS"
if acrh:
resp.headers["Access-Control-Allow-Headers"] = acrh
else:
resp.headers["Access-Control-Allow-Headers"] = (
"Content-Type, Authorization, X-Requested-With, X-Request-Id"
)
resp.headers["Access-Control-Max-Age"] = max_age
resp.headers["Access-Control-Expose-Headers"] = expose_headers
decision = "deny"
if not origin:
decision = "no-origin"
elif allowed_ok:
decision = "allow-*" if matched == "*" else "allow-origin"
resp.headers["X-CORS-Decision"] = f"preflight; {decision}; nginx_mode={nginx_mode}"
_log(
"INFO",
"preflight",
path=request.path,
origin=origin,
request_method=acrm or "",
request_headers=acrh or "",
allowed=allowed_ok,
matched=matched or "",
nginx_mode=nginx_mode,
decision=decision,
)
return resp
@app.after_request
def _add_cors_headers(response: Response): # type: ignore
origin = request.headers.get("Origin", "")
if not origin:
return response
allowed_ok, matched = _origin_matches(origin, all_origins, strict)
# When Nginx sets Allow-Origin, we only supplement others to avoid duplicates
if not nginx_mode:
if allowed_ok:
if matched == "*":
# Do not set credentials when wildcard is used; if proxied, skip to avoid dup
if not _is_proxy_request():
response.headers.setdefault("Access-Control-Allow-Origin", "*")
else:
_log("DEBUG", "proxy-skip-allow-origin(response)")
else:
response.headers.setdefault("Access-Control-Allow-Origin", origin)
response.headers.setdefault("Access-Control-Allow-Credentials", "true")
# Ensure caches vary by Origin when dynamic
vary_val = response.headers.get("Vary")
if not vary_val:
response.headers["Vary"] = "Origin"
elif "Origin" not in vary_val:
response.headers["Vary"] = f"{vary_val}, Origin"
# Expose common headers (safe to always include)
response.headers.setdefault("Access-Control-Expose-Headers", expose_headers)
# Add decision header and log
decision = "deny"
if allowed_ok:
decision = "allow-*" if matched == "*" else "allow-origin"
response.headers.setdefault(
"X-CORS-Decision",
f"response; {decision}; nginx_mode={nginx_mode}",
)
_log(
"INFO",
"response",
path=request.path,
method=request.method,
origin=origin,
allowed=allowed_ok,
matched=matched or "",
nginx_mode=nginx_mode,
decision=decision,
)
return response
# Diagnosis endpoint to aid verification and debugging
@app.get("/api/cors-diagnosis")
def cors_diagnosis(): # type: ignore
origin = request.headers.get("Origin", "")
ok, matched = _origin_matches(origin, all_origins, strict) if origin else (False, None)
data = {
"nginx_mode": nginx_mode,
"detected_proxy": _is_proxy_request(),
"allowed_origins": all_origins,
"strict": strict,
"request_origin": origin,
"matched_rule": matched,
"origin_allowed": ok,
"notes": (
"Nginx handles Allow-Origin; app supplements other headers"
if nginx_mode
else "App sets CORS headers end-to-end"
),
}
return jsonify(data)