From fa4e668d2e009728479e0050dda03b10b0934da9 Mon Sep 17 00:00:00 2001 From: laowang Date: Fri, 20 Mar 2026 21:01:58 +0800 Subject: [PATCH] list --- .env.example | 32 + README.md | 140 +++ openai_route.py | 2168 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 2340 insertions(+) create mode 100644 .env.example create mode 100644 README.md create mode 100644 openai_route.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..28bf20d --- /dev/null +++ b/.env.example @@ -0,0 +1,32 @@ +# 代理接口鉴权 建议直接填 token 本体 +# 例如 sk-abc123 +# 客户端可填 sk-abc123,系统会自动兼容 Bearer 格式 +OPENAI_PROXY_AUTH=请改成你的代理令牌 + +# 可视化后台管理员口令 访问 /admin/login 使用 +OPENAI_PROXY_ADMIN_TOKEN=请改成你的后台口令 + +# Flask 会话密钥 建议改成随机长字符串 +OPENAI_PROXY_SECRET_KEY=请改成随机长密钥 + +# 数据库文件路径 +OPENAI_PROXY_SQLITE_PATH=./data/openai_proxy.db + +# 监听地址和端口 +OPENAI_PROXY_LISTEN_HOST=0.0.0.0 +OPENAI_PROXY_LISTEN_PORT=8056 + +# 失败自动切换相关 +OPENAI_PROXY_MAX_FAILOVER_ATTEMPTS=5 +OPENAI_PROXY_FAILURE_COOLDOWN_SEC=300 +OPENAI_PROXY_AUTO_GROUP_FALLBACK=true + +# 后台开关 true 或 false +OPENAI_PROXY_ADMIN_ENABLED=true + +# 日志策略 +OPENAI_PROXY_LOG_FULL_PAYLOAD=false +OPENAI_PROXY_LOG_TEXT_LIMIT=100000 + +# 调试开关 生产建议 false +OPENAI_PROXY_DEBUG=false diff --git a/README.md b/README.md new file mode 100644 index 0000000..42fbbb7 --- /dev/null +++ b/README.md @@ -0,0 +1,140 @@ +# OpenAI 路由代理(SQLite + 可视化管理) + +这是一个面向个人和小团队的 OpenAI 兼容代理,重点是: + +1. 不依赖 MySQL,单文件 SQLite 即可运行。 +2. 一个统一入口接入多个上游 API。 +3. 上游失败时自动切换,减少手工改配置。 +4. 提供中文可视化后台,支持上游管理、统计查看、连通性测试、日志清理。 + +## 适用场景 + +1. 你有多个免费/限额 API key,想统一管理。 +2. 客户端(如 Cherry Studio / OpenClaw)只想配一个地址和一个 key。 +3. 希望某个上游挂了能自动切到下一个上游。 + +## 快速开始 + +1. 安装依赖 + +```bash +pip install flask requests +``` + +2. 复制配置模板 + +```powershell +Copy-Item .env.example .env +``` + +3. 打开 `.env`,至少修改这三项 + +1. `OPENAI_PROXY_AUTH`:客户端访问代理用的 key(建议填 token 本体)。 +2. `OPENAI_PROXY_ADMIN_TOKEN`:后台登录口令。 +3. `OPENAI_PROXY_SECRET_KEY`:Flask 会话密钥(建议随机长字符串)。 + +4. 初始化数据库(首次一次) + +```bash +python openai_route.py --init-db +``` + +5. 启动服务 + +```bash +python openai_route.py +``` + +启动后访问: + +1. 首页:`http://127.0.0.1:8056/` +2. 管理后台:`http://127.0.0.1:8056/admin` + +## 后台功能 + +后台首页 `/admin` 支持: + +1. 新增上游 +2. 编辑上游 +3. 启用/停用上游 +4. 删除上游(会自动清理关联日志) +5. 一键测试单条上游连通性(按钮:`测试连通性`) +6. 清理历史日志(按钮:`清理日志`) +7. 查看最近调用命中哪条上游(“最近 50 次调用”) + +统计页 `/admin/stats` 支持: + +1. 近 14 天请求统计 +2. 最近失败日志明细 + +## 字段说明(后台) + +1. 中文别名:客户端可直接填这个值作为 `model`。 +2. 分组编号:用于分组路由,支持数字和英文(例如 `1`、`nvidia_pool`)。 +3. 上游模型名:转发给上游时实际使用的模型名。 +4. 上游地址:上游 `.../v1` 地址。 +5. API 密钥:该上游 key。 +6. 强制参数(JSON):会覆盖或补充客户端请求参数。 +7. 请求限额:周期内请求数,`0` 表示不限额。 +8. Token 限额:周期内 token,`0` 表示不限额。 + +## 客户端如何填写 + +1. `base_url`:`http://127.0.0.1:8056/v1` +2. `api_key`:`.env` 里的 `OPENAI_PROXY_AUTH` +3. `model`:建议填“中文别名”或“分组编号” + +示例: + +1. 使用中文别名池:`model=通用聊天` +2. 使用数字分组:`model=1` +3. 使用英文分组:`model=group:nvidia_pool`(也支持直接填 `nvidia_pool`) + +## 路由识别规则 + +`model` 的识别顺序: + +1. 以 `group:` 开头:按分组路由 +2. 包含逗号:按多分组路由 +3. 纯数字:按分组路由 +4. 其他:先按“中文别名/上游模型名”精确匹配,未命中再按“分组名”匹配 + +## 自动切换规则 + +当命中池子后,系统会在可用上游中选择并自动切换: + +1. 默认“稳定优先”:按你配置顺序(上游编号从小到大)优先使用,不会每轮都轮换模型。 +2. 仅当本次请求失败时,才会在同一请求内自动切到下一个上游(无须客户端手工改模型)。 +3. 这些状态会触发自动切换:`401/403/404/408/409/425/429/500/502/503/504`。 +4. `400/422` 仅在错误文本明显属于“模型不存在/不可用”时才切换,避免误切换。 +5. 网络异常也会切换到下一个上游。 + +可通过响应头查看命中信息: + +1. `X-Route-Upstream-Id` +2. `X-Route-Group` +3. `X-Route-Model` +4. `X-Route-Tried-Ids` + +## 429 常见误解 + +很多客户端会把所有 `429` 翻译成“请求速率超过限制”,但不一定真是限速。 +请以后端返回的 `error.message` 为准,里面会给出详细原因(匹配数、可用数、冷却数、限额阻塞数)。 + +## 常用环境变量 + +1. `OPENAI_PROXY_AUTH` +2. `OPENAI_PROXY_ADMIN_TOKEN` +3. `OPENAI_PROXY_SECRET_KEY` +4. `OPENAI_PROXY_SQLITE_PATH` +5. `OPENAI_PROXY_MAX_FAILOVER_ATTEMPTS` +6. `OPENAI_PROXY_FAILURE_COOLDOWN_SEC` +7. `OPENAI_PROXY_AUTO_GROUP_FALLBACK` +8. `OPENAI_PROXY_ADMIN_ENABLED` + +## 生产建议 + +1. 务必修改默认口令。 +2. 后台建议仅内网可访问。 +3. 定期清理日志,避免 SQLite 持续膨胀。 +4. 若后续是多实例高并发,再迁移 MySQL/Postgres。 diff --git a/openai_route.py b/openai_route.py new file mode 100644 index 0000000..b97c094 --- /dev/null +++ b/openai_route.py @@ -0,0 +1,2168 @@ +from __future__ import annotations + +import argparse +import html +import json +import logging +import os +import sqlite3 +import time +import uuid +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Iterator + +import requests +from flask import Flask, Response, abort, g, jsonify, redirect, request, session, stream_with_context, url_for +from requests.adapters import HTTPAdapter + + +LOGGER = logging.getLogger("openai_proxy") +DEFAULT_AUTH = "Bearer change-me" +DEFAULT_SECRET_KEY = "openai-proxy-dev-secret-change-me" +RETRYABLE_STATUSES = {401, 403, 404, 408, 409, 425, 429, 500, 502, 503, 504} + + +def env_int(name: str, default: int) -> int: + raw = os.getenv(name) + if raw is None: + return default + try: + return int(raw.strip()) + except (TypeError, ValueError): + return default + + +def env_bool(name: str, default: bool) -> bool: + raw = os.getenv(name) + if raw is None: + return default + value = raw.strip().lower() + return value in {"1", "true", "yes", "on"} + + +def load_dotenv_file(path: Path) -> None: + if not path.exists() or not path.is_file(): + return + try: + lines = path.read_text(encoding="utf-8-sig").splitlines() + except OSError: + return + for line in lines: + text = line.strip() + if not text or text.startswith("#"): + continue + if text.startswith("export "): + text = text[7:].strip() + if "=" not in text: + continue + key, value = text.split("=", 1) + key = key.strip() + value = value.strip().strip("'").strip('"') + if not key: + continue + if key not in os.environ: + os.environ[key] = value + + +@dataclass(frozen=True) +class Settings: + expected_auth: str + sqlite_path: Path + busy_timeout_ms: int + listen_host: str + listen_port: int + connect_timeout_sec: int + read_timeout_sec: int + stream_read_timeout_sec: int + max_failover_attempts: int + auto_group_fallback: bool + failure_cooldown_sec: int + log_full_payload: bool + log_text_limit: int + admin_enabled: bool + admin_token: str + secret_key: str + debug: bool + + +def load_settings() -> Settings: + sqlite_path = Path(os.getenv("OPENAI_PROXY_SQLITE_PATH", "./data/openai_proxy.db")).expanduser() + return Settings( + expected_auth=os.getenv("OPENAI_PROXY_AUTH", DEFAULT_AUTH), + sqlite_path=sqlite_path, + busy_timeout_ms=env_int("OPENAI_PROXY_DB_BUSY_TIMEOUT_MS", 5000), + listen_host=os.getenv("OPENAI_PROXY_LISTEN_HOST", "0.0.0.0"), + listen_port=env_int("OPENAI_PROXY_LISTEN_PORT", 8056), + connect_timeout_sec=env_int("OPENAI_PROXY_CONNECT_TIMEOUT_SEC", 10), + read_timeout_sec=env_int("OPENAI_PROXY_READ_TIMEOUT_SEC", 120), + stream_read_timeout_sec=env_int("OPENAI_PROXY_STREAM_READ_TIMEOUT_SEC", 600), + max_failover_attempts=max(1, env_int("OPENAI_PROXY_MAX_FAILOVER_ATTEMPTS", 5)), + auto_group_fallback=env_bool("OPENAI_PROXY_AUTO_GROUP_FALLBACK", True), + failure_cooldown_sec=max(0, env_int("OPENAI_PROXY_FAILURE_COOLDOWN_SEC", 30)), + log_full_payload=env_bool("OPENAI_PROXY_LOG_FULL_PAYLOAD", False), + log_text_limit=max(256, env_int("OPENAI_PROXY_LOG_TEXT_LIMIT", 100000)), + admin_enabled=env_bool("OPENAI_PROXY_ADMIN_ENABLED", True), + admin_token=os.getenv("OPENAI_PROXY_ADMIN_TOKEN", "").strip(), + secret_key=os.getenv("OPENAI_PROXY_SECRET_KEY", DEFAULT_SECRET_KEY).strip() or DEFAULT_SECRET_KEY, + debug=env_bool("OPENAI_PROXY_DEBUG", False), + ) + + +load_dotenv_file(Path(os.getenv("OPENAI_PROXY_ENV_FILE", ".env"))) +SETTINGS = load_settings() +APP = Flask(__name__) +APP.config["SECRET_KEY"] = SETTINGS.secret_key +APP.config["SESSION_COOKIE_HTTPONLY"] = True +APP.config["SESSION_COOKIE_SAMESITE"] = "Lax" +HTTP_SESSION = requests.Session() +HTTP_SESSION.mount("http://", HTTPAdapter(pool_connections=50, pool_maxsize=50, max_retries=0)) +HTTP_SESSION.mount("https://", HTTPAdapter(pool_connections=50, pool_maxsize=50, max_retries=0)) + + +def utc_now_ts() -> int: + return int(time.time()) + + +def limit_type_alias(limit_type: str | None) -> str: + value = (limit_type or "").strip().lower() + alias_map = { + "miao": "second", + "fen": "minute", + "shi": "hour", + "tian": "day", + "yue": "month", + "nian": "year", + "sec": "second", + "min": "minute", + } + if value in {"second", "minute", "hour", "day", "month", "year"}: + return value + return alias_map.get(value, "day") + + +def cycle_start_ts(limit_type: str | None, ts: int) -> int: + unit = limit_type_alias(limit_type) + dt = datetime.fromtimestamp(ts, tz=timezone.utc) + if unit == "second": + return ts + if unit == "minute": + return int(dt.replace(second=0, microsecond=0).timestamp()) + if unit == "hour": + return int(dt.replace(minute=0, second=0, microsecond=0).timestamp()) + if unit == "day": + return int(dt.replace(hour=0, minute=0, second=0, microsecond=0).timestamp()) + if unit == "month": + return int(dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0).timestamp()) + if unit == "year": + return int(dt.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0).timestamp()) + return int(dt.replace(hour=0, minute=0, second=0, microsecond=0).timestamp()) + + +def open_sqlite_conn(db_path: Path) -> sqlite3.Connection: + db_path.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(str(db_path), timeout=SETTINGS.busy_timeout_ms / 1000, isolation_level=None) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL;") + conn.execute(f"PRAGMA busy_timeout={SETTINGS.busy_timeout_ms};") + conn.execute("PRAGMA synchronous=NORMAL;") + conn.execute("PRAGMA foreign_keys=ON;") + return conn + + +def init_db(db_path: Path) -> None: + conn = open_sqlite_conn(db_path) + try: + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS upstreams ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + group_name TEXT NOT NULL DEFAULT '1', + base_url TEXT NOT NULL, + api_key TEXT NOT NULL, + model TEXT NOT NULL, + force_parameter TEXT, + limit_type TEXT NOT NULL DEFAULT 'day', + limit_qty INTEGER NOT NULL DEFAULT 0, + limit_tokens INTEGER NOT NULL DEFAULT 0, + used_cycle_qty INTEGER NOT NULL DEFAULT 0, + used_cycle_tokens INTEGER NOT NULL DEFAULT 0, + used_all_qty INTEGER NOT NULL DEFAULT 0, + used_all_tokens INTEGER NOT NULL DEFAULT 0, + cycle_started_at INTEGER NOT NULL DEFAULT (unixepoch()), + last_used_at INTEGER, + consecutive_failures INTEGER NOT NULL DEFAULT 0, + cooldown_until INTEGER, + enabled INTEGER NOT NULL DEFAULT 1, + created_at INTEGER NOT NULL DEFAULT (unixepoch()), + updated_at INTEGER NOT NULL DEFAULT (unixepoch()) + ); + + CREATE INDEX IF NOT EXISTS idx_upstreams_enabled_model + ON upstreams(enabled, model, name, group_name); + + CREATE INDEX IF NOT EXISTS idx_upstreams_scheduling + ON upstreams(enabled, cooldown_until, consecutive_failures, last_used_at); + + CREATE TABLE IF NOT EXISTS logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + req_id TEXT NOT NULL UNIQUE, + upstream_id INTEGER NOT NULL, + request_at INTEGER NOT NULL, + finish_at INTEGER, + status_code INTEGER, + used_tokens INTEGER NOT NULL DEFAULT 0, + request_payload TEXT, + finish_text TEXT, + error_text TEXT, + FOREIGN KEY (upstream_id) REFERENCES upstreams(id) + ); + + CREATE INDEX IF NOT EXISTS idx_logs_request_at + ON logs(request_at); + """ + ) + finally: + conn.close() + + +def get_db() -> sqlite3.Connection: + conn = getattr(g, "_sqlite_conn", None) + if conn is None: + conn = open_sqlite_conn(SETTINGS.sqlite_path) + g._sqlite_conn = conn + return conn + + +@APP.teardown_appcontext +def close_db(_: BaseException | None) -> None: + conn = getattr(g, "_sqlite_conn", None) + if conn is not None: + conn.close() + g._sqlite_conn = None + + +def sqlite_row_to_dict(row: sqlite3.Row | None) -> dict[str, Any] | None: + if row is None: + return None + return {key: row[key] for key in row.keys()} + + +def parse_selector(raw_model: str) -> dict[str, Any] | None: + value = (raw_model or "").strip()[:100] + if not value: + return None + if value.lower().startswith("group:"): + group_part = value[6:] + groups = [part.strip() for part in group_part.split(",") if part.strip()] + if not groups: + return None + return {"type": "groups", "groups": groups, "value": value} + if "," in value: + groups = [part.strip() for part in value.split(",") if part.strip()] + if not groups: + return None + return {"type": "groups", "groups": groups, "value": value} + if value.isdigit(): + return {"type": "groups", "groups": [value], "value": value} + return {"type": "exact", "model": value, "value": value} + + +def fetch_candidate_ids( + conn: sqlite3.Connection, selector: dict[str, Any], excluded_ids: list[int], now_ts: int +) -> list[int]: + params: list[Any] = [] + sql = "SELECT id FROM upstreams WHERE enabled = 1" + if excluded_ids: + placeholders = ",".join("?" for _ in excluded_ids) + sql += f" AND id NOT IN ({placeholders})" + params.extend(excluded_ids) + if selector["type"] == "groups": + groups = list(selector["groups"]) + if SETTINGS.auto_group_fallback and len(groups) == 1 and groups[0].strip().isdigit(): + start_group = int(groups[0].strip()) + group_rows = conn.execute("SELECT DISTINCT group_name FROM upstreams WHERE enabled = 1").fetchall() + numeric_groups = sorted( + { + int(str(row["group_name"]).strip()) + for row in group_rows + if str(row["group_name"]).strip().isdigit() + } + ) + expanded_groups = [str(item) for item in numeric_groups if item >= start_group] + if expanded_groups: + groups = expanded_groups + placeholders = ",".join("?" for _ in groups) + sql += f" AND group_name IN ({placeholders})" + params.extend(groups) + else: + sql += " AND (name = ? OR model = ?)" + params.append(selector["model"]) + params.append(selector["model"]) + sql += """ + ORDER BY + CASE WHEN COALESCE(cooldown_until, 0) <= ? THEN 0 ELSE 1 END ASC, + -- 稳定优先:默认按配置顺序(编号)选择,失败才在单次请求内切到下一个。 + id ASC + """ + params.append(now_ts) + rows = conn.execute(sql, params).fetchall() + return [int(row["id"]) for row in rows] + + +def selector_group_fallback(selector: dict[str, Any]) -> dict[str, Any] | None: + if selector.get("type") != "exact": + return None + raw = str(selector.get("model") or "").strip() + if not raw: + return None + return {"type": "groups", "groups": [raw], "value": f"group:{raw}"} + + +def evaluate_row_available_now(row: dict[str, Any], now_ts: int) -> tuple[bool, str]: + cooldown_until = int(row.get("cooldown_until") or 0) + if cooldown_until > now_ts: + return False, "cooldown" + + cycle_started_at = int(row.get("cycle_started_at") or 0) + target_cycle_start = cycle_start_ts(str(row.get("limit_type") or "day"), now_ts) + used_cycle_qty = int(row.get("used_cycle_qty") or 0) + used_cycle_tokens = int(row.get("used_cycle_tokens") or 0) + if cycle_started_at < target_cycle_start: + used_cycle_qty = 0 + used_cycle_tokens = 0 + + limit_qty = int(row.get("limit_qty") or 0) + limit_tokens = int(row.get("limit_tokens") or 0) + qty_ok = limit_qty <= 0 or used_cycle_qty < limit_qty + token_ok = limit_tokens <= 0 or used_cycle_tokens < limit_tokens + if qty_ok and token_ok: + return True, "ok" + return False, "quota" + + +def collect_no_available_reason(conn: sqlite3.Connection, selector: dict[str, Any]) -> str: + now_ts = utc_now_ts() + selectors = [selector] + fallback_selector = selector_group_fallback(selector) + if fallback_selector is not None: + selectors.append(fallback_selector) + + matched_ids: set[int] = set() + for item in selectors: + matched_ids.update(fetch_candidate_ids(conn, item, [], now_ts)) + + if not matched_ids: + model_value = str(selector.get("value") or selector.get("model") or "") + return f"No available model endpoint: 未匹配到可用上游,请检查 model/分组。当前 model={model_value}" + + placeholders = ",".join("?" for _ in matched_ids) + rows = conn.execute( + f""" + SELECT id, name, group_name, model, limit_type, limit_qty, limit_tokens, + used_cycle_qty, used_cycle_tokens, cycle_started_at, cooldown_until + FROM upstreams + WHERE id IN ({placeholders}) + """, + tuple(sorted(matched_ids)), + ).fetchall() + cooldown_count = 0 + quota_count = 0 + ok_count = 0 + for row in rows: + row_dict = sqlite_row_to_dict(row) or {} + available, reason = evaluate_row_available_now(row_dict, now_ts) + if available: + ok_count += 1 + elif reason == "cooldown": + cooldown_count += 1 + elif reason == "quota": + quota_count += 1 + + return ( + "No available model endpoint: " + f"匹配到 {len(matched_ids)} 条上游,当前可用 {ok_count} 条,冷却中 {cooldown_count} 条,限额阻塞 {quota_count} 条。" + ) + + +def refresh_cycle_if_needed(conn: sqlite3.Connection, row: dict[str, Any], now_ts: int) -> dict[str, Any]: + current_cycle_start = int(row.get("cycle_started_at") or 0) + target_cycle_start = cycle_start_ts(row.get("limit_type"), now_ts) + if current_cycle_start >= target_cycle_start: + return row + conn.execute( + """ + UPDATE upstreams + SET used_cycle_qty = 0, + used_cycle_tokens = 0, + cycle_started_at = ?, + updated_at = ? + WHERE id = ? + """, + (target_cycle_start, now_ts, row["id"]), + ) + row["used_cycle_qty"] = 0 + row["used_cycle_tokens"] = 0 + row["cycle_started_at"] = target_cycle_start + return row + + +def has_quota(row: dict[str, Any]) -> bool: + limit_qty = int(row.get("limit_qty") or 0) + limit_tokens = int(row.get("limit_tokens") or 0) + used_cycle_qty = int(row.get("used_cycle_qty") or 0) + used_cycle_tokens = int(row.get("used_cycle_tokens") or 0) + qty_ok = limit_qty <= 0 or used_cycle_qty < limit_qty + token_ok = limit_tokens <= 0 or used_cycle_tokens < limit_tokens + return qty_ok and token_ok + + +def reserve_upstream(conn: sqlite3.Connection, upstream_id: int, now_ts: int) -> dict[str, Any] | None: + try: + conn.execute("BEGIN IMMEDIATE") + row = sqlite_row_to_dict( + conn.execute("SELECT * FROM upstreams WHERE id = ? AND enabled = 1", (upstream_id,)).fetchone() + ) + if row is None: + conn.execute("ROLLBACK") + return None + row = refresh_cycle_if_needed(conn, row, now_ts) + cooldown_until = int(row.get("cooldown_until") or 0) + if cooldown_until > now_ts: + conn.execute("ROLLBACK") + return None + if not has_quota(row): + conn.execute("ROLLBACK") + return None + conn.execute( + """ + UPDATE upstreams + SET used_cycle_qty = used_cycle_qty + 1, + used_all_qty = used_all_qty + 1, + last_used_at = ?, + updated_at = ? + WHERE id = ? + """, + (now_ts, now_ts, upstream_id), + ) + conn.execute("COMMIT") + row["used_cycle_qty"] = int(row.get("used_cycle_qty") or 0) + 1 + row["used_all_qty"] = int(row.get("used_all_qty") or 0) + 1 + row["last_used_at"] = now_ts + return row + except sqlite3.OperationalError: + try: + conn.execute("ROLLBACK") + except sqlite3.Error: + pass + return None + except Exception: + try: + conn.execute("ROLLBACK") + except sqlite3.Error: + pass + raise + + +def mark_upstream_success(conn: sqlite3.Connection, upstream_id: int, used_tokens: int) -> None: + now_ts = utc_now_ts() + conn.execute( + """ + UPDATE upstreams + SET used_cycle_tokens = used_cycle_tokens + ?, + used_all_tokens = used_all_tokens + ?, + consecutive_failures = 0, + cooldown_until = NULL, + updated_at = ? + WHERE id = ? + """, + (used_tokens, used_tokens, now_ts, upstream_id), + ) + + +def mark_upstream_failure(conn: sqlite3.Connection, upstream_id: int) -> None: + now_ts = utc_now_ts() + cooldown_until = now_ts + SETTINGS.failure_cooldown_sec if SETTINGS.failure_cooldown_sec > 0 else now_ts + conn.execute( + """ + UPDATE upstreams + SET consecutive_failures = consecutive_failures + 1, + cooldown_until = ?, + updated_at = ? + WHERE id = ? + """, + (cooldown_until, now_ts, upstream_id), + ) + + +def is_retryable_status(status_code: int) -> bool: + return status_code in RETRYABLE_STATUSES + + +def is_retryable_failure(status_code: int, response_text: str) -> bool: + if is_retryable_status(status_code): + return True + if status_code not in {400, 422}: + return False + text = (response_text or "").lower() + # 对常见“模型不存在/模型不可用”类错误放行重试到下一个上游。 + model_hints = [ + "model", + "not found", + "no such model", + "unknown model", + "invalid model", + "model not", + "模型", + "不存在", + "不可用", + ] + return any(token in text for token in model_hints) + + +def parse_int_like(value: Any) -> int | None: + if isinstance(value, int): + return value + if isinstance(value, float): + return int(value) + if isinstance(value, str): + text = value.strip() + if not text: + return None + if text.isdigit() or (text.startswith("-") and text[1:].isdigit()): + try: + return int(text) + except ValueError: + return None + try: + return int(float(text)) + except ValueError: + return None + return None + + +def sanitize_request_payload(payload: dict[str, Any]) -> dict[str, Any]: + if SETTINGS.log_full_payload: + return payload + messages = payload.get("messages") + message_count = len(messages) if isinstance(messages, list) else 0 + return { + "model": payload.get("model"), + "stream": bool(payload.get("stream", False)), + "max_tokens": payload.get("max_tokens"), + "max_output_tokens": payload.get("max_output_tokens"), + "message_count": message_count, + } + + +def create_request_log(conn: sqlite3.Connection, upstream_id: int, payload: dict[str, Any]) -> str: + req_id = uuid.uuid4().hex + safe_payload = sanitize_request_payload(payload) + payload_json = json.dumps(safe_payload, ensure_ascii=False, allow_nan=False) + conn.execute( + """ + INSERT INTO logs(req_id, upstream_id, request_at, request_payload) + VALUES (?, ?, ?, ?) + """, + (req_id, upstream_id, utc_now_ts(), payload_json), + ) + return req_id + + +def finish_request_log( + conn: sqlite3.Connection, + req_id: str, + status_code: int, + used_tokens: int, + finish_text: str | None = None, + error_text: str | None = None, +) -> None: + trimmed_finish = (finish_text or "")[: SETTINGS.log_text_limit] + trimmed_error = (error_text or "")[: SETTINGS.log_text_limit] + conn.execute( + """ + UPDATE logs + SET finish_at = ?, + status_code = ?, + used_tokens = ?, + finish_text = ?, + error_text = ? + WHERE req_id = ? + """, + (utc_now_ts(), status_code, used_tokens, trimmed_finish, trimmed_error, req_id), + ) + + +def merge_force_parameter(payload: dict[str, Any], upstream: dict[str, Any]) -> dict[str, Any]: + merged = dict(payload) + merged["model"] = upstream.get("model") + force_raw = upstream.get("force_parameter") + if force_raw is None: + return merged + force_str = str(force_raw).strip() + if not force_str: + return merged + try: + parsed = json.loads(force_str) + except json.JSONDecodeError: + return merged + force_payload: dict[str, Any] = {} + if isinstance(parsed, dict): + force_payload = dict(parsed) + elif isinstance(parsed, list): + for item in parsed: + if isinstance(item, dict): + force_payload.update(item) + if force_payload: + merged.update(force_payload) + return merged + + +def build_request_meta(upstream: dict[str, Any]) -> tuple[str, dict[str, str]]: + base = str(upstream.get("base_url", "")).rstrip("/") + "/" + api_key = str(upstream.get("api_key", "")).strip() + auth_value = api_key if api_key.startswith("Bearer ") else f"Bearer {api_key}" + return ( + base + "chat/completions", + { + "Authorization": auth_value, + "Content-Type": "application/json", + }, + ) + + +def build_models_meta(upstream: dict[str, Any]) -> tuple[str, dict[str, str]]: + base = str(upstream.get("base_url", "")).rstrip("/") + "/" + api_key = str(upstream.get("api_key", "")).strip() + auth_value = api_key if api_key.startswith("Bearer ") else f"Bearer {api_key}" + return ( + base + "models", + { + "Authorization": auth_value, + }, + ) + + +def build_admin_probe_payload(upstream: dict[str, Any]) -> dict[str, Any]: + # 仅用于后台连通性探测,尽量减少消耗。 + payload = { + "model": str(upstream.get("model") or ""), + "messages": [{"role": "user", "content": "ping"}], + "stream": False, + "max_tokens": 1, + "temperature": 0, + } + return merge_force_parameter(payload, upstream) + + +def extract_total_tokens(usage: Any) -> int: + if not usage: + return 0 + if hasattr(usage, "model_dump"): + usage = usage.model_dump(mode="json") + if not isinstance(usage, dict): + return 0 + total_tokens = parse_int_like(usage.get("total_tokens")) + if isinstance(total_tokens, int) and total_tokens > 0: + return total_tokens + prompt_tokens = parse_int_like(usage.get("prompt_tokens")) + completion_tokens = parse_int_like(usage.get("completion_tokens")) + if isinstance(prompt_tokens, int) and isinstance(completion_tokens, int): + return prompt_tokens + completion_tokens + return 0 + + +def extract_finish_text_from_response(resp_json: Any) -> str: + if not isinstance(resp_json, dict): + return "" + choices = resp_json.get("choices") + if not isinstance(choices, list) or not choices: + return "" + first = choices[0] + if not isinstance(first, dict): + return "" + message = first.get("message") + if isinstance(message, dict): + content = message.get("content") + if isinstance(content, str): + return content + text = first.get("text") + if isinstance(text, str): + return text + return "" + + +def extract_delta_text_from_chunk(chunk_dict: Any) -> str: + if not isinstance(chunk_dict, dict): + return "" + choices = chunk_dict.get("choices") + if not isinstance(choices, list) or not choices: + return "" + first = choices[0] + if not isinstance(first, dict): + return "" + delta = first.get("delta") + if isinstance(delta, dict): + content = delta.get("content") + if isinstance(content, str): + return content + return "" + + +def unauthorized_response() -> tuple[Response, int]: + return jsonify({"error": {"message": "Unauthorized"}}), 401 + + +def upstream_error_response(message: str, status_code: int = 502) -> Response: + response = jsonify({"error": {"message": message}}) + response.status_code = status_code + return response + + +def auth_header_ok(auth_header: str) -> bool: + expected_raw = (SETTINGS.expected_auth or "").strip() + got_raw = (auth_header or "").strip() + if not expected_raw: + return False + if got_raw == expected_raw: + return True + # 允许 Bearer 前缀大小写或客户端仅传 token 本体,减少接入出错概率。 + return token_body(got_raw) == token_body(expected_raw) + + +def token_body(token: str) -> str: + text = (token or "").strip() + if text.lower().startswith("bearer "): + return text[7:].strip() + return text + + +def effective_admin_token() -> str: + if SETTINGS.admin_token: + return token_body(SETTINGS.admin_token) + return token_body(SETTINGS.expected_auth) + + +def admin_token_ok(user_input: str) -> bool: + expected = effective_admin_token() + if not expected: + return False + typed = token_body(user_input) + return typed == expected + + +def is_local_request() -> bool: + remote_addr = (request.remote_addr or "").strip() + return remote_addr in {"127.0.0.1", "::1", "localhost"} + + +def ensure_admin_access() -> Response | None: + if not SETTINGS.admin_enabled: + abort(404) + if session.get("admin_ok") is True: + return None + # 默认没有额外管理员口令时,允许本机直接访问可视化页面,方便个人使用。 + if not SETTINGS.admin_token and is_local_request(): + return None + return redirect(url_for("admin_login", next=request.path)) + + +def to_form_int(raw: str, default: int = 0) -> int: + try: + return int((raw or "").strip()) + except (TypeError, ValueError): + return default + + +def parse_force_parameter_text(raw_text: str) -> tuple[bool, str | None]: + text = (raw_text or "").strip() + if not text: + return True, None + try: + parsed = json.loads(text) + except json.JSONDecodeError: + return False, "强制参数不是合法 JSON" + if not isinstance(parsed, (dict, list)): + return False, "强制参数必须是 JSON 对象或数组" + return True, json.dumps(parsed, ensure_ascii=False) + + +def format_ts(ts_value: Any) -> str: + if ts_value is None: + return "-" + try: + ts_int = int(ts_value) + except (TypeError, ValueError): + return "-" + if ts_int <= 0: + return "-" + return datetime.fromtimestamp(ts_int).strftime("%Y-%m-%d %H:%M:%S") + + +def attach_route_headers(response: Response, upstream: dict[str, Any], tried_ids: list[int]) -> Response: + current_id = int(upstream.get("id") or 0) + chain_ids = list(tried_ids) + [current_id] + response.headers["X-Route-Upstream-Id"] = str(current_id) + response.headers["X-Route-Group"] = str(upstream.get("group_name") or "") + response.headers["X-Route-Model"] = str(upstream.get("model") or "") + response.headers["X-Route-Tried-Ids"] = ",".join(str(item) for item in chain_ids if int(item) > 0) + return response + + +def choose_upstream(conn: sqlite3.Connection, selector: dict[str, Any], tried_ids: list[int]) -> dict[str, Any] | None: + now_ts = utc_now_ts() + selectors = [selector] + fallback_selector = selector_group_fallback(selector) + if fallback_selector is not None: + selectors.append(fallback_selector) + checked_ids: set[int] = set() + for item in selectors: + candidate_ids = fetch_candidate_ids(conn, item, tried_ids, now_ts) + for candidate_id in candidate_ids: + if candidate_id in checked_ids: + continue + checked_ids.add(candidate_id) + reserved = reserve_upstream(conn, candidate_id, now_ts) + if reserved is not None: + return reserved + return None + + +@APP.get("/healthz") +def healthz() -> dict[str, str]: + return {"status": "ok"} + + +@APP.get("/") +def index_page() -> Response: + admin_text = "已开启" if SETTINGS.admin_enabled else "未开启" + page = f""" + + + + + OpenAI 路由服务 + + + +
+

OpenAI 路由服务

+

服务状态:正常

+

可视化管理:{admin_text}

+

接口入口:/v1/chat/completions/v1/models

+ 进入可视化管理 + 健康检查 +
+ +""" + return Response(page, content_type="text/html; charset=utf-8") + + +def limit_type_display(limit_type: str | None) -> str: + mapping = { + "second": "秒", + "minute": "分钟", + "hour": "小时", + "day": "天", + "month": "月", + "year": "年", + } + alias = limit_type_alias(limit_type) + return mapping.get(alias, "天") + + +def render_limit_type_options(selected: str | None) -> str: + choices = [ + ("second", "秒"), + ("minute", "分钟"), + ("hour", "小时"), + ("day", "天"), + ("month", "月"), + ("year", "年"), + ] + current = limit_type_alias(selected) + parts: list[str] = [] + for value, text in choices: + selected_attr = " selected" if current == value else "" + parts.append(f"") + return "".join(parts) + + +@APP.get("/admin/login") +def admin_login() -> Response: + if not SETTINGS.admin_enabled: + abort(404) + next_url = request.args.get("next", "/admin") + if not next_url.startswith("/"): + next_url = "/admin" + err = request.args.get("err", "") + err_html = f"

{html.escape(err)}

" if err else "" + page = f""" + + + + + 管理员登录 + + + +
+

管理后台登录

+

请输入管理员口令后继续。

+ {err_html} +
+ + + + +
+
+ +""" + return Response(page, content_type="text/html; charset=utf-8") + + +@APP.post("/admin/login") +def admin_login_submit() -> Response: + if not SETTINGS.admin_enabled: + abort(404) + next_url = request.form.get("next", "/admin") + if not next_url.startswith("/"): + next_url = "/admin" + token = request.form.get("token", "") + if admin_token_ok(token): + session["admin_ok"] = True + return redirect(next_url) + return redirect(url_for("admin_login", next=next_url, err="口令错误,请重试")) + + +@APP.get("/admin/logout") +def admin_logout() -> Response: + if not SETTINGS.admin_enabled: + abort(404) + session.pop("admin_ok", None) + return redirect(url_for("admin_login", next="/admin")) + + +@APP.get("/admin") +def admin_console() -> Response: + guard = ensure_admin_access() + if guard is not None: + return guard + conn = get_db() + rows = conn.execute( + """ + SELECT id, name, group_name, model, base_url, limit_type, limit_qty, limit_tokens, + used_cycle_qty, used_cycle_tokens, used_all_qty, used_all_tokens, + consecutive_failures, cooldown_until, last_used_at, enabled, updated_at + FROM upstreams + ORDER BY id ASC + """ + ).fetchall() + summary_row = conn.execute( + """ + SELECT + COUNT(*) AS req_total, + COALESCE(SUM(CASE WHEN status_code = 200 THEN 1 ELSE 0 END), 0) AS req_success, + COALESCE(SUM(CASE WHEN status_code IS NOT NULL AND status_code <> 200 THEN 1 ELSE 0 END), 0) AS req_fail, + COALESCE(SUM(used_tokens), 0) AS token_total + FROM logs + """ + ).fetchone() + day_since_ts = utc_now_ts() - 86400 + day_row = conn.execute( + """ + SELECT + COUNT(*) AS req_day, + COALESCE(SUM(used_tokens), 0) AS token_day + FROM logs + WHERE request_at >= ? + """, + (day_since_ts,), + ).fetchone() + top_rows = conn.execute( + """ + SELECT + u.id AS upstream_id, + u.name AS name, + u.group_name AS group_name, + u.model AS model, + COUNT(l.id) AS req_total, + COALESCE(SUM(CASE WHEN l.status_code = 200 THEN 1 ELSE 0 END), 0) AS req_success, + COALESCE(SUM(CASE WHEN l.status_code IS NOT NULL AND l.status_code <> 200 THEN 1 ELSE 0 END), 0) AS req_fail, + COALESCE(SUM(l.used_tokens), 0) AS token_total, + MAX(l.request_at) AS last_at + FROM upstreams u + LEFT JOIN logs l ON l.upstream_id = u.id + GROUP BY u.id, u.name, u.group_name, u.model + ORDER BY req_total DESC, token_total DESC, u.id ASC + LIMIT 20 + """ + ).fetchall() + recent_rows = conn.execute( + """ + SELECT + l.id AS log_id, + l.request_at AS request_at, + l.status_code AS status_code, + l.used_tokens AS used_tokens, + u.id AS upstream_id, + u.name AS name, + u.group_name AS group_name, + u.model AS model + FROM logs l + LEFT JOIN upstreams u ON u.id = l.upstream_id + ORDER BY l.id DESC + LIMIT 50 + """ + ).fetchall() + msg = request.args.get("msg", "") + err = request.args.get("err", "") + banner_html = "" + if msg: + banner_html += f"

{html.escape(msg)}

" + if err: + banner_html += f"

{html.escape(err)}

" + summary = sqlite_row_to_dict(summary_row) or {} + day_summary = sqlite_row_to_dict(day_row) or {} + top_parts: list[str] = [] + for item in top_rows: + row = sqlite_row_to_dict(item) or {} + top_parts.append( + f"" + f"{row.get('upstream_id')}" + f"{html.escape(str(row.get('name') or ''))}" + f"{html.escape(str(row.get('group_name') or ''))}" + f"{html.escape(str(row.get('model') or ''))}" + f"{row.get('req_total') or 0}" + f"{row.get('req_success') or 0}" + f"{row.get('req_fail') or 0}" + f"{row.get('token_total') or 0}" + f"{format_ts(row.get('last_at'))}" + f"" + ) + top_html = "".join(top_parts) if top_parts else "暂无统计数据" + recent_parts: list[str] = [] + for item in recent_rows: + row = sqlite_row_to_dict(item) or {} + recent_parts.append( + f"" + f"{row.get('log_id') or '-'}" + f"{format_ts(row.get('request_at'))}" + f"{row.get('status_code') or '-'}" + f"{row.get('used_tokens') or 0}" + f"{row.get('upstream_id') or '-'}" + f"{html.escape(str(row.get('name') or ''))}" + f"{html.escape(str(row.get('group_name') or ''))}" + f"{html.escape(str(row.get('model') or ''))}" + f"" + ) + recent_html = "".join(recent_parts) if recent_parts else "暂无调用记录" + body_rows: list[str] = [] + for row in rows: + row_dict = sqlite_row_to_dict(row) or {} + enabled_checked = "checked" if int(row_dict.get("enabled") or 0) == 1 else "" + options_html = render_limit_type_options(str(row_dict.get("limit_type") or "day")) + body_rows.append( + f""" + + {row_dict.get("id")} + +
+ + + + + + + + + + +
+ 周期请求: {row_dict.get('used_cycle_qty')} | + 周期Token: {row_dict.get('used_cycle_tokens')} | + 总请求: {row_dict.get('used_all_qty')} | + 总Token: {row_dict.get('used_all_tokens')} | + 连续失败: {row_dict.get('consecutive_failures')} | + 冷却到: {format_ts(row_dict.get('cooldown_until'))} | + 周期限额单位: {limit_type_display(str(row_dict.get('limit_type') or 'day'))} +
+ + + +
+ + + """ + ) + body_html = "".join(body_rows) if body_rows else "暂无上游配置" + page = f""" + + + + + 可视化管理后台 + + + +

上游可视化管理

+

这里可以新增、编辑、启停、删除上游。建议给客户端使用中文别名(例如:通用聊天、快速模型)。

+ {banner_html} +
+

总览统计

+
+
总请求数
{summary.get("req_total", 0)}
+
成功请求
{summary.get("req_success", 0)}
+
失败请求
{summary.get("req_fail", 0)}
+
总 Token
{summary.get("token_total", 0)}
+
近24小时请求
{day_summary.get("req_day", 0)}
+
近24小时 Token
{day_summary.get("token_day", 0)}
+
+
+
+

新增上游

+
+ + + +
+
+
+
+
+
+
+
+
+
+
+
+
+ + 打开完整统计页 + 退出管理 +
+
+
+
+

上游统计 TOP 20

+
+ + + + + + + {top_html} +
编号中文别名分组模型总请求成功失败总Token最后调用
+
+
+
+

最近 50 次调用(可直接看到命中上游)

+
+ + + + + + + {recent_html} +
日志ID时间状态码Token上游ID中文别名分组模型
+
+
+
+ + + {body_html} +
编号上游配置
+
+ +""" + return Response(page, content_type="text/html; charset=utf-8") + + +@APP.get("/admin/stats") +def admin_stats() -> Response: + guard = ensure_admin_access() + if guard is not None: + return guard + conn = get_db() + daily_rows = conn.execute( + """ + SELECT + strftime('%Y-%m-%d', datetime(request_at, 'unixepoch', 'localtime')) AS day_key, + COUNT(*) AS req_total, + COALESCE(SUM(CASE WHEN status_code = 200 THEN 1 ELSE 0 END), 0) AS req_success, + COALESCE(SUM(CASE WHEN status_code IS NOT NULL AND status_code <> 200 THEN 1 ELSE 0 END), 0) AS req_fail, + COALESCE(SUM(used_tokens), 0) AS token_total + FROM logs + WHERE request_at >= ? + GROUP BY day_key + ORDER BY day_key DESC + LIMIT 14 + """, + (utc_now_ts() - 14 * 86400,), + ).fetchall() + error_rows = conn.execute( + """ + SELECT + l.req_id AS req_id, + l.status_code AS status_code, + l.error_text AS error_text, + l.finish_at AS finish_at, + u.name AS name, + u.group_name AS group_name, + u.model AS model + FROM logs l + LEFT JOIN upstreams u ON u.id = l.upstream_id + WHERE l.status_code IS NOT NULL AND l.status_code <> 200 + ORDER BY l.id DESC + LIMIT 100 + """ + ).fetchall() + + day_parts: list[str] = [] + for item in daily_rows: + row = sqlite_row_to_dict(item) or {} + day_parts.append( + f"" + f"{html.escape(str(row.get('day_key') or '-'))}" + f"{row.get('req_total') or 0}" + f"{row.get('req_success') or 0}" + f"{row.get('req_fail') or 0}" + f"{row.get('token_total') or 0}" + f"" + ) + day_html = "".join(day_parts) if day_parts else "暂无数据" + + err_parts: list[str] = [] + for item in error_rows: + row = sqlite_row_to_dict(item) or {} + err_parts.append( + f"" + f"{html.escape(str(row.get('req_id') or ''))}" + f"{row.get('status_code') or '-'}" + f"{html.escape(str(row.get('name') or ''))}" + f"{html.escape(str(row.get('group_name') or ''))}" + f"{html.escape(str(row.get('model') or ''))}" + f"{format_ts(row.get('finish_at'))}" + f"{html.escape(str(row.get('error_text') or ''))}" + f"" + ) + err_html = "".join(err_parts) if err_parts else "暂无失败日志" + + page = f""" + + + + + 完整统计 + + + +

完整统计

+ 返回管理首页 +
+

近 14 天按天统计

+
+ + + {day_html} +
日期总请求成功失败总 Token
+
+
+
+

最近 100 条失败日志

+
+ + + {err_html} +
请求ID状态码中文别名分组模型时间错误
+
+
+ +""" + return Response(page, content_type="text/html; charset=utf-8") + + +@APP.post("/admin/logs/cleanup") +def admin_logs_cleanup() -> Response: + guard = ensure_admin_access() + if guard is not None: + return guard + days = max(1, min(3650, to_form_int(request.form.get("days"), 7))) + cutoff_ts = utc_now_ts() - days * 86400 + conn = get_db() + cursor = conn.execute("DELETE FROM logs WHERE request_at < ?", (cutoff_ts,)) + deleted_count = int(cursor.rowcount or 0) + return redirect(url_for("admin_console", msg=f"日志清理完成:删除 {deleted_count} 条(保留最近 {days} 天)")) + + +@APP.post("/admin/upstreams//test") +def admin_upstream_test(upstream_id: int) -> Response: + guard = ensure_admin_access() + if guard is not None: + return guard + conn = get_db() + row = conn.execute("SELECT * FROM upstreams WHERE id = ?", (upstream_id,)).fetchone() + upstream = sqlite_row_to_dict(row) + if upstream is None: + return redirect(url_for("admin_console", err=f"连通性测试失败:编号 {upstream_id} 不存在")) + + models_url, models_headers = build_models_meta(upstream) + try: + resp = HTTP_SESSION.get( + models_url, + headers=models_headers, + timeout=(SETTINGS.connect_timeout_sec, min(20, SETTINGS.read_timeout_sec)), + ) + except requests.RequestException as exc: + resp = None + models_error = str(exc) + else: + models_error = "" + + if resp is not None and resp.status_code == 200: + return redirect(url_for("admin_console", msg=f"编号 {upstream_id} 连通性正常:/models 返回 200")) + + # /models 失败时,再用极小 token 的 chat/completions 探测,避免误判。 + probe_payload = build_admin_probe_payload(upstream) + probe_url, probe_headers = build_request_meta(upstream) + try: + probe_resp = HTTP_SESSION.post( + probe_url, + headers=probe_headers, + json=probe_payload, + timeout=(SETTINGS.connect_timeout_sec, min(30, SETTINGS.read_timeout_sec)), + ) + except requests.RequestException as exc: + reason = models_error or str(exc) + return redirect(url_for("admin_console", err=f"编号 {upstream_id} 连通性失败:{reason}")) + + if probe_resp.status_code == 200: + return redirect(url_for("admin_console", msg=f"编号 {upstream_id} 连通性正常:/chat/completions 返回 200")) + + detail = probe_resp.text[:180].replace("\n", " ").replace("\r", " ") + if resp is not None: + detail = f"/models={resp.status_code}, /chat={probe_resp.status_code}, detail={detail}" + else: + detail = f"/models网络异常, /chat={probe_resp.status_code}, detail={detail}" + return redirect(url_for("admin_console", err=f"编号 {upstream_id} 连通性失败:{detail}")) + + +@APP.post("/admin/upstreams/create") +def admin_upstream_create() -> Response: + guard = ensure_admin_access() + if guard is not None: + return guard + name = (request.form.get("name") or "").strip() + group_name = (request.form.get("group_name") or "1").strip() + model = (request.form.get("model") or "").strip() + base_url = (request.form.get("base_url") or "").strip() + api_key = (request.form.get("api_key") or "").strip() + limit_type = limit_type_alias(request.form.get("limit_type")) + limit_qty = max(0, to_form_int(request.form.get("limit_qty"), 0)) + limit_tokens = max(0, to_form_int(request.form.get("limit_tokens"), 0)) + enabled = 1 if (request.form.get("enabled") or "").strip().lower() in {"1", "on", "true"} else 0 + ok_force, force_parameter = parse_force_parameter_text(request.form.get("force_parameter", "")) + + if not name or not model or not base_url or not api_key: + return redirect(url_for("admin_console", err="新增失败:请把必填项填写完整")) + if not ok_force: + return redirect(url_for("admin_console", err=force_parameter or "新增失败:强制参数格式错误")) + + now_ts = utc_now_ts() + conn = get_db() + conn.execute( + """ + INSERT INTO upstreams( + name, group_name, base_url, api_key, model, force_parameter, + limit_type, limit_qty, limit_tokens, enabled, + cycle_started_at, created_at, updated_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + name, + group_name or "1", + base_url, + api_key, + model, + force_parameter, + limit_type, + limit_qty, + limit_tokens, + enabled, + cycle_start_ts(limit_type, now_ts), + now_ts, + now_ts, + ), + ) + return redirect(url_for("admin_console", msg="新增成功")) + + +@APP.post("/admin/upstreams//update") +def admin_upstream_update(upstream_id: int) -> Response: + guard = ensure_admin_access() + if guard is not None: + return guard + name = (request.form.get("name") or "").strip() + group_name = (request.form.get("group_name") or "1").strip() + model = (request.form.get("model") or "").strip() + base_url = (request.form.get("base_url") or "").strip() + api_key = (request.form.get("api_key") or "").strip() + limit_type = limit_type_alias(request.form.get("limit_type")) + limit_qty = max(0, to_form_int(request.form.get("limit_qty"), 0)) + limit_tokens = max(0, to_form_int(request.form.get("limit_tokens"), 0)) + enabled = 1 if (request.form.get("enabled") or "").strip().lower() in {"1", "on", "true"} else 0 + ok_force, force_parameter = parse_force_parameter_text(request.form.get("force_parameter", "")) + + if not name or not model or not base_url or not api_key: + return redirect(url_for("admin_console", err=f"保存失败:编号 {upstream_id} 的必填项不能为空")) + if not ok_force: + return redirect(url_for("admin_console", err=f"保存失败:编号 {upstream_id} 的强制参数格式错误")) + + conn = get_db() + current = conn.execute("SELECT id, limit_type FROM upstreams WHERE id = ?", (upstream_id,)).fetchone() + if current is None: + return redirect(url_for("admin_console", err=f"保存失败:编号 {upstream_id} 不存在")) + now_ts = utc_now_ts() + next_cycle_start = cycle_start_ts(limit_type, now_ts) + old_limit_type = limit_type_alias(str(current["limit_type"] or "day")) + if old_limit_type == limit_type: + cycle_started_at_sql = "cycle_started_at" + cycle_started_at_value = None + else: + cycle_started_at_sql = "?" + cycle_started_at_value = next_cycle_start + + params: list[Any] = [ + name, + group_name or "1", + base_url, + api_key, + model, + force_parameter, + limit_type, + limit_qty, + limit_tokens, + enabled, + now_ts, + upstream_id, + ] + if cycle_started_at_value is None: + conn.execute( + """ + UPDATE upstreams + SET name = ?, group_name = ?, base_url = ?, api_key = ?, model = ?, force_parameter = ?, + limit_type = ?, limit_qty = ?, limit_tokens = ?, enabled = ?, updated_at = ? + WHERE id = ? + """, + params, + ) + else: + conn.execute( + f""" + UPDATE upstreams + SET name = ?, group_name = ?, base_url = ?, api_key = ?, model = ?, force_parameter = ?, + limit_type = ?, limit_qty = ?, limit_tokens = ?, enabled = ?, cycle_started_at = {cycle_started_at_sql}, updated_at = ? + WHERE id = ? + """, + [ + name, + group_name or "1", + base_url, + api_key, + model, + force_parameter, + limit_type, + limit_qty, + limit_tokens, + enabled, + next_cycle_start, + now_ts, + upstream_id, + ], + ) + + return redirect(url_for("admin_console", msg=f"编号 {upstream_id} 已保存")) + + +@APP.post("/admin/upstreams//delete") +def admin_upstream_delete(upstream_id: int) -> Response: + guard = ensure_admin_access() + if guard is not None: + return guard + conn = get_db() + try: + conn.execute("BEGIN IMMEDIATE") + conn.execute("DELETE FROM logs WHERE upstream_id = ?", (upstream_id,)) + conn.execute("DELETE FROM upstreams WHERE id = ?", (upstream_id,)) + conn.execute("COMMIT") + except Exception as exc: + try: + conn.execute("ROLLBACK") + except sqlite3.Error: + pass + return redirect(url_for("admin_console", err=f"删除失败:{exc}")) + return redirect(url_for("admin_console", msg=f"编号 {upstream_id} 已删除(关联日志已同步清理)")) + + +@APP.get("/v1/models") +def list_models() -> Response | tuple[Response, int]: + auth = request.headers.get("Authorization", "") + if not auth_header_ok(auth): + return unauthorized_response() + conn = get_db() + rows = conn.execute( + """ + SELECT model, MIN(created_at) AS created_at + FROM upstreams + WHERE enabled = 1 AND model IS NOT NULL AND TRIM(model) <> '' + GROUP BY model + ORDER BY model ASC + """ + ).fetchall() + data = [] + now_ts = utc_now_ts() + for row in rows: + model_id = str(row["model"]).strip() + if not model_id: + continue + created_at = int(row["created_at"]) if row["created_at"] is not None else now_ts + data.append( + { + "id": model_id, + "object": "model", + "created": created_at, + "owned_by": "openai-proxy", + } + ) + return jsonify({"object": "list", "data": data}) + + +def forward_non_stream_request( + conn: sqlite3.Connection, + upstream: dict[str, Any], + request_url: str, + request_headers: dict[str, str], + payload: dict[str, Any], + log_id: str, + tried_ids: list[int], +) -> tuple[bool, Response]: + upstream_id = int(upstream["id"]) + try: + result = HTTP_SESSION.post( + request_url, + headers=request_headers, + json=payload, + timeout=(SETTINGS.connect_timeout_sec, SETTINGS.read_timeout_sec), + ) + except requests.RequestException as exc: + error_text = str(exc) + mark_upstream_failure(conn, upstream_id) + finish_request_log(conn, log_id, status_code=599, used_tokens=0, error_text=error_text) + return False, attach_route_headers(upstream_error_response(error_text, status_code=502), upstream, tried_ids) + + status_code = int(result.status_code) + content_type = result.headers.get("Content-Type") or "application/json; charset=utf-8" + body_bytes = result.content + body_text = result.text + + if status_code != 200: + mark_upstream_failure(conn, upstream_id) + finish_request_log( + conn, + log_id, + status_code=status_code, + used_tokens=0, + finish_text=None, + error_text=body_text, + ) + response = attach_route_headers(Response(body_bytes, status=status_code, content_type=content_type), upstream, tried_ids) + return (not is_retryable_failure(status_code, body_text)), response + + used_tokens = 0 + finish_text = "" + try: + result_json = result.json() + used_tokens = extract_total_tokens(result_json.get("usage")) + finish_text = extract_finish_text_from_response(result_json) + except ValueError: + pass + + mark_upstream_success(conn, upstream_id, used_tokens) + finish_request_log( + conn, + log_id, + status_code=status_code, + used_tokens=used_tokens, + finish_text=finish_text or body_text, + error_text=None, + ) + response = attach_route_headers(Response(body_bytes, status=status_code, content_type=content_type), upstream, tried_ids) + return True, response + + +def finalize_stream_outcome( + upstream_id: int, + log_id: str, + used_tokens: int, + finish_text: str, + stream_error: str | None, +) -> None: + # 流式响应结束时,请求上下文可能已经关闭,必须使用独立连接落库。 + conn: sqlite3.Connection | None = None + try: + conn = open_sqlite_conn(SETTINGS.sqlite_path) + if stream_error: + mark_upstream_failure(conn, upstream_id) + finish_request_log( + conn, + log_id, + status_code=599, + used_tokens=used_tokens, + finish_text=finish_text, + error_text=stream_error, + ) + else: + mark_upstream_success(conn, upstream_id, used_tokens) + finish_request_log( + conn, + log_id, + status_code=200, + used_tokens=used_tokens, + finish_text=finish_text, + error_text=None, + ) + except Exception: + LOGGER.exception("流式请求收尾写库失败 upstream_id=%s log_id=%s", upstream_id, log_id) + finally: + if conn is not None: + conn.close() + + +def stream_response_generator( + upstream_id: int, + log_id: str, + upstream_resp: requests.Response, +) -> Iterator[str]: + used_tokens = 0 + finish_parts: list[str] = [] + stream_error: str | None = None + try: + for line in upstream_resp.iter_lines(decode_unicode=True): + if line is None: + continue + if line == "": + yield "\n" + continue + if line.startswith("data:"): + data = line[len("data:") :].lstrip() + if data == "[DONE]": + yield "data: [DONE]\n\n" + continue + try: + chunk_dict = json.loads(data) + except json.JSONDecodeError: + yield line + "\n" + continue + usage_tokens = extract_total_tokens(chunk_dict.get("usage")) + if usage_tokens > 0: + used_tokens = usage_tokens + delta_text = extract_delta_text_from_chunk(chunk_dict) + if delta_text: + finish_parts.append(delta_text) + yield "data: " + json.dumps(chunk_dict, ensure_ascii=False) + "\n\n" + continue + yield line + "\n" + except requests.RequestException as exc: + stream_error = str(exc) + raise + finally: + upstream_resp.close() + finish_text = "".join(finish_parts) + finalize_stream_outcome(upstream_id, log_id, used_tokens, finish_text, stream_error) + + +def forward_stream_request( + conn: sqlite3.Connection, + upstream: dict[str, Any], + request_url: str, + request_headers: dict[str, str], + payload: dict[str, Any], + log_id: str, + tried_ids: list[int], +) -> tuple[bool, Response]: + upstream_id = int(upstream["id"]) + try: + upstream_resp = HTTP_SESSION.post( + request_url, + headers=request_headers, + json=payload, + stream=True, + timeout=(SETTINGS.connect_timeout_sec, SETTINGS.stream_read_timeout_sec), + ) + except requests.RequestException as exc: + error_text = str(exc) + mark_upstream_failure(conn, upstream_id) + finish_request_log(conn, log_id, status_code=599, used_tokens=0, error_text=error_text) + return False, attach_route_headers(upstream_error_response(error_text, status_code=502), upstream, tried_ids) + + status_code = int(upstream_resp.status_code) + if status_code != 200: + body_text = upstream_resp.text + body_bytes = upstream_resp.content + content_type = upstream_resp.headers.get("Content-Type") or "application/json; charset=utf-8" + upstream_resp.close() + mark_upstream_failure(conn, upstream_id) + finish_request_log( + conn, + log_id, + status_code=status_code, + used_tokens=0, + finish_text=None, + error_text=body_text, + ) + response = attach_route_headers(Response(body_bytes, status=status_code, content_type=content_type), upstream, tried_ids) + return (not is_retryable_failure(status_code, body_text)), response + + content_type = upstream_resp.headers.get("Content-Type") or "text/event-stream; charset=utf-8" + stream_iter = stream_with_context(stream_response_generator(upstream_id, log_id, upstream_resp)) + return True, attach_route_headers(Response(stream_iter, status=200, content_type=content_type), upstream, tried_ids) + + +@APP.post("/v1/chat/completions") +def proxy_chat_completions() -> Response | tuple[Response, int]: + auth = request.headers.get("Authorization", "") + if not auth_header_ok(auth): + return unauthorized_response() + + user_post = request.get_json(silent=True) + if not isinstance(user_post, dict): + return jsonify({"error": {"message": "Invalid JSON body"}}), 400 + + selector = parse_selector(str(user_post.get("model", ""))) + if selector is None: + return jsonify({"error": {"message": "model is required"}}), 400 + + conn = get_db() + tried_ids: list[int] = [] + last_retryable_response: Response | None = None + + for _ in range(SETTINGS.max_failover_attempts): + upstream = choose_upstream(conn, selector, tried_ids) + if upstream is None: + break + upstream_id = int(upstream["id"]) + payload = merge_force_parameter(user_post, upstream) + request_url, request_headers = build_request_meta(upstream) + log_id = create_request_log(conn, upstream_id, payload) + + stream_mode = bool(payload.get("stream", False)) + if stream_mode: + should_stop, response = forward_stream_request( + conn, upstream, request_url, request_headers, payload, log_id, tried_ids + ) + else: + should_stop, response = forward_non_stream_request( + conn, upstream, request_url, request_headers, payload, log_id, tried_ids + ) + if should_stop: + return response + tried_ids.append(upstream_id) + last_retryable_response = response + + if last_retryable_response is not None: + return last_retryable_response + no_available_message = collect_no_available_reason(conn, selector) + return jsonify( + { + "error": { + "message": no_available_message, + "type": "routing_error", + "code": "no_available_model_endpoint", + } + } + ), 429 + + +def setup_logging() -> None: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s %(message)s", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="OpenAI compatible proxy with SQLite routing") + parser.add_argument("--init-db", action="store_true", help="Only initialize SQLite schema and exit") + return parser.parse_args() + + +def startup_checks() -> None: + if SETTINGS.expected_auth == DEFAULT_AUTH: + LOGGER.warning( + "OPENAI_PROXY_AUTH is using default value '%s'. Please set a strong bearer token before exposing the service.", + DEFAULT_AUTH, + ) + + +def main() -> None: + setup_logging() + startup_checks() + args = parse_args() + init_db(SETTINGS.sqlite_path) + if args.init_db: + LOGGER.info("SQLite schema initialized: %s", SETTINGS.sqlite_path) + return + APP.run( + host=SETTINGS.listen_host, + port=SETTINGS.listen_port, + debug=SETTINGS.debug, + threaded=True, + ) + + +if __name__ == "__main__": + main()