Files
KeyNexus/openai_route.py
2026-03-20 21:01:58 +08:00

2169 lines
73 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import argparse
import html
import json
import logging
import os
import sqlite3
import time
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Iterator
import requests
from flask import Flask, Response, abort, g, jsonify, redirect, request, session, stream_with_context, url_for
from requests.adapters import HTTPAdapter
LOGGER = logging.getLogger("openai_proxy")
DEFAULT_AUTH = "Bearer change-me"
DEFAULT_SECRET_KEY = "openai-proxy-dev-secret-change-me"
RETRYABLE_STATUSES = {401, 403, 404, 408, 409, 425, 429, 500, 502, 503, 504}
def env_int(name: str, default: int) -> int:
raw = os.getenv(name)
if raw is None:
return default
try:
return int(raw.strip())
except (TypeError, ValueError):
return default
def env_bool(name: str, default: bool) -> bool:
raw = os.getenv(name)
if raw is None:
return default
value = raw.strip().lower()
return value in {"1", "true", "yes", "on"}
def load_dotenv_file(path: Path) -> None:
if not path.exists() or not path.is_file():
return
try:
lines = path.read_text(encoding="utf-8-sig").splitlines()
except OSError:
return
for line in lines:
text = line.strip()
if not text or text.startswith("#"):
continue
if text.startswith("export "):
text = text[7:].strip()
if "=" not in text:
continue
key, value = text.split("=", 1)
key = key.strip()
value = value.strip().strip("'").strip('"')
if not key:
continue
if key not in os.environ:
os.environ[key] = value
@dataclass(frozen=True)
class Settings:
expected_auth: str
sqlite_path: Path
busy_timeout_ms: int
listen_host: str
listen_port: int
connect_timeout_sec: int
read_timeout_sec: int
stream_read_timeout_sec: int
max_failover_attempts: int
auto_group_fallback: bool
failure_cooldown_sec: int
log_full_payload: bool
log_text_limit: int
admin_enabled: bool
admin_token: str
secret_key: str
debug: bool
def load_settings() -> Settings:
sqlite_path = Path(os.getenv("OPENAI_PROXY_SQLITE_PATH", "./data/openai_proxy.db")).expanduser()
return Settings(
expected_auth=os.getenv("OPENAI_PROXY_AUTH", DEFAULT_AUTH),
sqlite_path=sqlite_path,
busy_timeout_ms=env_int("OPENAI_PROXY_DB_BUSY_TIMEOUT_MS", 5000),
listen_host=os.getenv("OPENAI_PROXY_LISTEN_HOST", "0.0.0.0"),
listen_port=env_int("OPENAI_PROXY_LISTEN_PORT", 8056),
connect_timeout_sec=env_int("OPENAI_PROXY_CONNECT_TIMEOUT_SEC", 10),
read_timeout_sec=env_int("OPENAI_PROXY_READ_TIMEOUT_SEC", 120),
stream_read_timeout_sec=env_int("OPENAI_PROXY_STREAM_READ_TIMEOUT_SEC", 600),
max_failover_attempts=max(1, env_int("OPENAI_PROXY_MAX_FAILOVER_ATTEMPTS", 5)),
auto_group_fallback=env_bool("OPENAI_PROXY_AUTO_GROUP_FALLBACK", True),
failure_cooldown_sec=max(0, env_int("OPENAI_PROXY_FAILURE_COOLDOWN_SEC", 30)),
log_full_payload=env_bool("OPENAI_PROXY_LOG_FULL_PAYLOAD", False),
log_text_limit=max(256, env_int("OPENAI_PROXY_LOG_TEXT_LIMIT", 100000)),
admin_enabled=env_bool("OPENAI_PROXY_ADMIN_ENABLED", True),
admin_token=os.getenv("OPENAI_PROXY_ADMIN_TOKEN", "").strip(),
secret_key=os.getenv("OPENAI_PROXY_SECRET_KEY", DEFAULT_SECRET_KEY).strip() or DEFAULT_SECRET_KEY,
debug=env_bool("OPENAI_PROXY_DEBUG", False),
)
load_dotenv_file(Path(os.getenv("OPENAI_PROXY_ENV_FILE", ".env")))
SETTINGS = load_settings()
APP = Flask(__name__)
APP.config["SECRET_KEY"] = SETTINGS.secret_key
APP.config["SESSION_COOKIE_HTTPONLY"] = True
APP.config["SESSION_COOKIE_SAMESITE"] = "Lax"
HTTP_SESSION = requests.Session()
HTTP_SESSION.mount("http://", HTTPAdapter(pool_connections=50, pool_maxsize=50, max_retries=0))
HTTP_SESSION.mount("https://", HTTPAdapter(pool_connections=50, pool_maxsize=50, max_retries=0))
def utc_now_ts() -> int:
return int(time.time())
def limit_type_alias(limit_type: str | None) -> str:
value = (limit_type or "").strip().lower()
alias_map = {
"miao": "second",
"fen": "minute",
"shi": "hour",
"tian": "day",
"yue": "month",
"nian": "year",
"sec": "second",
"min": "minute",
}
if value in {"second", "minute", "hour", "day", "month", "year"}:
return value
return alias_map.get(value, "day")
def cycle_start_ts(limit_type: str | None, ts: int) -> int:
unit = limit_type_alias(limit_type)
dt = datetime.fromtimestamp(ts, tz=timezone.utc)
if unit == "second":
return ts
if unit == "minute":
return int(dt.replace(second=0, microsecond=0).timestamp())
if unit == "hour":
return int(dt.replace(minute=0, second=0, microsecond=0).timestamp())
if unit == "day":
return int(dt.replace(hour=0, minute=0, second=0, microsecond=0).timestamp())
if unit == "month":
return int(dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0).timestamp())
if unit == "year":
return int(dt.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0).timestamp())
return int(dt.replace(hour=0, minute=0, second=0, microsecond=0).timestamp())
def open_sqlite_conn(db_path: Path) -> sqlite3.Connection:
db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(db_path), timeout=SETTINGS.busy_timeout_ms / 1000, isolation_level=None)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL;")
conn.execute(f"PRAGMA busy_timeout={SETTINGS.busy_timeout_ms};")
conn.execute("PRAGMA synchronous=NORMAL;")
conn.execute("PRAGMA foreign_keys=ON;")
return conn
def init_db(db_path: Path) -> None:
conn = open_sqlite_conn(db_path)
try:
conn.executescript(
"""
CREATE TABLE IF NOT EXISTS upstreams (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
group_name TEXT NOT NULL DEFAULT '1',
base_url TEXT NOT NULL,
api_key TEXT NOT NULL,
model TEXT NOT NULL,
force_parameter TEXT,
limit_type TEXT NOT NULL DEFAULT 'day',
limit_qty INTEGER NOT NULL DEFAULT 0,
limit_tokens INTEGER NOT NULL DEFAULT 0,
used_cycle_qty INTEGER NOT NULL DEFAULT 0,
used_cycle_tokens INTEGER NOT NULL DEFAULT 0,
used_all_qty INTEGER NOT NULL DEFAULT 0,
used_all_tokens INTEGER NOT NULL DEFAULT 0,
cycle_started_at INTEGER NOT NULL DEFAULT (unixepoch()),
last_used_at INTEGER,
consecutive_failures INTEGER NOT NULL DEFAULT 0,
cooldown_until INTEGER,
enabled INTEGER NOT NULL DEFAULT 1,
created_at INTEGER NOT NULL DEFAULT (unixepoch()),
updated_at INTEGER NOT NULL DEFAULT (unixepoch())
);
CREATE INDEX IF NOT EXISTS idx_upstreams_enabled_model
ON upstreams(enabled, model, name, group_name);
CREATE INDEX IF NOT EXISTS idx_upstreams_scheduling
ON upstreams(enabled, cooldown_until, consecutive_failures, last_used_at);
CREATE TABLE IF NOT EXISTS logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
req_id TEXT NOT NULL UNIQUE,
upstream_id INTEGER NOT NULL,
request_at INTEGER NOT NULL,
finish_at INTEGER,
status_code INTEGER,
used_tokens INTEGER NOT NULL DEFAULT 0,
request_payload TEXT,
finish_text TEXT,
error_text TEXT,
FOREIGN KEY (upstream_id) REFERENCES upstreams(id)
);
CREATE INDEX IF NOT EXISTS idx_logs_request_at
ON logs(request_at);
"""
)
finally:
conn.close()
def get_db() -> sqlite3.Connection:
conn = getattr(g, "_sqlite_conn", None)
if conn is None:
conn = open_sqlite_conn(SETTINGS.sqlite_path)
g._sqlite_conn = conn
return conn
@APP.teardown_appcontext
def close_db(_: BaseException | None) -> None:
conn = getattr(g, "_sqlite_conn", None)
if conn is not None:
conn.close()
g._sqlite_conn = None
def sqlite_row_to_dict(row: sqlite3.Row | None) -> dict[str, Any] | None:
if row is None:
return None
return {key: row[key] for key in row.keys()}
def parse_selector(raw_model: str) -> dict[str, Any] | None:
value = (raw_model or "").strip()[:100]
if not value:
return None
if value.lower().startswith("group:"):
group_part = value[6:]
groups = [part.strip() for part in group_part.split(",") if part.strip()]
if not groups:
return None
return {"type": "groups", "groups": groups, "value": value}
if "," in value:
groups = [part.strip() for part in value.split(",") if part.strip()]
if not groups:
return None
return {"type": "groups", "groups": groups, "value": value}
if value.isdigit():
return {"type": "groups", "groups": [value], "value": value}
return {"type": "exact", "model": value, "value": value}
def fetch_candidate_ids(
conn: sqlite3.Connection, selector: dict[str, Any], excluded_ids: list[int], now_ts: int
) -> list[int]:
params: list[Any] = []
sql = "SELECT id FROM upstreams WHERE enabled = 1"
if excluded_ids:
placeholders = ",".join("?" for _ in excluded_ids)
sql += f" AND id NOT IN ({placeholders})"
params.extend(excluded_ids)
if selector["type"] == "groups":
groups = list(selector["groups"])
if SETTINGS.auto_group_fallback and len(groups) == 1 and groups[0].strip().isdigit():
start_group = int(groups[0].strip())
group_rows = conn.execute("SELECT DISTINCT group_name FROM upstreams WHERE enabled = 1").fetchall()
numeric_groups = sorted(
{
int(str(row["group_name"]).strip())
for row in group_rows
if str(row["group_name"]).strip().isdigit()
}
)
expanded_groups = [str(item) for item in numeric_groups if item >= start_group]
if expanded_groups:
groups = expanded_groups
placeholders = ",".join("?" for _ in groups)
sql += f" AND group_name IN ({placeholders})"
params.extend(groups)
else:
sql += " AND (name = ? OR model = ?)"
params.append(selector["model"])
params.append(selector["model"])
sql += """
ORDER BY
CASE WHEN COALESCE(cooldown_until, 0) <= ? THEN 0 ELSE 1 END ASC,
-- 稳定优先:默认按配置顺序(编号)选择,失败才在单次请求内切到下一个。
id ASC
"""
params.append(now_ts)
rows = conn.execute(sql, params).fetchall()
return [int(row["id"]) for row in rows]
def selector_group_fallback(selector: dict[str, Any]) -> dict[str, Any] | None:
if selector.get("type") != "exact":
return None
raw = str(selector.get("model") or "").strip()
if not raw:
return None
return {"type": "groups", "groups": [raw], "value": f"group:{raw}"}
def evaluate_row_available_now(row: dict[str, Any], now_ts: int) -> tuple[bool, str]:
cooldown_until = int(row.get("cooldown_until") or 0)
if cooldown_until > now_ts:
return False, "cooldown"
cycle_started_at = int(row.get("cycle_started_at") or 0)
target_cycle_start = cycle_start_ts(str(row.get("limit_type") or "day"), now_ts)
used_cycle_qty = int(row.get("used_cycle_qty") or 0)
used_cycle_tokens = int(row.get("used_cycle_tokens") or 0)
if cycle_started_at < target_cycle_start:
used_cycle_qty = 0
used_cycle_tokens = 0
limit_qty = int(row.get("limit_qty") or 0)
limit_tokens = int(row.get("limit_tokens") or 0)
qty_ok = limit_qty <= 0 or used_cycle_qty < limit_qty
token_ok = limit_tokens <= 0 or used_cycle_tokens < limit_tokens
if qty_ok and token_ok:
return True, "ok"
return False, "quota"
def collect_no_available_reason(conn: sqlite3.Connection, selector: dict[str, Any]) -> str:
now_ts = utc_now_ts()
selectors = [selector]
fallback_selector = selector_group_fallback(selector)
if fallback_selector is not None:
selectors.append(fallback_selector)
matched_ids: set[int] = set()
for item in selectors:
matched_ids.update(fetch_candidate_ids(conn, item, [], now_ts))
if not matched_ids:
model_value = str(selector.get("value") or selector.get("model") or "")
return f"No available model endpoint: 未匹配到可用上游,请检查 model/分组。当前 model={model_value}"
placeholders = ",".join("?" for _ in matched_ids)
rows = conn.execute(
f"""
SELECT id, name, group_name, model, limit_type, limit_qty, limit_tokens,
used_cycle_qty, used_cycle_tokens, cycle_started_at, cooldown_until
FROM upstreams
WHERE id IN ({placeholders})
""",
tuple(sorted(matched_ids)),
).fetchall()
cooldown_count = 0
quota_count = 0
ok_count = 0
for row in rows:
row_dict = sqlite_row_to_dict(row) or {}
available, reason = evaluate_row_available_now(row_dict, now_ts)
if available:
ok_count += 1
elif reason == "cooldown":
cooldown_count += 1
elif reason == "quota":
quota_count += 1
return (
"No available model endpoint: "
f"匹配到 {len(matched_ids)} 条上游,当前可用 {ok_count} 条,冷却中 {cooldown_count} 条,限额阻塞 {quota_count} 条。"
)
def refresh_cycle_if_needed(conn: sqlite3.Connection, row: dict[str, Any], now_ts: int) -> dict[str, Any]:
current_cycle_start = int(row.get("cycle_started_at") or 0)
target_cycle_start = cycle_start_ts(row.get("limit_type"), now_ts)
if current_cycle_start >= target_cycle_start:
return row
conn.execute(
"""
UPDATE upstreams
SET used_cycle_qty = 0,
used_cycle_tokens = 0,
cycle_started_at = ?,
updated_at = ?
WHERE id = ?
""",
(target_cycle_start, now_ts, row["id"]),
)
row["used_cycle_qty"] = 0
row["used_cycle_tokens"] = 0
row["cycle_started_at"] = target_cycle_start
return row
def has_quota(row: dict[str, Any]) -> bool:
limit_qty = int(row.get("limit_qty") or 0)
limit_tokens = int(row.get("limit_tokens") or 0)
used_cycle_qty = int(row.get("used_cycle_qty") or 0)
used_cycle_tokens = int(row.get("used_cycle_tokens") or 0)
qty_ok = limit_qty <= 0 or used_cycle_qty < limit_qty
token_ok = limit_tokens <= 0 or used_cycle_tokens < limit_tokens
return qty_ok and token_ok
def reserve_upstream(conn: sqlite3.Connection, upstream_id: int, now_ts: int) -> dict[str, Any] | None:
try:
conn.execute("BEGIN IMMEDIATE")
row = sqlite_row_to_dict(
conn.execute("SELECT * FROM upstreams WHERE id = ? AND enabled = 1", (upstream_id,)).fetchone()
)
if row is None:
conn.execute("ROLLBACK")
return None
row = refresh_cycle_if_needed(conn, row, now_ts)
cooldown_until = int(row.get("cooldown_until") or 0)
if cooldown_until > now_ts:
conn.execute("ROLLBACK")
return None
if not has_quota(row):
conn.execute("ROLLBACK")
return None
conn.execute(
"""
UPDATE upstreams
SET used_cycle_qty = used_cycle_qty + 1,
used_all_qty = used_all_qty + 1,
last_used_at = ?,
updated_at = ?
WHERE id = ?
""",
(now_ts, now_ts, upstream_id),
)
conn.execute("COMMIT")
row["used_cycle_qty"] = int(row.get("used_cycle_qty") or 0) + 1
row["used_all_qty"] = int(row.get("used_all_qty") or 0) + 1
row["last_used_at"] = now_ts
return row
except sqlite3.OperationalError:
try:
conn.execute("ROLLBACK")
except sqlite3.Error:
pass
return None
except Exception:
try:
conn.execute("ROLLBACK")
except sqlite3.Error:
pass
raise
def mark_upstream_success(conn: sqlite3.Connection, upstream_id: int, used_tokens: int) -> None:
now_ts = utc_now_ts()
conn.execute(
"""
UPDATE upstreams
SET used_cycle_tokens = used_cycle_tokens + ?,
used_all_tokens = used_all_tokens + ?,
consecutive_failures = 0,
cooldown_until = NULL,
updated_at = ?
WHERE id = ?
""",
(used_tokens, used_tokens, now_ts, upstream_id),
)
def mark_upstream_failure(conn: sqlite3.Connection, upstream_id: int) -> None:
now_ts = utc_now_ts()
cooldown_until = now_ts + SETTINGS.failure_cooldown_sec if SETTINGS.failure_cooldown_sec > 0 else now_ts
conn.execute(
"""
UPDATE upstreams
SET consecutive_failures = consecutive_failures + 1,
cooldown_until = ?,
updated_at = ?
WHERE id = ?
""",
(cooldown_until, now_ts, upstream_id),
)
def is_retryable_status(status_code: int) -> bool:
return status_code in RETRYABLE_STATUSES
def is_retryable_failure(status_code: int, response_text: str) -> bool:
if is_retryable_status(status_code):
return True
if status_code not in {400, 422}:
return False
text = (response_text or "").lower()
# 对常见“模型不存在/模型不可用”类错误放行重试到下一个上游。
model_hints = [
"model",
"not found",
"no such model",
"unknown model",
"invalid model",
"model not",
"模型",
"不存在",
"不可用",
]
return any(token in text for token in model_hints)
def parse_int_like(value: Any) -> int | None:
if isinstance(value, int):
return value
if isinstance(value, float):
return int(value)
if isinstance(value, str):
text = value.strip()
if not text:
return None
if text.isdigit() or (text.startswith("-") and text[1:].isdigit()):
try:
return int(text)
except ValueError:
return None
try:
return int(float(text))
except ValueError:
return None
return None
def sanitize_request_payload(payload: dict[str, Any]) -> dict[str, Any]:
if SETTINGS.log_full_payload:
return payload
messages = payload.get("messages")
message_count = len(messages) if isinstance(messages, list) else 0
return {
"model": payload.get("model"),
"stream": bool(payload.get("stream", False)),
"max_tokens": payload.get("max_tokens"),
"max_output_tokens": payload.get("max_output_tokens"),
"message_count": message_count,
}
def create_request_log(conn: sqlite3.Connection, upstream_id: int, payload: dict[str, Any]) -> str:
req_id = uuid.uuid4().hex
safe_payload = sanitize_request_payload(payload)
payload_json = json.dumps(safe_payload, ensure_ascii=False, allow_nan=False)
conn.execute(
"""
INSERT INTO logs(req_id, upstream_id, request_at, request_payload)
VALUES (?, ?, ?, ?)
""",
(req_id, upstream_id, utc_now_ts(), payload_json),
)
return req_id
def finish_request_log(
conn: sqlite3.Connection,
req_id: str,
status_code: int,
used_tokens: int,
finish_text: str | None = None,
error_text: str | None = None,
) -> None:
trimmed_finish = (finish_text or "")[: SETTINGS.log_text_limit]
trimmed_error = (error_text or "")[: SETTINGS.log_text_limit]
conn.execute(
"""
UPDATE logs
SET finish_at = ?,
status_code = ?,
used_tokens = ?,
finish_text = ?,
error_text = ?
WHERE req_id = ?
""",
(utc_now_ts(), status_code, used_tokens, trimmed_finish, trimmed_error, req_id),
)
def merge_force_parameter(payload: dict[str, Any], upstream: dict[str, Any]) -> dict[str, Any]:
merged = dict(payload)
merged["model"] = upstream.get("model")
force_raw = upstream.get("force_parameter")
if force_raw is None:
return merged
force_str = str(force_raw).strip()
if not force_str:
return merged
try:
parsed = json.loads(force_str)
except json.JSONDecodeError:
return merged
force_payload: dict[str, Any] = {}
if isinstance(parsed, dict):
force_payload = dict(parsed)
elif isinstance(parsed, list):
for item in parsed:
if isinstance(item, dict):
force_payload.update(item)
if force_payload:
merged.update(force_payload)
return merged
def build_request_meta(upstream: dict[str, Any]) -> tuple[str, dict[str, str]]:
base = str(upstream.get("base_url", "")).rstrip("/") + "/"
api_key = str(upstream.get("api_key", "")).strip()
auth_value = api_key if api_key.startswith("Bearer ") else f"Bearer {api_key}"
return (
base + "chat/completions",
{
"Authorization": auth_value,
"Content-Type": "application/json",
},
)
def build_models_meta(upstream: dict[str, Any]) -> tuple[str, dict[str, str]]:
base = str(upstream.get("base_url", "")).rstrip("/") + "/"
api_key = str(upstream.get("api_key", "")).strip()
auth_value = api_key if api_key.startswith("Bearer ") else f"Bearer {api_key}"
return (
base + "models",
{
"Authorization": auth_value,
},
)
def build_admin_probe_payload(upstream: dict[str, Any]) -> dict[str, Any]:
# 仅用于后台连通性探测,尽量减少消耗。
payload = {
"model": str(upstream.get("model") or ""),
"messages": [{"role": "user", "content": "ping"}],
"stream": False,
"max_tokens": 1,
"temperature": 0,
}
return merge_force_parameter(payload, upstream)
def extract_total_tokens(usage: Any) -> int:
if not usage:
return 0
if hasattr(usage, "model_dump"):
usage = usage.model_dump(mode="json")
if not isinstance(usage, dict):
return 0
total_tokens = parse_int_like(usage.get("total_tokens"))
if isinstance(total_tokens, int) and total_tokens > 0:
return total_tokens
prompt_tokens = parse_int_like(usage.get("prompt_tokens"))
completion_tokens = parse_int_like(usage.get("completion_tokens"))
if isinstance(prompt_tokens, int) and isinstance(completion_tokens, int):
return prompt_tokens + completion_tokens
return 0
def extract_finish_text_from_response(resp_json: Any) -> str:
if not isinstance(resp_json, dict):
return ""
choices = resp_json.get("choices")
if not isinstance(choices, list) or not choices:
return ""
first = choices[0]
if not isinstance(first, dict):
return ""
message = first.get("message")
if isinstance(message, dict):
content = message.get("content")
if isinstance(content, str):
return content
text = first.get("text")
if isinstance(text, str):
return text
return ""
def extract_delta_text_from_chunk(chunk_dict: Any) -> str:
if not isinstance(chunk_dict, dict):
return ""
choices = chunk_dict.get("choices")
if not isinstance(choices, list) or not choices:
return ""
first = choices[0]
if not isinstance(first, dict):
return ""
delta = first.get("delta")
if isinstance(delta, dict):
content = delta.get("content")
if isinstance(content, str):
return content
return ""
def unauthorized_response() -> tuple[Response, int]:
return jsonify({"error": {"message": "Unauthorized"}}), 401
def upstream_error_response(message: str, status_code: int = 502) -> Response:
response = jsonify({"error": {"message": message}})
response.status_code = status_code
return response
def auth_header_ok(auth_header: str) -> bool:
expected_raw = (SETTINGS.expected_auth or "").strip()
got_raw = (auth_header or "").strip()
if not expected_raw:
return False
if got_raw == expected_raw:
return True
# 允许 Bearer 前缀大小写或客户端仅传 token 本体,减少接入出错概率。
return token_body(got_raw) == token_body(expected_raw)
def token_body(token: str) -> str:
text = (token or "").strip()
if text.lower().startswith("bearer "):
return text[7:].strip()
return text
def effective_admin_token() -> str:
if SETTINGS.admin_token:
return token_body(SETTINGS.admin_token)
return token_body(SETTINGS.expected_auth)
def admin_token_ok(user_input: str) -> bool:
expected = effective_admin_token()
if not expected:
return False
typed = token_body(user_input)
return typed == expected
def is_local_request() -> bool:
remote_addr = (request.remote_addr or "").strip()
return remote_addr in {"127.0.0.1", "::1", "localhost"}
def ensure_admin_access() -> Response | None:
if not SETTINGS.admin_enabled:
abort(404)
if session.get("admin_ok") is True:
return None
# 默认没有额外管理员口令时,允许本机直接访问可视化页面,方便个人使用。
if not SETTINGS.admin_token and is_local_request():
return None
return redirect(url_for("admin_login", next=request.path))
def to_form_int(raw: str, default: int = 0) -> int:
try:
return int((raw or "").strip())
except (TypeError, ValueError):
return default
def parse_force_parameter_text(raw_text: str) -> tuple[bool, str | None]:
text = (raw_text or "").strip()
if not text:
return True, None
try:
parsed = json.loads(text)
except json.JSONDecodeError:
return False, "强制参数不是合法 JSON"
if not isinstance(parsed, (dict, list)):
return False, "强制参数必须是 JSON 对象或数组"
return True, json.dumps(parsed, ensure_ascii=False)
def format_ts(ts_value: Any) -> str:
if ts_value is None:
return "-"
try:
ts_int = int(ts_value)
except (TypeError, ValueError):
return "-"
if ts_int <= 0:
return "-"
return datetime.fromtimestamp(ts_int).strftime("%Y-%m-%d %H:%M:%S")
def attach_route_headers(response: Response, upstream: dict[str, Any], tried_ids: list[int]) -> Response:
current_id = int(upstream.get("id") or 0)
chain_ids = list(tried_ids) + [current_id]
response.headers["X-Route-Upstream-Id"] = str(current_id)
response.headers["X-Route-Group"] = str(upstream.get("group_name") or "")
response.headers["X-Route-Model"] = str(upstream.get("model") or "")
response.headers["X-Route-Tried-Ids"] = ",".join(str(item) for item in chain_ids if int(item) > 0)
return response
def choose_upstream(conn: sqlite3.Connection, selector: dict[str, Any], tried_ids: list[int]) -> dict[str, Any] | None:
now_ts = utc_now_ts()
selectors = [selector]
fallback_selector = selector_group_fallback(selector)
if fallback_selector is not None:
selectors.append(fallback_selector)
checked_ids: set[int] = set()
for item in selectors:
candidate_ids = fetch_candidate_ids(conn, item, tried_ids, now_ts)
for candidate_id in candidate_ids:
if candidate_id in checked_ids:
continue
checked_ids.add(candidate_id)
reserved = reserve_upstream(conn, candidate_id, now_ts)
if reserved is not None:
return reserved
return None
@APP.get("/healthz")
def healthz() -> dict[str, str]:
return {"status": "ok"}
@APP.get("/")
def index_page() -> Response:
admin_text = "已开启" if SETTINGS.admin_enabled else "未开启"
page = f"""<!doctype html>
<html lang="zh-CN">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>OpenAI 路由服务</title>
<style>
body {{
margin: 0;
padding: 32px;
font-family: "Microsoft YaHei", "PingFang SC", "Segoe UI", sans-serif;
background: #f4f6fb;
color: #1c2432;
}}
.card {{
max-width: 840px;
margin: 0 auto;
background: #fff;
border: 1px solid #e2e8f0;
border-radius: 14px;
padding: 24px;
box-shadow: 0 8px 30px rgba(23, 35, 67, 0.08);
}}
h1 {{
margin: 0 0 12px;
font-size: 26px;
}}
p {{
margin: 0 0 10px;
color: #445268;
line-height: 1.6;
}}
.btn {{
display: inline-block;
margin-right: 12px;
margin-top: 14px;
padding: 10px 14px;
border-radius: 10px;
text-decoration: none;
border: 1px solid #c6d2ea;
color: #163a78;
background: #eef4ff;
font-weight: 600;
}}
</style>
</head>
<body>
<div class="card">
<h1>OpenAI 路由服务</h1>
<p>服务状态:正常</p>
<p>可视化管理:{admin_text}</p>
<p>接口入口:<code>/v1/chat/completions</code>、<code>/v1/models</code></p>
<a class="btn" href="/admin">进入可视化管理</a>
<a class="btn" href="/healthz">健康检查</a>
</div>
</body>
</html>"""
return Response(page, content_type="text/html; charset=utf-8")
def limit_type_display(limit_type: str | None) -> str:
mapping = {
"second": "",
"minute": "分钟",
"hour": "小时",
"day": "",
"month": "",
"year": "",
}
alias = limit_type_alias(limit_type)
return mapping.get(alias, "")
def render_limit_type_options(selected: str | None) -> str:
choices = [
("second", ""),
("minute", "分钟"),
("hour", "小时"),
("day", ""),
("month", ""),
("year", ""),
]
current = limit_type_alias(selected)
parts: list[str] = []
for value, text in choices:
selected_attr = " selected" if current == value else ""
parts.append(f"<option value='{value}'{selected_attr}>{text}</option>")
return "".join(parts)
@APP.get("/admin/login")
def admin_login() -> Response:
if not SETTINGS.admin_enabled:
abort(404)
next_url = request.args.get("next", "/admin")
if not next_url.startswith("/"):
next_url = "/admin"
err = request.args.get("err", "")
err_html = f"<p class='err'>{html.escape(err)}</p>" if err else ""
page = f"""<!doctype html>
<html lang="zh-CN">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>管理员登录</title>
<style>
body {{
margin: 0;
min-height: 100vh;
display: grid;
place-items: center;
background: radial-gradient(circle at top, #f6fafc, #e8eef8);
font-family: "Microsoft YaHei", "PingFang SC", "Segoe UI", sans-serif;
color: #20304a;
}}
.card {{
width: min(460px, 92vw);
background: #ffffff;
border: 1px solid #dbe5f5;
border-radius: 14px;
box-shadow: 0 12px 32px rgba(23, 45, 88, 0.12);
padding: 24px;
}}
h1 {{
margin: 0 0 8px;
font-size: 24px;
}}
p {{
color: #4c5f7c;
}}
.err {{
color: #b42318;
font-weight: 700;
}}
input {{
width: 100%;
box-sizing: border-box;
border: 1px solid #c7d4eb;
border-radius: 10px;
padding: 10px 12px;
font-size: 14px;
margin-top: 6px;
}}
button {{
margin-top: 14px;
width: 100%;
border: 0;
border-radius: 10px;
padding: 11px 12px;
font-size: 15px;
font-weight: 700;
color: #fff;
background: #2563eb;
cursor: pointer;
}}
</style>
</head>
<body>
<div class="card">
<h1>管理后台登录</h1>
<p>请输入管理员口令后继续。</p>
{err_html}
<form method="post" action="/admin/login">
<input type="hidden" name="next" value="{html.escape(next_url)}" />
<label>管理员口令</label>
<input type="password" name="token" placeholder="请输入口令" required />
<button type="submit">进入后台</button>
</form>
</div>
</body>
</html>"""
return Response(page, content_type="text/html; charset=utf-8")
@APP.post("/admin/login")
def admin_login_submit() -> Response:
if not SETTINGS.admin_enabled:
abort(404)
next_url = request.form.get("next", "/admin")
if not next_url.startswith("/"):
next_url = "/admin"
token = request.form.get("token", "")
if admin_token_ok(token):
session["admin_ok"] = True
return redirect(next_url)
return redirect(url_for("admin_login", next=next_url, err="口令错误,请重试"))
@APP.get("/admin/logout")
def admin_logout() -> Response:
if not SETTINGS.admin_enabled:
abort(404)
session.pop("admin_ok", None)
return redirect(url_for("admin_login", next="/admin"))
@APP.get("/admin")
def admin_console() -> Response:
guard = ensure_admin_access()
if guard is not None:
return guard
conn = get_db()
rows = conn.execute(
"""
SELECT id, name, group_name, model, base_url, limit_type, limit_qty, limit_tokens,
used_cycle_qty, used_cycle_tokens, used_all_qty, used_all_tokens,
consecutive_failures, cooldown_until, last_used_at, enabled, updated_at
FROM upstreams
ORDER BY id ASC
"""
).fetchall()
summary_row = conn.execute(
"""
SELECT
COUNT(*) AS req_total,
COALESCE(SUM(CASE WHEN status_code = 200 THEN 1 ELSE 0 END), 0) AS req_success,
COALESCE(SUM(CASE WHEN status_code IS NOT NULL AND status_code <> 200 THEN 1 ELSE 0 END), 0) AS req_fail,
COALESCE(SUM(used_tokens), 0) AS token_total
FROM logs
"""
).fetchone()
day_since_ts = utc_now_ts() - 86400
day_row = conn.execute(
"""
SELECT
COUNT(*) AS req_day,
COALESCE(SUM(used_tokens), 0) AS token_day
FROM logs
WHERE request_at >= ?
""",
(day_since_ts,),
).fetchone()
top_rows = conn.execute(
"""
SELECT
u.id AS upstream_id,
u.name AS name,
u.group_name AS group_name,
u.model AS model,
COUNT(l.id) AS req_total,
COALESCE(SUM(CASE WHEN l.status_code = 200 THEN 1 ELSE 0 END), 0) AS req_success,
COALESCE(SUM(CASE WHEN l.status_code IS NOT NULL AND l.status_code <> 200 THEN 1 ELSE 0 END), 0) AS req_fail,
COALESCE(SUM(l.used_tokens), 0) AS token_total,
MAX(l.request_at) AS last_at
FROM upstreams u
LEFT JOIN logs l ON l.upstream_id = u.id
GROUP BY u.id, u.name, u.group_name, u.model
ORDER BY req_total DESC, token_total DESC, u.id ASC
LIMIT 20
"""
).fetchall()
recent_rows = conn.execute(
"""
SELECT
l.id AS log_id,
l.request_at AS request_at,
l.status_code AS status_code,
l.used_tokens AS used_tokens,
u.id AS upstream_id,
u.name AS name,
u.group_name AS group_name,
u.model AS model
FROM logs l
LEFT JOIN upstreams u ON u.id = l.upstream_id
ORDER BY l.id DESC
LIMIT 50
"""
).fetchall()
msg = request.args.get("msg", "")
err = request.args.get("err", "")
banner_html = ""
if msg:
banner_html += f"<p class='msg'>{html.escape(msg)}</p>"
if err:
banner_html += f"<p class='err'>{html.escape(err)}</p>"
summary = sqlite_row_to_dict(summary_row) or {}
day_summary = sqlite_row_to_dict(day_row) or {}
top_parts: list[str] = []
for item in top_rows:
row = sqlite_row_to_dict(item) or {}
top_parts.append(
f"<tr>"
f"<td>{row.get('upstream_id')}</td>"
f"<td>{html.escape(str(row.get('name') or ''))}</td>"
f"<td>{html.escape(str(row.get('group_name') or ''))}</td>"
f"<td>{html.escape(str(row.get('model') or ''))}</td>"
f"<td>{row.get('req_total') or 0}</td>"
f"<td>{row.get('req_success') or 0}</td>"
f"<td>{row.get('req_fail') or 0}</td>"
f"<td>{row.get('token_total') or 0}</td>"
f"<td>{format_ts(row.get('last_at'))}</td>"
f"</tr>"
)
top_html = "".join(top_parts) if top_parts else "<tr><td colspan='9'>暂无统计数据</td></tr>"
recent_parts: list[str] = []
for item in recent_rows:
row = sqlite_row_to_dict(item) or {}
recent_parts.append(
f"<tr>"
f"<td>{row.get('log_id') or '-'}</td>"
f"<td>{format_ts(row.get('request_at'))}</td>"
f"<td>{row.get('status_code') or '-'}</td>"
f"<td>{row.get('used_tokens') or 0}</td>"
f"<td>{row.get('upstream_id') or '-'}</td>"
f"<td>{html.escape(str(row.get('name') or ''))}</td>"
f"<td>{html.escape(str(row.get('group_name') or ''))}</td>"
f"<td>{html.escape(str(row.get('model') or ''))}</td>"
f"</tr>"
)
recent_html = "".join(recent_parts) if recent_parts else "<tr><td colspan='8'>暂无调用记录</td></tr>"
body_rows: list[str] = []
for row in rows:
row_dict = sqlite_row_to_dict(row) or {}
enabled_checked = "checked" if int(row_dict.get("enabled") or 0) == 1 else ""
options_html = render_limit_type_options(str(row_dict.get("limit_type") or "day"))
body_rows.append(
f"""
<tr>
<td>{row_dict.get("id")}</td>
<td colspan="11">
<form method="post" action="/admin/upstreams/{row_dict.get('id')}/update" class="row-form">
<label>中文别名</label><input name="name" value="{html.escape(str(row_dict.get('name') or ''))}" required />
<label>分组编号</label><input name="group_name" value="{html.escape(str(row_dict.get('group_name') or '1'))}" required />
<label>上游模型名</label><input name="model" value="{html.escape(str(row_dict.get('model') or ''))}" required />
<label>上游地址</label><input name="base_url" value="{html.escape(str(row_dict.get('base_url') or ''))}" required />
<label>API 密钥</label><input name="api_key" value="{html.escape(str(row_dict.get('api_key') or ''))}" required />
<label>强制参数(JSON)</label><textarea name="force_parameter">{html.escape(str(row_dict.get('force_parameter') or ''))}</textarea>
<label>限额周期</label><select name="limit_type">{options_html}</select>
<label>请求限额</label><input name="limit_qty" value="{row_dict.get('limit_qty')}" />
<label>Token 限额</label><input name="limit_tokens" value="{row_dict.get('limit_tokens')}" />
<label class="check"><input type="checkbox" name="enabled" value="1" {enabled_checked} /> 启用</label>
<div class="stats">
周期请求: {row_dict.get('used_cycle_qty')} |
周期Token: {row_dict.get('used_cycle_tokens')} |
总请求: {row_dict.get('used_all_qty')} |
总Token: {row_dict.get('used_all_tokens')} |
连续失败: {row_dict.get('consecutive_failures')} |
冷却到: {format_ts(row_dict.get('cooldown_until'))} |
周期限额单位: {limit_type_display(str(row_dict.get('limit_type') or 'day'))}
</div>
<button type="submit" class="btn-save">保存修改</button>
<button type="submit" formaction="/admin/upstreams/{row_dict.get('id')}/test" formnovalidate class="btn-test">测试连通性</button>
<button type="submit" formaction="/admin/upstreams/{row_dict.get('id')}/delete" formnovalidate class="btn-del" onclick="return confirm('确认删除这条上游吗?');">删除</button>
</form>
</td>
</tr>
"""
)
body_html = "".join(body_rows) if body_rows else "<tr><td colspan='12'>暂无上游配置</td></tr>"
page = f"""<!doctype html>
<html lang="zh-CN">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>可视化管理后台</title>
<style>
body {{
margin: 0;
padding: 24px;
font-family: "Microsoft YaHei", "PingFang SC", sans-serif;
background: linear-gradient(180deg, #eef3ff 0%, #f7f9ff 48%, #f5f8ff 100%);
color: #21304b;
}}
h1 {{
margin: 0 0 8px 0;
font-size: 22px;
}}
.hint {{
margin: 0 0 14px 0;
color: #4e5f7a;
}}
.msg {{
background: #ecfdf3;
border: 1px solid #abefc6;
color: #067647;
padding: 10px 12px;
border-radius: 10px;
}}
.err {{
background: #fef3f2;
border: 1px solid #fecdca;
color: #b42318;
padding: 10px 12px;
border-radius: 10px;
}}
.panel {{
border: 1px solid #d5e1f5;
border-radius: 14px;
background: #fff;
padding: 16px;
margin-bottom: 16px;
box-shadow: 0 8px 20px rgba(32, 48, 75, 0.06);
}}
.stats-cards {{
display: grid;
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
gap: 10px;
}}
.card {{
border: 1px solid #d8e3f5;
border-radius: 10px;
padding: 10px 12px;
background: #f8fbff;
}}
.card .num {{
margin-top: 6px;
font-size: 20px;
font-weight: 700;
color: #15428b;
}}
.form-grid {{
display: grid;
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
gap: 8px;
}}
label {{
font-size: 12px;
font-weight: 700;
color: #486185;
display: block;
margin-top: 8px;
}}
input, textarea, select {{
width: 100%;
box-sizing: border-box;
padding: 8px 10px;
border-radius: 8px;
border: 1px solid #c8d6ee;
margin-top: 4px;
font-size: 13px;
background: #fff;
color: #21304b;
}}
textarea {{
min-height: 62px;
resize: vertical;
}}
.check {{
display: flex;
align-items: center;
gap: 8px;
margin-top: 18px;
font-size: 13px;
}}
.check input {{
width: auto;
margin: 0;
}}
.table-wrap {{
border: 1px solid #d5e1f5;
border-radius: 10px;
overflow: auto;
background: #fff;
box-shadow: inset 0 0 0 1px rgba(255,255,255,0.4);
}}
table {{
border-collapse: collapse;
width: 100%;
min-width: 1180px;
}}
th, td {{
border-bottom: 1px solid #edf2fa;
border-right: 1px solid #edf2fa;
padding: 10px 12px;
font-size: 13px;
text-align: left;
vertical-align: top;
word-break: break-word;
}}
th {{
position: sticky;
top: 0;
background: #eff5ff;
z-index: 1;
}}
th:last-child, td:last-child {{
border-right: none;
}}
tr:nth-child(even) td {{
background: #fbfdff;
}}
tr:hover td {{
background: #f1f6ff;
}}
.row-form {{
display: grid;
grid-template-columns: repeat(auto-fit, minmax(170px, 1fr));
gap: 6px 10px;
}}
.stats {{
grid-column: 1 / -1;
color: #5b708f;
font-size: 12px;
margin-top: 4px;
}}
.btn-save, .btn-test, .btn-del, .btn-add, .btn-link, .btn-clean {{
border: 0;
border-radius: 9px;
padding: 8px 12px;
font-weight: 700;
cursor: pointer;
margin-top: 12px;
color: white;
text-decoration: none;
display: inline-block;
width: fit-content;
}}
.btn-save {{
background: #2563eb;
}}
.btn-test {{
background: #0d9488;
}}
.btn-add {{
background: #0f766e;
}}
.btn-link {{
background: #475467;
}}
.btn-clean {{
background: #9f1239;
}}
.btn-del {{
background: #b42318;
}}
.btn-save:hover, .btn-test:hover, .btn-del:hover, .btn-add:hover, .btn-link:hover, .btn-clean:hover {{
filter: brightness(1.05);
transform: translateY(-1px);
transition: all .15s ease;
}}
.mini-form {{
display: flex;
align-items: center;
gap: 10px;
flex-wrap: wrap;
}}
.mini-form input {{
width: 120px;
margin-top: 0;
}}
</style>
</head>
<body>
<h1>上游可视化管理</h1>
<p class="hint">这里可以新增、编辑、启停、删除上游。建议给客户端使用中文别名(例如:通用聊天、快速模型)。</p>
{banner_html}
<div class="panel">
<h2>总览统计</h2>
<div class="stats-cards">
<div class="card"><div>总请求数</div><div class="num">{summary.get("req_total", 0)}</div></div>
<div class="card"><div>成功请求</div><div class="num">{summary.get("req_success", 0)}</div></div>
<div class="card"><div>失败请求</div><div class="num">{summary.get("req_fail", 0)}</div></div>
<div class="card"><div>总 Token</div><div class="num">{summary.get("token_total", 0)}</div></div>
<div class="card"><div>近24小时请求</div><div class="num">{day_summary.get("req_day", 0)}</div></div>
<div class="card"><div>近24小时 Token</div><div class="num">{day_summary.get("token_day", 0)}</div></div>
</div>
</div>
<div class="panel">
<h2>新增上游</h2>
<form method="post" action="/admin/logs/cleanup" class="mini-form">
<label style="margin:0;">日志清理(保留最近天数)</label>
<input type="number" name="days" min="1" max="3650" value="7" />
<button type="submit" class="btn-clean" onclick="return confirm('确认清理旧日志吗?');">清理日志</button>
</form>
<form method="post" action="/admin/upstreams/create" class="form-grid">
<div><label>中文别名</label><input name="name" placeholder="例如:通用聊天" required /></div>
<div><label>分组编号</label><input name="group_name" value="1" required /></div>
<div><label>上游模型名</label><input name="model" placeholder="例如gpt-4o-mini" required /></div>
<div><label>上游地址</label><input name="base_url" placeholder="https://xxx.com/v1" required /></div>
<div><label>API 密钥</label><input name="api_key" placeholder="sk-xxx" required /></div>
<div><label>强制参数(JSON)</label><textarea name="force_parameter" placeholder='例如:{{"temperature":0.7}}'></textarea></div>
<div><label>限额周期</label><select name="limit_type">{render_limit_type_options("day")}</select></div>
<div><label>请求限额(0=不限)</label><input name="limit_qty" value="0" /></div>
<div><label>Token 限额(0=不限)</label><input name="limit_tokens" value="0" /></div>
<div><label class="check"><input type="checkbox" name="enabled" value="1" checked /> 启用</label></div>
<div>
<button type="submit" class="btn-add">新增上游</button>
<a class="btn-link" href="/admin/stats">打开完整统计页</a>
<a class="btn-link" href="/admin/logout">退出管理</a>
</div>
</form>
</div>
<div class="panel">
<h2>上游统计 TOP 20</h2>
<div class="table-wrap">
<table>
<thead>
<tr>
<th>编号</th><th>中文别名</th><th>分组</th><th>模型</th><th>总请求</th><th>成功</th><th>失败</th><th>总Token</th><th>最后调用</th>
</tr>
</thead>
<tbody>{top_html}</tbody>
</table>
</div>
</div>
<div class="panel">
<h2>最近 50 次调用(可直接看到命中上游)</h2>
<div class="table-wrap">
<table>
<thead>
<tr>
<th>日志ID</th><th>时间</th><th>状态码</th><th>Token</th><th>上游ID</th><th>中文别名</th><th>分组</th><th>模型</th>
</tr>
</thead>
<tbody>{recent_html}</tbody>
</table>
</div>
</div>
<div class="table-wrap">
<table>
<thead><tr><th style="width:80px;">编号</th><th>上游配置</th></tr></thead>
<tbody>{body_html}</tbody>
</table>
</div>
</body>
</html>"""
return Response(page, content_type="text/html; charset=utf-8")
@APP.get("/admin/stats")
def admin_stats() -> Response:
guard = ensure_admin_access()
if guard is not None:
return guard
conn = get_db()
daily_rows = conn.execute(
"""
SELECT
strftime('%Y-%m-%d', datetime(request_at, 'unixepoch', 'localtime')) AS day_key,
COUNT(*) AS req_total,
COALESCE(SUM(CASE WHEN status_code = 200 THEN 1 ELSE 0 END), 0) AS req_success,
COALESCE(SUM(CASE WHEN status_code IS NOT NULL AND status_code <> 200 THEN 1 ELSE 0 END), 0) AS req_fail,
COALESCE(SUM(used_tokens), 0) AS token_total
FROM logs
WHERE request_at >= ?
GROUP BY day_key
ORDER BY day_key DESC
LIMIT 14
""",
(utc_now_ts() - 14 * 86400,),
).fetchall()
error_rows = conn.execute(
"""
SELECT
l.req_id AS req_id,
l.status_code AS status_code,
l.error_text AS error_text,
l.finish_at AS finish_at,
u.name AS name,
u.group_name AS group_name,
u.model AS model
FROM logs l
LEFT JOIN upstreams u ON u.id = l.upstream_id
WHERE l.status_code IS NOT NULL AND l.status_code <> 200
ORDER BY l.id DESC
LIMIT 100
"""
).fetchall()
day_parts: list[str] = []
for item in daily_rows:
row = sqlite_row_to_dict(item) or {}
day_parts.append(
f"<tr>"
f"<td>{html.escape(str(row.get('day_key') or '-'))}</td>"
f"<td>{row.get('req_total') or 0}</td>"
f"<td>{row.get('req_success') or 0}</td>"
f"<td>{row.get('req_fail') or 0}</td>"
f"<td>{row.get('token_total') or 0}</td>"
f"</tr>"
)
day_html = "".join(day_parts) if day_parts else "<tr><td colspan='5'>暂无数据</td></tr>"
err_parts: list[str] = []
for item in error_rows:
row = sqlite_row_to_dict(item) or {}
err_parts.append(
f"<tr>"
f"<td>{html.escape(str(row.get('req_id') or ''))}</td>"
f"<td>{row.get('status_code') or '-'}</td>"
f"<td>{html.escape(str(row.get('name') or ''))}</td>"
f"<td>{html.escape(str(row.get('group_name') or ''))}</td>"
f"<td>{html.escape(str(row.get('model') or ''))}</td>"
f"<td>{format_ts(row.get('finish_at'))}</td>"
f"<td>{html.escape(str(row.get('error_text') or ''))}</td>"
f"</tr>"
)
err_html = "".join(err_parts) if err_parts else "<tr><td colspan='7'>暂无失败日志</td></tr>"
page = f"""<!doctype html>
<html lang="zh-CN">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>完整统计</title>
<style>
body {{
margin: 0;
padding: 24px;
font-family: "Microsoft YaHei", "PingFang SC", sans-serif;
background: linear-gradient(180deg, #eef3ff 0%, #f7f9ff 48%, #f5f8ff 100%);
color: #21304b;
}}
h1 {{
margin: 0 0 10px;
font-size: 24px;
}}
.link {{
display: inline-block;
margin-bottom: 14px;
text-decoration: none;
color: #0b4ecf;
font-weight: 700;
background: #eef4ff;
border: 1px solid #cfe0ff;
border-radius: 10px;
padding: 8px 12px;
}}
.panel {{
border: 1px solid #d5e1f5;
border-radius: 12px;
background: #fff;
padding: 14px;
margin-bottom: 14px;
box-shadow: 0 8px 20px rgba(32, 48, 75, 0.06);
}}
table {{
border-collapse: collapse;
width: 100%;
}}
th, td {{
border-bottom: 1px solid #edf2fa;
border-right: 1px solid #edf2fa;
padding: 8px 10px;
font-size: 13px;
text-align: left;
word-break: break-word;
vertical-align: top;
}}
th {{
background: #eff5ff;
position: sticky;
top: 0;
}}
th:last-child, td:last-child {{
border-right: none;
}}
.table-wrap {{
border: 1px solid #d5e1f5;
border-radius: 10px;
overflow: auto;
background: #fff;
}}
tr:nth-child(even) td {{
background: #fbfdff;
}}
tr:hover td {{
background: #f1f6ff;
}}
</style>
</head>
<body>
<h1>完整统计</h1>
<a class="link" href="/admin">返回管理首页</a>
<div class="panel">
<h2>近 14 天按天统计</h2>
<div class="table-wrap">
<table>
<thead><tr><th>日期</th><th>总请求</th><th>成功</th><th>失败</th><th>总 Token</th></tr></thead>
<tbody>{day_html}</tbody>
</table>
</div>
</div>
<div class="panel">
<h2>最近 100 条失败日志</h2>
<div class="table-wrap">
<table>
<thead><tr><th>请求ID</th><th>状态码</th><th>中文别名</th><th>分组</th><th>模型</th><th>时间</th><th>错误</th></tr></thead>
<tbody>{err_html}</tbody>
</table>
</div>
</div>
</body>
</html>"""
return Response(page, content_type="text/html; charset=utf-8")
@APP.post("/admin/logs/cleanup")
def admin_logs_cleanup() -> Response:
guard = ensure_admin_access()
if guard is not None:
return guard
days = max(1, min(3650, to_form_int(request.form.get("days"), 7)))
cutoff_ts = utc_now_ts() - days * 86400
conn = get_db()
cursor = conn.execute("DELETE FROM logs WHERE request_at < ?", (cutoff_ts,))
deleted_count = int(cursor.rowcount or 0)
return redirect(url_for("admin_console", msg=f"日志清理完成:删除 {deleted_count} 条(保留最近 {days} 天)"))
@APP.post("/admin/upstreams/<int:upstream_id>/test")
def admin_upstream_test(upstream_id: int) -> Response:
guard = ensure_admin_access()
if guard is not None:
return guard
conn = get_db()
row = conn.execute("SELECT * FROM upstreams WHERE id = ?", (upstream_id,)).fetchone()
upstream = sqlite_row_to_dict(row)
if upstream is None:
return redirect(url_for("admin_console", err=f"连通性测试失败:编号 {upstream_id} 不存在"))
models_url, models_headers = build_models_meta(upstream)
try:
resp = HTTP_SESSION.get(
models_url,
headers=models_headers,
timeout=(SETTINGS.connect_timeout_sec, min(20, SETTINGS.read_timeout_sec)),
)
except requests.RequestException as exc:
resp = None
models_error = str(exc)
else:
models_error = ""
if resp is not None and resp.status_code == 200:
return redirect(url_for("admin_console", msg=f"编号 {upstream_id} 连通性正常:/models 返回 200"))
# /models 失败时,再用极小 token 的 chat/completions 探测,避免误判。
probe_payload = build_admin_probe_payload(upstream)
probe_url, probe_headers = build_request_meta(upstream)
try:
probe_resp = HTTP_SESSION.post(
probe_url,
headers=probe_headers,
json=probe_payload,
timeout=(SETTINGS.connect_timeout_sec, min(30, SETTINGS.read_timeout_sec)),
)
except requests.RequestException as exc:
reason = models_error or str(exc)
return redirect(url_for("admin_console", err=f"编号 {upstream_id} 连通性失败:{reason}"))
if probe_resp.status_code == 200:
return redirect(url_for("admin_console", msg=f"编号 {upstream_id} 连通性正常:/chat/completions 返回 200"))
detail = probe_resp.text[:180].replace("\n", " ").replace("\r", " ")
if resp is not None:
detail = f"/models={resp.status_code}, /chat={probe_resp.status_code}, detail={detail}"
else:
detail = f"/models网络异常, /chat={probe_resp.status_code}, detail={detail}"
return redirect(url_for("admin_console", err=f"编号 {upstream_id} 连通性失败:{detail}"))
@APP.post("/admin/upstreams/create")
def admin_upstream_create() -> Response:
guard = ensure_admin_access()
if guard is not None:
return guard
name = (request.form.get("name") or "").strip()
group_name = (request.form.get("group_name") or "1").strip()
model = (request.form.get("model") or "").strip()
base_url = (request.form.get("base_url") or "").strip()
api_key = (request.form.get("api_key") or "").strip()
limit_type = limit_type_alias(request.form.get("limit_type"))
limit_qty = max(0, to_form_int(request.form.get("limit_qty"), 0))
limit_tokens = max(0, to_form_int(request.form.get("limit_tokens"), 0))
enabled = 1 if (request.form.get("enabled") or "").strip().lower() in {"1", "on", "true"} else 0
ok_force, force_parameter = parse_force_parameter_text(request.form.get("force_parameter", ""))
if not name or not model or not base_url or not api_key:
return redirect(url_for("admin_console", err="新增失败:请把必填项填写完整"))
if not ok_force:
return redirect(url_for("admin_console", err=force_parameter or "新增失败:强制参数格式错误"))
now_ts = utc_now_ts()
conn = get_db()
conn.execute(
"""
INSERT INTO upstreams(
name, group_name, base_url, api_key, model, force_parameter,
limit_type, limit_qty, limit_tokens, enabled,
cycle_started_at, created_at, updated_at
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
name,
group_name or "1",
base_url,
api_key,
model,
force_parameter,
limit_type,
limit_qty,
limit_tokens,
enabled,
cycle_start_ts(limit_type, now_ts),
now_ts,
now_ts,
),
)
return redirect(url_for("admin_console", msg="新增成功"))
@APP.post("/admin/upstreams/<int:upstream_id>/update")
def admin_upstream_update(upstream_id: int) -> Response:
guard = ensure_admin_access()
if guard is not None:
return guard
name = (request.form.get("name") or "").strip()
group_name = (request.form.get("group_name") or "1").strip()
model = (request.form.get("model") or "").strip()
base_url = (request.form.get("base_url") or "").strip()
api_key = (request.form.get("api_key") or "").strip()
limit_type = limit_type_alias(request.form.get("limit_type"))
limit_qty = max(0, to_form_int(request.form.get("limit_qty"), 0))
limit_tokens = max(0, to_form_int(request.form.get("limit_tokens"), 0))
enabled = 1 if (request.form.get("enabled") or "").strip().lower() in {"1", "on", "true"} else 0
ok_force, force_parameter = parse_force_parameter_text(request.form.get("force_parameter", ""))
if not name or not model or not base_url or not api_key:
return redirect(url_for("admin_console", err=f"保存失败:编号 {upstream_id} 的必填项不能为空"))
if not ok_force:
return redirect(url_for("admin_console", err=f"保存失败:编号 {upstream_id} 的强制参数格式错误"))
conn = get_db()
current = conn.execute("SELECT id, limit_type FROM upstreams WHERE id = ?", (upstream_id,)).fetchone()
if current is None:
return redirect(url_for("admin_console", err=f"保存失败:编号 {upstream_id} 不存在"))
now_ts = utc_now_ts()
next_cycle_start = cycle_start_ts(limit_type, now_ts)
old_limit_type = limit_type_alias(str(current["limit_type"] or "day"))
if old_limit_type == limit_type:
cycle_started_at_sql = "cycle_started_at"
cycle_started_at_value = None
else:
cycle_started_at_sql = "?"
cycle_started_at_value = next_cycle_start
params: list[Any] = [
name,
group_name or "1",
base_url,
api_key,
model,
force_parameter,
limit_type,
limit_qty,
limit_tokens,
enabled,
now_ts,
upstream_id,
]
if cycle_started_at_value is None:
conn.execute(
"""
UPDATE upstreams
SET name = ?, group_name = ?, base_url = ?, api_key = ?, model = ?, force_parameter = ?,
limit_type = ?, limit_qty = ?, limit_tokens = ?, enabled = ?, updated_at = ?
WHERE id = ?
""",
params,
)
else:
conn.execute(
f"""
UPDATE upstreams
SET name = ?, group_name = ?, base_url = ?, api_key = ?, model = ?, force_parameter = ?,
limit_type = ?, limit_qty = ?, limit_tokens = ?, enabled = ?, cycle_started_at = {cycle_started_at_sql}, updated_at = ?
WHERE id = ?
""",
[
name,
group_name or "1",
base_url,
api_key,
model,
force_parameter,
limit_type,
limit_qty,
limit_tokens,
enabled,
next_cycle_start,
now_ts,
upstream_id,
],
)
return redirect(url_for("admin_console", msg=f"编号 {upstream_id} 已保存"))
@APP.post("/admin/upstreams/<int:upstream_id>/delete")
def admin_upstream_delete(upstream_id: int) -> Response:
guard = ensure_admin_access()
if guard is not None:
return guard
conn = get_db()
try:
conn.execute("BEGIN IMMEDIATE")
conn.execute("DELETE FROM logs WHERE upstream_id = ?", (upstream_id,))
conn.execute("DELETE FROM upstreams WHERE id = ?", (upstream_id,))
conn.execute("COMMIT")
except Exception as exc:
try:
conn.execute("ROLLBACK")
except sqlite3.Error:
pass
return redirect(url_for("admin_console", err=f"删除失败:{exc}"))
return redirect(url_for("admin_console", msg=f"编号 {upstream_id} 已删除(关联日志已同步清理)"))
@APP.get("/v1/models")
def list_models() -> Response | tuple[Response, int]:
auth = request.headers.get("Authorization", "")
if not auth_header_ok(auth):
return unauthorized_response()
conn = get_db()
rows = conn.execute(
"""
SELECT model, MIN(created_at) AS created_at
FROM upstreams
WHERE enabled = 1 AND model IS NOT NULL AND TRIM(model) <> ''
GROUP BY model
ORDER BY model ASC
"""
).fetchall()
data = []
now_ts = utc_now_ts()
for row in rows:
model_id = str(row["model"]).strip()
if not model_id:
continue
created_at = int(row["created_at"]) if row["created_at"] is not None else now_ts
data.append(
{
"id": model_id,
"object": "model",
"created": created_at,
"owned_by": "openai-proxy",
}
)
return jsonify({"object": "list", "data": data})
def forward_non_stream_request(
conn: sqlite3.Connection,
upstream: dict[str, Any],
request_url: str,
request_headers: dict[str, str],
payload: dict[str, Any],
log_id: str,
tried_ids: list[int],
) -> tuple[bool, Response]:
upstream_id = int(upstream["id"])
try:
result = HTTP_SESSION.post(
request_url,
headers=request_headers,
json=payload,
timeout=(SETTINGS.connect_timeout_sec, SETTINGS.read_timeout_sec),
)
except requests.RequestException as exc:
error_text = str(exc)
mark_upstream_failure(conn, upstream_id)
finish_request_log(conn, log_id, status_code=599, used_tokens=0, error_text=error_text)
return False, attach_route_headers(upstream_error_response(error_text, status_code=502), upstream, tried_ids)
status_code = int(result.status_code)
content_type = result.headers.get("Content-Type") or "application/json; charset=utf-8"
body_bytes = result.content
body_text = result.text
if status_code != 200:
mark_upstream_failure(conn, upstream_id)
finish_request_log(
conn,
log_id,
status_code=status_code,
used_tokens=0,
finish_text=None,
error_text=body_text,
)
response = attach_route_headers(Response(body_bytes, status=status_code, content_type=content_type), upstream, tried_ids)
return (not is_retryable_failure(status_code, body_text)), response
used_tokens = 0
finish_text = ""
try:
result_json = result.json()
used_tokens = extract_total_tokens(result_json.get("usage"))
finish_text = extract_finish_text_from_response(result_json)
except ValueError:
pass
mark_upstream_success(conn, upstream_id, used_tokens)
finish_request_log(
conn,
log_id,
status_code=status_code,
used_tokens=used_tokens,
finish_text=finish_text or body_text,
error_text=None,
)
response = attach_route_headers(Response(body_bytes, status=status_code, content_type=content_type), upstream, tried_ids)
return True, response
def finalize_stream_outcome(
upstream_id: int,
log_id: str,
used_tokens: int,
finish_text: str,
stream_error: str | None,
) -> None:
# 流式响应结束时,请求上下文可能已经关闭,必须使用独立连接落库。
conn: sqlite3.Connection | None = None
try:
conn = open_sqlite_conn(SETTINGS.sqlite_path)
if stream_error:
mark_upstream_failure(conn, upstream_id)
finish_request_log(
conn,
log_id,
status_code=599,
used_tokens=used_tokens,
finish_text=finish_text,
error_text=stream_error,
)
else:
mark_upstream_success(conn, upstream_id, used_tokens)
finish_request_log(
conn,
log_id,
status_code=200,
used_tokens=used_tokens,
finish_text=finish_text,
error_text=None,
)
except Exception:
LOGGER.exception("流式请求收尾写库失败 upstream_id=%s log_id=%s", upstream_id, log_id)
finally:
if conn is not None:
conn.close()
def stream_response_generator(
upstream_id: int,
log_id: str,
upstream_resp: requests.Response,
) -> Iterator[str]:
used_tokens = 0
finish_parts: list[str] = []
stream_error: str | None = None
try:
for line in upstream_resp.iter_lines(decode_unicode=True):
if line is None:
continue
if line == "":
yield "\n"
continue
if line.startswith("data:"):
data = line[len("data:") :].lstrip()
if data == "[DONE]":
yield "data: [DONE]\n\n"
continue
try:
chunk_dict = json.loads(data)
except json.JSONDecodeError:
yield line + "\n"
continue
usage_tokens = extract_total_tokens(chunk_dict.get("usage"))
if usage_tokens > 0:
used_tokens = usage_tokens
delta_text = extract_delta_text_from_chunk(chunk_dict)
if delta_text:
finish_parts.append(delta_text)
yield "data: " + json.dumps(chunk_dict, ensure_ascii=False) + "\n\n"
continue
yield line + "\n"
except requests.RequestException as exc:
stream_error = str(exc)
raise
finally:
upstream_resp.close()
finish_text = "".join(finish_parts)
finalize_stream_outcome(upstream_id, log_id, used_tokens, finish_text, stream_error)
def forward_stream_request(
conn: sqlite3.Connection,
upstream: dict[str, Any],
request_url: str,
request_headers: dict[str, str],
payload: dict[str, Any],
log_id: str,
tried_ids: list[int],
) -> tuple[bool, Response]:
upstream_id = int(upstream["id"])
try:
upstream_resp = HTTP_SESSION.post(
request_url,
headers=request_headers,
json=payload,
stream=True,
timeout=(SETTINGS.connect_timeout_sec, SETTINGS.stream_read_timeout_sec),
)
except requests.RequestException as exc:
error_text = str(exc)
mark_upstream_failure(conn, upstream_id)
finish_request_log(conn, log_id, status_code=599, used_tokens=0, error_text=error_text)
return False, attach_route_headers(upstream_error_response(error_text, status_code=502), upstream, tried_ids)
status_code = int(upstream_resp.status_code)
if status_code != 200:
body_text = upstream_resp.text
body_bytes = upstream_resp.content
content_type = upstream_resp.headers.get("Content-Type") or "application/json; charset=utf-8"
upstream_resp.close()
mark_upstream_failure(conn, upstream_id)
finish_request_log(
conn,
log_id,
status_code=status_code,
used_tokens=0,
finish_text=None,
error_text=body_text,
)
response = attach_route_headers(Response(body_bytes, status=status_code, content_type=content_type), upstream, tried_ids)
return (not is_retryable_failure(status_code, body_text)), response
content_type = upstream_resp.headers.get("Content-Type") or "text/event-stream; charset=utf-8"
stream_iter = stream_with_context(stream_response_generator(upstream_id, log_id, upstream_resp))
return True, attach_route_headers(Response(stream_iter, status=200, content_type=content_type), upstream, tried_ids)
@APP.post("/v1/chat/completions")
def proxy_chat_completions() -> Response | tuple[Response, int]:
auth = request.headers.get("Authorization", "")
if not auth_header_ok(auth):
return unauthorized_response()
user_post = request.get_json(silent=True)
if not isinstance(user_post, dict):
return jsonify({"error": {"message": "Invalid JSON body"}}), 400
selector = parse_selector(str(user_post.get("model", "")))
if selector is None:
return jsonify({"error": {"message": "model is required"}}), 400
conn = get_db()
tried_ids: list[int] = []
last_retryable_response: Response | None = None
for _ in range(SETTINGS.max_failover_attempts):
upstream = choose_upstream(conn, selector, tried_ids)
if upstream is None:
break
upstream_id = int(upstream["id"])
payload = merge_force_parameter(user_post, upstream)
request_url, request_headers = build_request_meta(upstream)
log_id = create_request_log(conn, upstream_id, payload)
stream_mode = bool(payload.get("stream", False))
if stream_mode:
should_stop, response = forward_stream_request(
conn, upstream, request_url, request_headers, payload, log_id, tried_ids
)
else:
should_stop, response = forward_non_stream_request(
conn, upstream, request_url, request_headers, payload, log_id, tried_ids
)
if should_stop:
return response
tried_ids.append(upstream_id)
last_retryable_response = response
if last_retryable_response is not None:
return last_retryable_response
no_available_message = collect_no_available_reason(conn, selector)
return jsonify(
{
"error": {
"message": no_available_message,
"type": "routing_error",
"code": "no_available_model_endpoint",
}
}
), 429
def setup_logging() -> None:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="OpenAI compatible proxy with SQLite routing")
parser.add_argument("--init-db", action="store_true", help="Only initialize SQLite schema and exit")
return parser.parse_args()
def startup_checks() -> None:
if SETTINGS.expected_auth == DEFAULT_AUTH:
LOGGER.warning(
"OPENAI_PROXY_AUTH is using default value '%s'. Please set a strong bearer token before exposing the service.",
DEFAULT_AUTH,
)
def main() -> None:
setup_logging()
startup_checks()
args = parse_args()
init_db(SETTINGS.sqlite_path)
if args.init_db:
LOGGER.info("SQLite schema initialized: %s", SETTINGS.sqlite_path)
return
APP.run(
host=SETTINGS.listen_host,
port=SETTINGS.listen_port,
debug=SETTINGS.debug,
threaded=True,
)
if __name__ == "__main__":
main()