232 lines
8.4 KiB
Python
232 lines
8.4 KiB
Python
|
|
import os
|
||
|
|
from typing import Iterable, Tuple
|
||
|
|
|
||
|
|
from flask import Flask, request, Response, jsonify
|
||
|
|
from datetime import datetime
|
||
|
|
|
||
|
|
|
||
|
|
def _get_bool(name: str, default: bool = False) -> bool:
|
||
|
|
v = os.getenv(name)
|
||
|
|
if v is None:
|
||
|
|
return default
|
||
|
|
return str(v).strip().lower() in {"1", "true", "yes", "on"}
|
||
|
|
|
||
|
|
|
||
|
|
def _parse_allowed(origins_raw: str | None) -> list[str]:
|
||
|
|
if not origins_raw:
|
||
|
|
return ["*"]
|
||
|
|
parts = [p.strip() for p in origins_raw.split(",") if p.strip()]
|
||
|
|
return parts or ["*"]
|
||
|
|
|
||
|
|
|
||
|
|
def _origin_matches(origin: str, allowed: Iterable[str], strict: bool) -> tuple[bool, str | None]:
|
||
|
|
# strict=True: only exact match items that are not wildcards
|
||
|
|
# strict=False: allow simple suffix matches: .example.com or *.example.com
|
||
|
|
for pat in allowed:
|
||
|
|
if pat == "*":
|
||
|
|
if strict:
|
||
|
|
# strict mode ignores wildcard
|
||
|
|
continue
|
||
|
|
return True, "*"
|
||
|
|
if origin == pat:
|
||
|
|
return True, pat
|
||
|
|
if not strict:
|
||
|
|
if pat.startswith("*."):
|
||
|
|
suf = pat[1:] # .example.com
|
||
|
|
if origin.endswith(suf):
|
||
|
|
return True, pat
|
||
|
|
if pat.startswith("."):
|
||
|
|
if origin.endswith(pat):
|
||
|
|
return True, pat
|
||
|
|
return False, None
|
||
|
|
|
||
|
|
|
||
|
|
def init_smart_cors(app: Flask) -> None:
|
||
|
|
"""Attach Smart CORS handlers to a Flask app.
|
||
|
|
|
||
|
|
Env vars:
|
||
|
|
- ALLOWED_ORIGINS: comma-separated, e.g. https://a.com,https://b.com or *
|
||
|
|
- CORS_STRICT: true/false (default false)
|
||
|
|
- CORS_DEBUG: true/false
|
||
|
|
- NGINX_CORS_MODE: true/false (when Nginx sets Allow-Origin; app only supplements others)
|
||
|
|
- CORS_MAX_AGE: seconds for preflight caching (default 86400)
|
||
|
|
- CORS_EXPOSE_HEADERS: override exposed headers list
|
||
|
|
"""
|
||
|
|
|
||
|
|
allowed = _parse_allowed(os.getenv("ALLOWED_ORIGINS"))
|
||
|
|
|
||
|
|
|
||
|
|
# 添加前端实际使用的域名
|
||
|
|
frontend_origins = [
|
||
|
|
"http://chinaweal.com.cn:8090",
|
||
|
|
"http://www.chinaweal.com.cn:8090",
|
||
|
|
"https://chinaweal.com.cn",
|
||
|
|
"https://www.chinaweal.com.cn",
|
||
|
|
"http://172.22.80.130:8000", # 从日志中看到的referer
|
||
|
|
]
|
||
|
|
|
||
|
|
# 添加默认的本地开发源
|
||
|
|
default_origins = [
|
||
|
|
"http://localhost:3000",
|
||
|
|
"http://localhost:8000",
|
||
|
|
"http://localhost:8081",
|
||
|
|
"http://127.0.0.1:3000",
|
||
|
|
"http://127.0.0.1:8000",
|
||
|
|
"http://127.0.0.1:8081"
|
||
|
|
]
|
||
|
|
|
||
|
|
# 合并并去重
|
||
|
|
all_origins = list(set(allowed + frontend_origins + default_origins))
|
||
|
|
strict = _get_bool("CORS_STRICT", False)
|
||
|
|
debug = _get_bool("CORS_DEBUG", False)
|
||
|
|
nginx_mode = _get_bool("NGINX_CORS_MODE", False)
|
||
|
|
max_age = os.getenv("CORS_MAX_AGE", "86400")
|
||
|
|
expose_headers = os.getenv(
|
||
|
|
"CORS_EXPOSE_HEADERS",
|
||
|
|
"Content-Length, Content-Type, Authorization, X-Request-Id",
|
||
|
|
)
|
||
|
|
|
||
|
|
def _is_proxy_request() -> bool:
|
||
|
|
hdrs = request.headers
|
||
|
|
return any(h in hdrs for h in ("X-Forwarded-For", "X-Forwarded-Proto", "X-Real-IP"))
|
||
|
|
|
||
|
|
def _log(level: str, msg: str, **fields):
|
||
|
|
if not debug and level == "DEBUG":
|
||
|
|
return
|
||
|
|
logger = app.logger
|
||
|
|
kv = " ".join(f"{k}={v}" for k, v in fields.items()) if fields else ""
|
||
|
|
logger.log(
|
||
|
|
20 if level == "INFO" else 10 if level == "DEBUG" else 30,
|
||
|
|
f"[CORS] {msg} {kv}".rstrip(),
|
||
|
|
)
|
||
|
|
|
||
|
|
@app.before_request
|
||
|
|
def _handle_preflight(): # type: ignore
|
||
|
|
if request.method != "OPTIONS":
|
||
|
|
return None
|
||
|
|
origin = request.headers.get("Origin", "")
|
||
|
|
acrm = request.headers.get("Access-Control-Request-Method", "")
|
||
|
|
acrh = request.headers.get("Access-Control-Request-Headers", "")
|
||
|
|
|
||
|
|
allowed_ok, matched = _origin_matches(origin, all_origins, strict) if origin else (False, None)
|
||
|
|
|
||
|
|
# Build a minimal preflight response
|
||
|
|
resp = Response(status=204)
|
||
|
|
|
||
|
|
# When Nginx handles Allow-Origin, do not override it
|
||
|
|
if not nginx_mode:
|
||
|
|
if allowed_ok:
|
||
|
|
if matched == "*":
|
||
|
|
# If behind proxy (likely Nginx will add AO), skip to avoid duplicates
|
||
|
|
if not _is_proxy_request():
|
||
|
|
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||
|
|
else:
|
||
|
|
_log("DEBUG", "proxy-skip-allow-origin(preflight)")
|
||
|
|
else:
|
||
|
|
resp.headers["Access-Control-Allow-Origin"] = origin
|
||
|
|
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||
|
|
resp.headers.add("Vary", "Origin")
|
||
|
|
# Always supplement the rest
|
||
|
|
resp.headers["Access-Control-Allow-Methods"] = acrm or "GET, POST, OPTIONS"
|
||
|
|
if acrh:
|
||
|
|
resp.headers["Access-Control-Allow-Headers"] = acrh
|
||
|
|
else:
|
||
|
|
resp.headers["Access-Control-Allow-Headers"] = (
|
||
|
|
"Content-Type, Authorization, X-Requested-With, X-Request-Id"
|
||
|
|
)
|
||
|
|
resp.headers["Access-Control-Max-Age"] = max_age
|
||
|
|
resp.headers["Access-Control-Expose-Headers"] = expose_headers
|
||
|
|
|
||
|
|
decision = "deny"
|
||
|
|
if not origin:
|
||
|
|
decision = "no-origin"
|
||
|
|
elif allowed_ok:
|
||
|
|
decision = "allow-*" if matched == "*" else "allow-origin"
|
||
|
|
resp.headers["X-CORS-Decision"] = f"preflight; {decision}; nginx_mode={nginx_mode}"
|
||
|
|
|
||
|
|
_log(
|
||
|
|
"INFO",
|
||
|
|
"preflight",
|
||
|
|
path=request.path,
|
||
|
|
origin=origin,
|
||
|
|
request_method=acrm or "",
|
||
|
|
request_headers=acrh or "",
|
||
|
|
allowed=allowed_ok,
|
||
|
|
matched=matched or "",
|
||
|
|
nginx_mode=nginx_mode,
|
||
|
|
decision=decision,
|
||
|
|
)
|
||
|
|
return resp
|
||
|
|
|
||
|
|
@app.after_request
|
||
|
|
def _add_cors_headers(response: Response): # type: ignore
|
||
|
|
origin = request.headers.get("Origin", "")
|
||
|
|
if not origin:
|
||
|
|
return response
|
||
|
|
|
||
|
|
allowed_ok, matched = _origin_matches(origin, all_origins, strict)
|
||
|
|
|
||
|
|
# When Nginx sets Allow-Origin, we only supplement others to avoid duplicates
|
||
|
|
if not nginx_mode:
|
||
|
|
if allowed_ok:
|
||
|
|
if matched == "*":
|
||
|
|
# Do not set credentials when wildcard is used; if proxied, skip to avoid dup
|
||
|
|
if not _is_proxy_request():
|
||
|
|
response.headers.setdefault("Access-Control-Allow-Origin", "*")
|
||
|
|
else:
|
||
|
|
_log("DEBUG", "proxy-skip-allow-origin(response)")
|
||
|
|
else:
|
||
|
|
response.headers.setdefault("Access-Control-Allow-Origin", origin)
|
||
|
|
response.headers.setdefault("Access-Control-Allow-Credentials", "true")
|
||
|
|
# Ensure caches vary by Origin when dynamic
|
||
|
|
vary_val = response.headers.get("Vary")
|
||
|
|
if not vary_val:
|
||
|
|
response.headers["Vary"] = "Origin"
|
||
|
|
elif "Origin" not in vary_val:
|
||
|
|
response.headers["Vary"] = f"{vary_val}, Origin"
|
||
|
|
|
||
|
|
# Expose common headers (safe to always include)
|
||
|
|
response.headers.setdefault("Access-Control-Expose-Headers", expose_headers)
|
||
|
|
|
||
|
|
# Add decision header and log
|
||
|
|
decision = "deny"
|
||
|
|
if allowed_ok:
|
||
|
|
decision = "allow-*" if matched == "*" else "allow-origin"
|
||
|
|
response.headers.setdefault(
|
||
|
|
"X-CORS-Decision",
|
||
|
|
f"response; {decision}; nginx_mode={nginx_mode}",
|
||
|
|
)
|
||
|
|
_log(
|
||
|
|
"INFO",
|
||
|
|
"response",
|
||
|
|
path=request.path,
|
||
|
|
method=request.method,
|
||
|
|
origin=origin,
|
||
|
|
allowed=allowed_ok,
|
||
|
|
matched=matched or "",
|
||
|
|
nginx_mode=nginx_mode,
|
||
|
|
decision=decision,
|
||
|
|
)
|
||
|
|
return response
|
||
|
|
|
||
|
|
# Diagnosis endpoint to aid verification and debugging
|
||
|
|
@app.get("/api/cors-diagnosis")
|
||
|
|
def cors_diagnosis(): # type: ignore
|
||
|
|
origin = request.headers.get("Origin", "")
|
||
|
|
ok, matched = _origin_matches(origin, all_origins, strict) if origin else (False, None)
|
||
|
|
data = {
|
||
|
|
"nginx_mode": nginx_mode,
|
||
|
|
"detected_proxy": _is_proxy_request(),
|
||
|
|
"allowed_origins": all_origins,
|
||
|
|
"strict": strict,
|
||
|
|
"request_origin": origin,
|
||
|
|
"matched_rule": matched,
|
||
|
|
"origin_allowed": ok,
|
||
|
|
"notes": (
|
||
|
|
"Nginx handles Allow-Origin; app supplements other headers"
|
||
|
|
if nginx_mode
|
||
|
|
else "App sets CORS headers end-to-end"
|
||
|
|
),
|
||
|
|
}
|
||
|
|
return jsonify(data)
|