from __future__ import annotations import argparse import html import json import logging import os import sqlite3 import time import uuid from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Any, Iterator import requests from flask import Flask, Response, abort, g, jsonify, redirect, request, session, stream_with_context, url_for from requests.adapters import HTTPAdapter LOGGER = logging.getLogger("openai_proxy") DEFAULT_AUTH = "Bearer change-me" DEFAULT_SECRET_KEY = "openai-proxy-dev-secret-change-me" RETRYABLE_STATUSES = {401, 403, 404, 408, 409, 425, 429, 500, 502, 503, 504} def env_int(name: str, default: int) -> int: raw = os.getenv(name) if raw is None: return default try: return int(raw.strip()) except (TypeError, ValueError): return default def env_bool(name: str, default: bool) -> bool: raw = os.getenv(name) if raw is None: return default value = raw.strip().lower() return value in {"1", "true", "yes", "on"} def load_dotenv_file(path: Path) -> None: if not path.exists() or not path.is_file(): return try: lines = path.read_text(encoding="utf-8-sig").splitlines() except OSError: return for line in lines: text = line.strip() if not text or text.startswith("#"): continue if text.startswith("export "): text = text[7:].strip() if "=" not in text: continue key, value = text.split("=", 1) key = key.strip() value = value.strip().strip("'").strip('"') if not key: continue if key not in os.environ: os.environ[key] = value @dataclass(frozen=True) class Settings: expected_auth: str sqlite_path: Path busy_timeout_ms: int listen_host: str listen_port: int connect_timeout_sec: int read_timeout_sec: int stream_read_timeout_sec: int max_failover_attempts: int auto_group_fallback: bool failure_cooldown_sec: int log_full_payload: bool log_text_limit: int admin_enabled: bool admin_token: str secret_key: str debug: bool def load_settings() -> Settings: sqlite_path = Path(os.getenv("OPENAI_PROXY_SQLITE_PATH", "./data/openai_proxy.db")).expanduser() return Settings( expected_auth=os.getenv("OPENAI_PROXY_AUTH", DEFAULT_AUTH), sqlite_path=sqlite_path, busy_timeout_ms=env_int("OPENAI_PROXY_DB_BUSY_TIMEOUT_MS", 5000), listen_host=os.getenv("OPENAI_PROXY_LISTEN_HOST", "0.0.0.0"), listen_port=env_int("OPENAI_PROXY_LISTEN_PORT", 8056), connect_timeout_sec=env_int("OPENAI_PROXY_CONNECT_TIMEOUT_SEC", 10), read_timeout_sec=env_int("OPENAI_PROXY_READ_TIMEOUT_SEC", 120), stream_read_timeout_sec=env_int("OPENAI_PROXY_STREAM_READ_TIMEOUT_SEC", 600), max_failover_attempts=max(1, env_int("OPENAI_PROXY_MAX_FAILOVER_ATTEMPTS", 5)), auto_group_fallback=env_bool("OPENAI_PROXY_AUTO_GROUP_FALLBACK", True), failure_cooldown_sec=max(0, env_int("OPENAI_PROXY_FAILURE_COOLDOWN_SEC", 30)), log_full_payload=env_bool("OPENAI_PROXY_LOG_FULL_PAYLOAD", False), log_text_limit=max(256, env_int("OPENAI_PROXY_LOG_TEXT_LIMIT", 100000)), admin_enabled=env_bool("OPENAI_PROXY_ADMIN_ENABLED", True), admin_token=os.getenv("OPENAI_PROXY_ADMIN_TOKEN", "").strip(), secret_key=os.getenv("OPENAI_PROXY_SECRET_KEY", DEFAULT_SECRET_KEY).strip() or DEFAULT_SECRET_KEY, debug=env_bool("OPENAI_PROXY_DEBUG", False), ) load_dotenv_file(Path(os.getenv("OPENAI_PROXY_ENV_FILE", ".env"))) SETTINGS = load_settings() APP = Flask(__name__) APP.config["SECRET_KEY"] = SETTINGS.secret_key APP.config["SESSION_COOKIE_HTTPONLY"] = True APP.config["SESSION_COOKIE_SAMESITE"] = "Lax" HTTP_SESSION = requests.Session() HTTP_SESSION.mount("http://", HTTPAdapter(pool_connections=50, pool_maxsize=50, max_retries=0)) HTTP_SESSION.mount("https://", HTTPAdapter(pool_connections=50, pool_maxsize=50, max_retries=0)) def utc_now_ts() -> int: return int(time.time()) def limit_type_alias(limit_type: str | None) -> str: value = (limit_type or "").strip().lower() alias_map = { "miao": "second", "fen": "minute", "shi": "hour", "tian": "day", "yue": "month", "nian": "year", "sec": "second", "min": "minute", } if value in {"second", "minute", "hour", "day", "month", "year"}: return value return alias_map.get(value, "day") def cycle_start_ts(limit_type: str | None, ts: int) -> int: unit = limit_type_alias(limit_type) dt = datetime.fromtimestamp(ts, tz=timezone.utc) if unit == "second": return ts if unit == "minute": return int(dt.replace(second=0, microsecond=0).timestamp()) if unit == "hour": return int(dt.replace(minute=0, second=0, microsecond=0).timestamp()) if unit == "day": return int(dt.replace(hour=0, minute=0, second=0, microsecond=0).timestamp()) if unit == "month": return int(dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0).timestamp()) if unit == "year": return int(dt.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0).timestamp()) return int(dt.replace(hour=0, minute=0, second=0, microsecond=0).timestamp()) def open_sqlite_conn(db_path: Path) -> sqlite3.Connection: db_path.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(str(db_path), timeout=SETTINGS.busy_timeout_ms / 1000, isolation_level=None) conn.row_factory = sqlite3.Row conn.execute("PRAGMA journal_mode=WAL;") conn.execute(f"PRAGMA busy_timeout={SETTINGS.busy_timeout_ms};") conn.execute("PRAGMA synchronous=NORMAL;") conn.execute("PRAGMA foreign_keys=ON;") return conn def init_db(db_path: Path) -> None: conn = open_sqlite_conn(db_path) try: conn.executescript( """ CREATE TABLE IF NOT EXISTS upstreams ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, group_name TEXT NOT NULL DEFAULT '1', base_url TEXT NOT NULL, api_key TEXT NOT NULL, model TEXT NOT NULL, force_parameter TEXT, limit_type TEXT NOT NULL DEFAULT 'day', limit_qty INTEGER NOT NULL DEFAULT 0, limit_tokens INTEGER NOT NULL DEFAULT 0, used_cycle_qty INTEGER NOT NULL DEFAULT 0, used_cycle_tokens INTEGER NOT NULL DEFAULT 0, used_all_qty INTEGER NOT NULL DEFAULT 0, used_all_tokens INTEGER NOT NULL DEFAULT 0, cycle_started_at INTEGER NOT NULL DEFAULT (unixepoch()), last_used_at INTEGER, consecutive_failures INTEGER NOT NULL DEFAULT 0, cooldown_until INTEGER, enabled INTEGER NOT NULL DEFAULT 1, created_at INTEGER NOT NULL DEFAULT (unixepoch()), updated_at INTEGER NOT NULL DEFAULT (unixepoch()) ); CREATE INDEX IF NOT EXISTS idx_upstreams_enabled_model ON upstreams(enabled, model, name, group_name); CREATE INDEX IF NOT EXISTS idx_upstreams_scheduling ON upstreams(enabled, cooldown_until, consecutive_failures, last_used_at); CREATE TABLE IF NOT EXISTS logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, req_id TEXT NOT NULL UNIQUE, upstream_id INTEGER NOT NULL, request_at INTEGER NOT NULL, finish_at INTEGER, status_code INTEGER, used_tokens INTEGER NOT NULL DEFAULT 0, request_payload TEXT, finish_text TEXT, error_text TEXT, FOREIGN KEY (upstream_id) REFERENCES upstreams(id) ); CREATE INDEX IF NOT EXISTS idx_logs_request_at ON logs(request_at); """ ) finally: conn.close() def get_db() -> sqlite3.Connection: conn = getattr(g, "_sqlite_conn", None) if conn is None: conn = open_sqlite_conn(SETTINGS.sqlite_path) g._sqlite_conn = conn return conn @APP.teardown_appcontext def close_db(_: BaseException | None) -> None: conn = getattr(g, "_sqlite_conn", None) if conn is not None: conn.close() g._sqlite_conn = None def sqlite_row_to_dict(row: sqlite3.Row | None) -> dict[str, Any] | None: if row is None: return None return {key: row[key] for key in row.keys()} def parse_selector(raw_model: str) -> dict[str, Any] | None: value = (raw_model or "").strip()[:100] if not value: return None if value.lower().startswith("group:"): group_part = value[6:] groups = [part.strip() for part in group_part.split(",") if part.strip()] if not groups: return None return {"type": "groups", "groups": groups, "value": value} if "," in value: groups = [part.strip() for part in value.split(",") if part.strip()] if not groups: return None return {"type": "groups", "groups": groups, "value": value} if value.isdigit(): return {"type": "groups", "groups": [value], "value": value} return {"type": "exact", "model": value, "value": value} def fetch_candidate_ids( conn: sqlite3.Connection, selector: dict[str, Any], excluded_ids: list[int], now_ts: int ) -> list[int]: params: list[Any] = [] sql = "SELECT id FROM upstreams WHERE enabled = 1" if excluded_ids: placeholders = ",".join("?" for _ in excluded_ids) sql += f" AND id NOT IN ({placeholders})" params.extend(excluded_ids) if selector["type"] == "groups": groups = list(selector["groups"]) if SETTINGS.auto_group_fallback and len(groups) == 1 and groups[0].strip().isdigit(): start_group = int(groups[0].strip()) group_rows = conn.execute("SELECT DISTINCT group_name FROM upstreams WHERE enabled = 1").fetchall() numeric_groups = sorted( { int(str(row["group_name"]).strip()) for row in group_rows if str(row["group_name"]).strip().isdigit() } ) expanded_groups = [str(item) for item in numeric_groups if item >= start_group] if expanded_groups: groups = expanded_groups placeholders = ",".join("?" for _ in groups) sql += f" AND group_name IN ({placeholders})" params.extend(groups) else: sql += " AND (name = ? OR model = ?)" params.append(selector["model"]) params.append(selector["model"]) sql += """ ORDER BY CASE WHEN COALESCE(cooldown_until, 0) <= ? THEN 0 ELSE 1 END ASC, -- 稳定优先:默认按配置顺序(编号)选择,失败才在单次请求内切到下一个。 id ASC """ params.append(now_ts) rows = conn.execute(sql, params).fetchall() return [int(row["id"]) for row in rows] def selector_group_fallback(selector: dict[str, Any]) -> dict[str, Any] | None: if selector.get("type") != "exact": return None raw = str(selector.get("model") or "").strip() if not raw: return None return {"type": "groups", "groups": [raw], "value": f"group:{raw}"} def evaluate_row_available_now(row: dict[str, Any], now_ts: int) -> tuple[bool, str]: cooldown_until = int(row.get("cooldown_until") or 0) if cooldown_until > now_ts: return False, "cooldown" cycle_started_at = int(row.get("cycle_started_at") or 0) target_cycle_start = cycle_start_ts(str(row.get("limit_type") or "day"), now_ts) used_cycle_qty = int(row.get("used_cycle_qty") or 0) used_cycle_tokens = int(row.get("used_cycle_tokens") or 0) if cycle_started_at < target_cycle_start: used_cycle_qty = 0 used_cycle_tokens = 0 limit_qty = int(row.get("limit_qty") or 0) limit_tokens = int(row.get("limit_tokens") or 0) qty_ok = limit_qty <= 0 or used_cycle_qty < limit_qty token_ok = limit_tokens <= 0 or used_cycle_tokens < limit_tokens if qty_ok and token_ok: return True, "ok" return False, "quota" def collect_no_available_reason(conn: sqlite3.Connection, selector: dict[str, Any]) -> str: now_ts = utc_now_ts() selectors = [selector] fallback_selector = selector_group_fallback(selector) if fallback_selector is not None: selectors.append(fallback_selector) matched_ids: set[int] = set() for item in selectors: matched_ids.update(fetch_candidate_ids(conn, item, [], now_ts)) if not matched_ids: model_value = str(selector.get("value") or selector.get("model") or "") return f"No available model endpoint: 未匹配到可用上游,请检查 model/分组。当前 model={model_value}" placeholders = ",".join("?" for _ in matched_ids) rows = conn.execute( f""" SELECT id, name, group_name, model, limit_type, limit_qty, limit_tokens, used_cycle_qty, used_cycle_tokens, cycle_started_at, cooldown_until FROM upstreams WHERE id IN ({placeholders}) """, tuple(sorted(matched_ids)), ).fetchall() cooldown_count = 0 quota_count = 0 ok_count = 0 for row in rows: row_dict = sqlite_row_to_dict(row) or {} available, reason = evaluate_row_available_now(row_dict, now_ts) if available: ok_count += 1 elif reason == "cooldown": cooldown_count += 1 elif reason == "quota": quota_count += 1 return ( "No available model endpoint: " f"匹配到 {len(matched_ids)} 条上游,当前可用 {ok_count} 条,冷却中 {cooldown_count} 条,限额阻塞 {quota_count} 条。" ) def refresh_cycle_if_needed(conn: sqlite3.Connection, row: dict[str, Any], now_ts: int) -> dict[str, Any]: current_cycle_start = int(row.get("cycle_started_at") or 0) target_cycle_start = cycle_start_ts(row.get("limit_type"), now_ts) if current_cycle_start >= target_cycle_start: return row conn.execute( """ UPDATE upstreams SET used_cycle_qty = 0, used_cycle_tokens = 0, cycle_started_at = ?, updated_at = ? WHERE id = ? """, (target_cycle_start, now_ts, row["id"]), ) row["used_cycle_qty"] = 0 row["used_cycle_tokens"] = 0 row["cycle_started_at"] = target_cycle_start return row def has_quota(row: dict[str, Any]) -> bool: limit_qty = int(row.get("limit_qty") or 0) limit_tokens = int(row.get("limit_tokens") or 0) used_cycle_qty = int(row.get("used_cycle_qty") or 0) used_cycle_tokens = int(row.get("used_cycle_tokens") or 0) qty_ok = limit_qty <= 0 or used_cycle_qty < limit_qty token_ok = limit_tokens <= 0 or used_cycle_tokens < limit_tokens return qty_ok and token_ok def reserve_upstream(conn: sqlite3.Connection, upstream_id: int, now_ts: int) -> dict[str, Any] | None: try: conn.execute("BEGIN IMMEDIATE") row = sqlite_row_to_dict( conn.execute("SELECT * FROM upstreams WHERE id = ? AND enabled = 1", (upstream_id,)).fetchone() ) if row is None: conn.execute("ROLLBACK") return None row = refresh_cycle_if_needed(conn, row, now_ts) cooldown_until = int(row.get("cooldown_until") or 0) if cooldown_until > now_ts: conn.execute("ROLLBACK") return None if not has_quota(row): conn.execute("ROLLBACK") return None conn.execute( """ UPDATE upstreams SET used_cycle_qty = used_cycle_qty + 1, used_all_qty = used_all_qty + 1, last_used_at = ?, updated_at = ? WHERE id = ? """, (now_ts, now_ts, upstream_id), ) conn.execute("COMMIT") row["used_cycle_qty"] = int(row.get("used_cycle_qty") or 0) + 1 row["used_all_qty"] = int(row.get("used_all_qty") or 0) + 1 row["last_used_at"] = now_ts return row except sqlite3.OperationalError: try: conn.execute("ROLLBACK") except sqlite3.Error: pass return None except Exception: try: conn.execute("ROLLBACK") except sqlite3.Error: pass raise def mark_upstream_success(conn: sqlite3.Connection, upstream_id: int, used_tokens: int) -> None: now_ts = utc_now_ts() conn.execute( """ UPDATE upstreams SET used_cycle_tokens = used_cycle_tokens + ?, used_all_tokens = used_all_tokens + ?, consecutive_failures = 0, cooldown_until = NULL, updated_at = ? WHERE id = ? """, (used_tokens, used_tokens, now_ts, upstream_id), ) def mark_upstream_failure(conn: sqlite3.Connection, upstream_id: int) -> None: now_ts = utc_now_ts() cooldown_until = now_ts + SETTINGS.failure_cooldown_sec if SETTINGS.failure_cooldown_sec > 0 else now_ts conn.execute( """ UPDATE upstreams SET consecutive_failures = consecutive_failures + 1, cooldown_until = ?, updated_at = ? WHERE id = ? """, (cooldown_until, now_ts, upstream_id), ) def is_retryable_status(status_code: int) -> bool: return status_code in RETRYABLE_STATUSES def is_retryable_failure(status_code: int, response_text: str) -> bool: if is_retryable_status(status_code): return True if status_code not in {400, 422}: return False text = (response_text or "").lower() # 对常见“模型不存在/模型不可用”类错误放行重试到下一个上游。 model_hints = [ "model", "not found", "no such model", "unknown model", "invalid model", "model not", "模型", "不存在", "不可用", ] return any(token in text for token in model_hints) def parse_int_like(value: Any) -> int | None: if isinstance(value, int): return value if isinstance(value, float): return int(value) if isinstance(value, str): text = value.strip() if not text: return None if text.isdigit() or (text.startswith("-") and text[1:].isdigit()): try: return int(text) except ValueError: return None try: return int(float(text)) except ValueError: return None return None def sanitize_request_payload(payload: dict[str, Any]) -> dict[str, Any]: if SETTINGS.log_full_payload: return payload messages = payload.get("messages") message_count = len(messages) if isinstance(messages, list) else 0 return { "model": payload.get("model"), "stream": bool(payload.get("stream", False)), "max_tokens": payload.get("max_tokens"), "max_output_tokens": payload.get("max_output_tokens"), "message_count": message_count, } def create_request_log(conn: sqlite3.Connection, upstream_id: int, payload: dict[str, Any]) -> str: req_id = uuid.uuid4().hex safe_payload = sanitize_request_payload(payload) payload_json = json.dumps(safe_payload, ensure_ascii=False, allow_nan=False) conn.execute( """ INSERT INTO logs(req_id, upstream_id, request_at, request_payload) VALUES (?, ?, ?, ?) """, (req_id, upstream_id, utc_now_ts(), payload_json), ) return req_id def finish_request_log( conn: sqlite3.Connection, req_id: str, status_code: int, used_tokens: int, finish_text: str | None = None, error_text: str | None = None, ) -> None: trimmed_finish = (finish_text or "")[: SETTINGS.log_text_limit] trimmed_error = (error_text or "")[: SETTINGS.log_text_limit] conn.execute( """ UPDATE logs SET finish_at = ?, status_code = ?, used_tokens = ?, finish_text = ?, error_text = ? WHERE req_id = ? """, (utc_now_ts(), status_code, used_tokens, trimmed_finish, trimmed_error, req_id), ) def merge_force_parameter(payload: dict[str, Any], upstream: dict[str, Any]) -> dict[str, Any]: merged = dict(payload) merged["model"] = upstream.get("model") force_raw = upstream.get("force_parameter") if force_raw is None: return merged force_str = str(force_raw).strip() if not force_str: return merged try: parsed = json.loads(force_str) except json.JSONDecodeError: return merged force_payload: dict[str, Any] = {} if isinstance(parsed, dict): force_payload = dict(parsed) elif isinstance(parsed, list): for item in parsed: if isinstance(item, dict): force_payload.update(item) if force_payload: merged.update(force_payload) return merged def build_request_meta(upstream: dict[str, Any]) -> tuple[str, dict[str, str]]: base = str(upstream.get("base_url", "")).rstrip("/") + "/" api_key = str(upstream.get("api_key", "")).strip() auth_value = api_key if api_key.startswith("Bearer ") else f"Bearer {api_key}" return ( base + "chat/completions", { "Authorization": auth_value, "Content-Type": "application/json", }, ) def build_models_meta(upstream: dict[str, Any]) -> tuple[str, dict[str, str]]: base = str(upstream.get("base_url", "")).rstrip("/") + "/" api_key = str(upstream.get("api_key", "")).strip() auth_value = api_key if api_key.startswith("Bearer ") else f"Bearer {api_key}" return ( base + "models", { "Authorization": auth_value, }, ) def build_admin_probe_payload(upstream: dict[str, Any]) -> dict[str, Any]: # 仅用于后台连通性探测,尽量减少消耗。 payload = { "model": str(upstream.get("model") or ""), "messages": [{"role": "user", "content": "ping"}], "stream": False, "max_tokens": 1, "temperature": 0, } return merge_force_parameter(payload, upstream) def extract_total_tokens(usage: Any) -> int: if not usage: return 0 if hasattr(usage, "model_dump"): usage = usage.model_dump(mode="json") if not isinstance(usage, dict): return 0 total_tokens = parse_int_like(usage.get("total_tokens")) if isinstance(total_tokens, int) and total_tokens > 0: return total_tokens prompt_tokens = parse_int_like(usage.get("prompt_tokens")) completion_tokens = parse_int_like(usage.get("completion_tokens")) if isinstance(prompt_tokens, int) and isinstance(completion_tokens, int): return prompt_tokens + completion_tokens return 0 def extract_finish_text_from_response(resp_json: Any) -> str: if not isinstance(resp_json, dict): return "" choices = resp_json.get("choices") if not isinstance(choices, list) or not choices: return "" first = choices[0] if not isinstance(first, dict): return "" message = first.get("message") if isinstance(message, dict): content = message.get("content") if isinstance(content, str): return content text = first.get("text") if isinstance(text, str): return text return "" def extract_delta_text_from_chunk(chunk_dict: Any) -> str: if not isinstance(chunk_dict, dict): return "" choices = chunk_dict.get("choices") if not isinstance(choices, list) or not choices: return "" first = choices[0] if not isinstance(first, dict): return "" delta = first.get("delta") if isinstance(delta, dict): content = delta.get("content") if isinstance(content, str): return content return "" def unauthorized_response() -> tuple[Response, int]: return jsonify({"error": {"message": "Unauthorized"}}), 401 def upstream_error_response(message: str, status_code: int = 502) -> Response: response = jsonify({"error": {"message": message}}) response.status_code = status_code return response def auth_header_ok(auth_header: str) -> bool: expected_raw = (SETTINGS.expected_auth or "").strip() got_raw = (auth_header or "").strip() if not expected_raw: return False if got_raw == expected_raw: return True # 允许 Bearer 前缀大小写或客户端仅传 token 本体,减少接入出错概率。 return token_body(got_raw) == token_body(expected_raw) def token_body(token: str) -> str: text = (token or "").strip() if text.lower().startswith("bearer "): return text[7:].strip() return text def effective_admin_token() -> str: if SETTINGS.admin_token: return token_body(SETTINGS.admin_token) return token_body(SETTINGS.expected_auth) def admin_token_ok(user_input: str) -> bool: expected = effective_admin_token() if not expected: return False typed = token_body(user_input) return typed == expected def is_local_request() -> bool: remote_addr = (request.remote_addr or "").strip() return remote_addr in {"127.0.0.1", "::1", "localhost"} def ensure_admin_access() -> Response | None: if not SETTINGS.admin_enabled: abort(404) if session.get("admin_ok") is True: return None # 默认没有额外管理员口令时,允许本机直接访问可视化页面,方便个人使用。 if not SETTINGS.admin_token and is_local_request(): return None return redirect(url_for("admin_login", next=request.path)) def to_form_int(raw: str, default: int = 0) -> int: try: return int((raw or "").strip()) except (TypeError, ValueError): return default def parse_force_parameter_text(raw_text: str) -> tuple[bool, str | None]: text = (raw_text or "").strip() if not text: return True, None try: parsed = json.loads(text) except json.JSONDecodeError: return False, "强制参数不是合法 JSON" if not isinstance(parsed, (dict, list)): return False, "强制参数必须是 JSON 对象或数组" return True, json.dumps(parsed, ensure_ascii=False) def format_ts(ts_value: Any) -> str: if ts_value is None: return "-" try: ts_int = int(ts_value) except (TypeError, ValueError): return "-" if ts_int <= 0: return "-" return datetime.fromtimestamp(ts_int).strftime("%Y-%m-%d %H:%M:%S") def attach_route_headers(response: Response, upstream: dict[str, Any], tried_ids: list[int]) -> Response: current_id = int(upstream.get("id") or 0) chain_ids = list(tried_ids) + [current_id] response.headers["X-Route-Upstream-Id"] = str(current_id) response.headers["X-Route-Group"] = str(upstream.get("group_name") or "") response.headers["X-Route-Model"] = str(upstream.get("model") or "") response.headers["X-Route-Tried-Ids"] = ",".join(str(item) for item in chain_ids if int(item) > 0) return response def choose_upstream(conn: sqlite3.Connection, selector: dict[str, Any], tried_ids: list[int]) -> dict[str, Any] | None: now_ts = utc_now_ts() selectors = [selector] fallback_selector = selector_group_fallback(selector) if fallback_selector is not None: selectors.append(fallback_selector) checked_ids: set[int] = set() for item in selectors: candidate_ids = fetch_candidate_ids(conn, item, tried_ids, now_ts) for candidate_id in candidate_ids: if candidate_id in checked_ids: continue checked_ids.add(candidate_id) reserved = reserve_upstream(conn, candidate_id, now_ts) if reserved is not None: return reserved return None @APP.get("/healthz") def healthz() -> dict[str, str]: return {"status": "ok"} @APP.get("/") def index_page() -> Response: admin_text = "已开启" if SETTINGS.admin_enabled else "未开启" page = f""" OpenAI 路由服务

OpenAI 路由服务

服务状态:正常

可视化管理:{admin_text}

接口入口:/v1/chat/completions/v1/models

进入可视化管理 健康检查
""" return Response(page, content_type="text/html; charset=utf-8") def limit_type_display(limit_type: str | None) -> str: mapping = { "second": "秒", "minute": "分钟", "hour": "小时", "day": "天", "month": "月", "year": "年", } alias = limit_type_alias(limit_type) return mapping.get(alias, "天") def render_limit_type_options(selected: str | None) -> str: choices = [ ("second", "秒"), ("minute", "分钟"), ("hour", "小时"), ("day", "天"), ("month", "月"), ("year", "年"), ] current = limit_type_alias(selected) parts: list[str] = [] for value, text in choices: selected_attr = " selected" if current == value else "" parts.append(f"") return "".join(parts) @APP.get("/admin/login") def admin_login() -> Response: if not SETTINGS.admin_enabled: abort(404) next_url = request.args.get("next", "/admin") if not next_url.startswith("/"): next_url = "/admin" err = request.args.get("err", "") err_html = f"

{html.escape(err)}

" if err else "" page = f""" 管理员登录

管理后台登录

请输入管理员口令后继续。

{err_html}
""" return Response(page, content_type="text/html; charset=utf-8") @APP.post("/admin/login") def admin_login_submit() -> Response: if not SETTINGS.admin_enabled: abort(404) next_url = request.form.get("next", "/admin") if not next_url.startswith("/"): next_url = "/admin" token = request.form.get("token", "") if admin_token_ok(token): session["admin_ok"] = True return redirect(next_url) return redirect(url_for("admin_login", next=next_url, err="口令错误,请重试")) @APP.get("/admin/logout") def admin_logout() -> Response: if not SETTINGS.admin_enabled: abort(404) session.pop("admin_ok", None) return redirect(url_for("admin_login", next="/admin")) @APP.get("/admin") def admin_console() -> Response: guard = ensure_admin_access() if guard is not None: return guard conn = get_db() rows = conn.execute( """ SELECT id, name, group_name, model, base_url, limit_type, limit_qty, limit_tokens, used_cycle_qty, used_cycle_tokens, used_all_qty, used_all_tokens, consecutive_failures, cooldown_until, last_used_at, enabled, updated_at FROM upstreams ORDER BY id ASC """ ).fetchall() summary_row = conn.execute( """ SELECT COUNT(*) AS req_total, COALESCE(SUM(CASE WHEN status_code = 200 THEN 1 ELSE 0 END), 0) AS req_success, COALESCE(SUM(CASE WHEN status_code IS NOT NULL AND status_code <> 200 THEN 1 ELSE 0 END), 0) AS req_fail, COALESCE(SUM(used_tokens), 0) AS token_total FROM logs """ ).fetchone() day_since_ts = utc_now_ts() - 86400 day_row = conn.execute( """ SELECT COUNT(*) AS req_day, COALESCE(SUM(used_tokens), 0) AS token_day FROM logs WHERE request_at >= ? """, (day_since_ts,), ).fetchone() top_rows = conn.execute( """ SELECT u.id AS upstream_id, u.name AS name, u.group_name AS group_name, u.model AS model, COUNT(l.id) AS req_total, COALESCE(SUM(CASE WHEN l.status_code = 200 THEN 1 ELSE 0 END), 0) AS req_success, COALESCE(SUM(CASE WHEN l.status_code IS NOT NULL AND l.status_code <> 200 THEN 1 ELSE 0 END), 0) AS req_fail, COALESCE(SUM(l.used_tokens), 0) AS token_total, MAX(l.request_at) AS last_at FROM upstreams u LEFT JOIN logs l ON l.upstream_id = u.id GROUP BY u.id, u.name, u.group_name, u.model ORDER BY req_total DESC, token_total DESC, u.id ASC LIMIT 20 """ ).fetchall() recent_rows = conn.execute( """ SELECT l.id AS log_id, l.request_at AS request_at, l.status_code AS status_code, l.used_tokens AS used_tokens, u.id AS upstream_id, u.name AS name, u.group_name AS group_name, u.model AS model FROM logs l LEFT JOIN upstreams u ON u.id = l.upstream_id ORDER BY l.id DESC LIMIT 50 """ ).fetchall() msg = request.args.get("msg", "") err = request.args.get("err", "") banner_html = "" if msg: banner_html += f"

{html.escape(msg)}

" if err: banner_html += f"

{html.escape(err)}

" summary = sqlite_row_to_dict(summary_row) or {} day_summary = sqlite_row_to_dict(day_row) or {} top_parts: list[str] = [] for item in top_rows: row = sqlite_row_to_dict(item) or {} top_parts.append( f"" f"{row.get('upstream_id')}" f"{html.escape(str(row.get('name') or ''))}" f"{html.escape(str(row.get('group_name') or ''))}" f"{html.escape(str(row.get('model') or ''))}" f"{row.get('req_total') or 0}" f"{row.get('req_success') or 0}" f"{row.get('req_fail') or 0}" f"{row.get('token_total') or 0}" f"{format_ts(row.get('last_at'))}" f"" ) top_html = "".join(top_parts) if top_parts else "暂无统计数据" recent_parts: list[str] = [] for item in recent_rows: row = sqlite_row_to_dict(item) or {} recent_parts.append( f"" f"{row.get('log_id') or '-'}" f"{format_ts(row.get('request_at'))}" f"{row.get('status_code') or '-'}" f"{row.get('used_tokens') or 0}" f"{row.get('upstream_id') or '-'}" f"{html.escape(str(row.get('name') or ''))}" f"{html.escape(str(row.get('group_name') or ''))}" f"{html.escape(str(row.get('model') or ''))}" f"" ) recent_html = "".join(recent_parts) if recent_parts else "暂无调用记录" body_rows: list[str] = [] for row in rows: row_dict = sqlite_row_to_dict(row) or {} enabled_checked = "checked" if int(row_dict.get("enabled") or 0) == 1 else "" options_html = render_limit_type_options(str(row_dict.get("limit_type") or "day")) body_rows.append( f""" {row_dict.get("id")}
周期请求: {row_dict.get('used_cycle_qty')} | 周期Token: {row_dict.get('used_cycle_tokens')} | 总请求: {row_dict.get('used_all_qty')} | 总Token: {row_dict.get('used_all_tokens')} | 连续失败: {row_dict.get('consecutive_failures')} | 冷却到: {format_ts(row_dict.get('cooldown_until'))} | 周期限额单位: {limit_type_display(str(row_dict.get('limit_type') or 'day'))}
""" ) body_html = "".join(body_rows) if body_rows else "暂无上游配置" page = f""" 可视化管理后台

上游可视化管理

这里可以新增、编辑、启停、删除上游。建议给客户端使用中文别名(例如:通用聊天、快速模型)。

{banner_html}

总览统计

总请求数
{summary.get("req_total", 0)}
成功请求
{summary.get("req_success", 0)}
失败请求
{summary.get("req_fail", 0)}
总 Token
{summary.get("token_total", 0)}
近24小时请求
{day_summary.get("req_day", 0)}
近24小时 Token
{day_summary.get("token_day", 0)}

新增上游

打开完整统计页 退出管理

上游统计 TOP 20

{top_html}
编号中文别名分组模型总请求成功失败总Token最后调用

最近 50 次调用(可直接看到命中上游)

{recent_html}
日志ID时间状态码Token上游ID中文别名分组模型
{body_html}
编号上游配置
""" return Response(page, content_type="text/html; charset=utf-8") @APP.get("/admin/stats") def admin_stats() -> Response: guard = ensure_admin_access() if guard is not None: return guard conn = get_db() daily_rows = conn.execute( """ SELECT strftime('%Y-%m-%d', datetime(request_at, 'unixepoch', 'localtime')) AS day_key, COUNT(*) AS req_total, COALESCE(SUM(CASE WHEN status_code = 200 THEN 1 ELSE 0 END), 0) AS req_success, COALESCE(SUM(CASE WHEN status_code IS NOT NULL AND status_code <> 200 THEN 1 ELSE 0 END), 0) AS req_fail, COALESCE(SUM(used_tokens), 0) AS token_total FROM logs WHERE request_at >= ? GROUP BY day_key ORDER BY day_key DESC LIMIT 14 """, (utc_now_ts() - 14 * 86400,), ).fetchall() error_rows = conn.execute( """ SELECT l.req_id AS req_id, l.status_code AS status_code, l.error_text AS error_text, l.finish_at AS finish_at, u.name AS name, u.group_name AS group_name, u.model AS model FROM logs l LEFT JOIN upstreams u ON u.id = l.upstream_id WHERE l.status_code IS NOT NULL AND l.status_code <> 200 ORDER BY l.id DESC LIMIT 100 """ ).fetchall() day_parts: list[str] = [] for item in daily_rows: row = sqlite_row_to_dict(item) or {} day_parts.append( f"" f"{html.escape(str(row.get('day_key') or '-'))}" f"{row.get('req_total') or 0}" f"{row.get('req_success') or 0}" f"{row.get('req_fail') or 0}" f"{row.get('token_total') or 0}" f"" ) day_html = "".join(day_parts) if day_parts else "暂无数据" err_parts: list[str] = [] for item in error_rows: row = sqlite_row_to_dict(item) or {} err_parts.append( f"" f"{html.escape(str(row.get('req_id') or ''))}" f"{row.get('status_code') or '-'}" f"{html.escape(str(row.get('name') or ''))}" f"{html.escape(str(row.get('group_name') or ''))}" f"{html.escape(str(row.get('model') or ''))}" f"{format_ts(row.get('finish_at'))}" f"{html.escape(str(row.get('error_text') or ''))}" f"" ) err_html = "".join(err_parts) if err_parts else "暂无失败日志" page = f""" 完整统计

完整统计

返回管理首页

近 14 天按天统计

{day_html}
日期总请求成功失败总 Token

最近 100 条失败日志

{err_html}
请求ID状态码中文别名分组模型时间错误
""" return Response(page, content_type="text/html; charset=utf-8") @APP.post("/admin/logs/cleanup") def admin_logs_cleanup() -> Response: guard = ensure_admin_access() if guard is not None: return guard days = max(1, min(3650, to_form_int(request.form.get("days"), 7))) cutoff_ts = utc_now_ts() - days * 86400 conn = get_db() cursor = conn.execute("DELETE FROM logs WHERE request_at < ?", (cutoff_ts,)) deleted_count = int(cursor.rowcount or 0) return redirect(url_for("admin_console", msg=f"日志清理完成:删除 {deleted_count} 条(保留最近 {days} 天)")) @APP.post("/admin/upstreams//test") def admin_upstream_test(upstream_id: int) -> Response: guard = ensure_admin_access() if guard is not None: return guard conn = get_db() row = conn.execute("SELECT * FROM upstreams WHERE id = ?", (upstream_id,)).fetchone() upstream = sqlite_row_to_dict(row) if upstream is None: return redirect(url_for("admin_console", err=f"连通性测试失败:编号 {upstream_id} 不存在")) models_url, models_headers = build_models_meta(upstream) try: resp = HTTP_SESSION.get( models_url, headers=models_headers, timeout=(SETTINGS.connect_timeout_sec, min(20, SETTINGS.read_timeout_sec)), ) except requests.RequestException as exc: resp = None models_error = str(exc) else: models_error = "" if resp is not None and resp.status_code == 200: return redirect(url_for("admin_console", msg=f"编号 {upstream_id} 连通性正常:/models 返回 200")) # /models 失败时,再用极小 token 的 chat/completions 探测,避免误判。 probe_payload = build_admin_probe_payload(upstream) probe_url, probe_headers = build_request_meta(upstream) try: probe_resp = HTTP_SESSION.post( probe_url, headers=probe_headers, json=probe_payload, timeout=(SETTINGS.connect_timeout_sec, min(30, SETTINGS.read_timeout_sec)), ) except requests.RequestException as exc: reason = models_error or str(exc) return redirect(url_for("admin_console", err=f"编号 {upstream_id} 连通性失败:{reason}")) if probe_resp.status_code == 200: return redirect(url_for("admin_console", msg=f"编号 {upstream_id} 连通性正常:/chat/completions 返回 200")) detail = probe_resp.text[:180].replace("\n", " ").replace("\r", " ") if resp is not None: detail = f"/models={resp.status_code}, /chat={probe_resp.status_code}, detail={detail}" else: detail = f"/models网络异常, /chat={probe_resp.status_code}, detail={detail}" return redirect(url_for("admin_console", err=f"编号 {upstream_id} 连通性失败:{detail}")) @APP.post("/admin/upstreams/create") def admin_upstream_create() -> Response: guard = ensure_admin_access() if guard is not None: return guard name = (request.form.get("name") or "").strip() group_name = (request.form.get("group_name") or "1").strip() model = (request.form.get("model") or "").strip() base_url = (request.form.get("base_url") or "").strip() api_key = (request.form.get("api_key") or "").strip() limit_type = limit_type_alias(request.form.get("limit_type")) limit_qty = max(0, to_form_int(request.form.get("limit_qty"), 0)) limit_tokens = max(0, to_form_int(request.form.get("limit_tokens"), 0)) enabled = 1 if (request.form.get("enabled") or "").strip().lower() in {"1", "on", "true"} else 0 ok_force, force_parameter = parse_force_parameter_text(request.form.get("force_parameter", "")) if not name or not model or not base_url or not api_key: return redirect(url_for("admin_console", err="新增失败:请把必填项填写完整")) if not ok_force: return redirect(url_for("admin_console", err=force_parameter or "新增失败:强制参数格式错误")) now_ts = utc_now_ts() conn = get_db() conn.execute( """ INSERT INTO upstreams( name, group_name, base_url, api_key, model, force_parameter, limit_type, limit_qty, limit_tokens, enabled, cycle_started_at, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( name, group_name or "1", base_url, api_key, model, force_parameter, limit_type, limit_qty, limit_tokens, enabled, cycle_start_ts(limit_type, now_ts), now_ts, now_ts, ), ) return redirect(url_for("admin_console", msg="新增成功")) @APP.post("/admin/upstreams//update") def admin_upstream_update(upstream_id: int) -> Response: guard = ensure_admin_access() if guard is not None: return guard name = (request.form.get("name") or "").strip() group_name = (request.form.get("group_name") or "1").strip() model = (request.form.get("model") or "").strip() base_url = (request.form.get("base_url") or "").strip() api_key = (request.form.get("api_key") or "").strip() limit_type = limit_type_alias(request.form.get("limit_type")) limit_qty = max(0, to_form_int(request.form.get("limit_qty"), 0)) limit_tokens = max(0, to_form_int(request.form.get("limit_tokens"), 0)) enabled = 1 if (request.form.get("enabled") or "").strip().lower() in {"1", "on", "true"} else 0 ok_force, force_parameter = parse_force_parameter_text(request.form.get("force_parameter", "")) if not name or not model or not base_url or not api_key: return redirect(url_for("admin_console", err=f"保存失败:编号 {upstream_id} 的必填项不能为空")) if not ok_force: return redirect(url_for("admin_console", err=f"保存失败:编号 {upstream_id} 的强制参数格式错误")) conn = get_db() current = conn.execute("SELECT id, limit_type FROM upstreams WHERE id = ?", (upstream_id,)).fetchone() if current is None: return redirect(url_for("admin_console", err=f"保存失败:编号 {upstream_id} 不存在")) now_ts = utc_now_ts() next_cycle_start = cycle_start_ts(limit_type, now_ts) old_limit_type = limit_type_alias(str(current["limit_type"] or "day")) if old_limit_type == limit_type: cycle_started_at_sql = "cycle_started_at" cycle_started_at_value = None else: cycle_started_at_sql = "?" cycle_started_at_value = next_cycle_start params: list[Any] = [ name, group_name or "1", base_url, api_key, model, force_parameter, limit_type, limit_qty, limit_tokens, enabled, now_ts, upstream_id, ] if cycle_started_at_value is None: conn.execute( """ UPDATE upstreams SET name = ?, group_name = ?, base_url = ?, api_key = ?, model = ?, force_parameter = ?, limit_type = ?, limit_qty = ?, limit_tokens = ?, enabled = ?, updated_at = ? WHERE id = ? """, params, ) else: conn.execute( f""" UPDATE upstreams SET name = ?, group_name = ?, base_url = ?, api_key = ?, model = ?, force_parameter = ?, limit_type = ?, limit_qty = ?, limit_tokens = ?, enabled = ?, cycle_started_at = {cycle_started_at_sql}, updated_at = ? WHERE id = ? """, [ name, group_name or "1", base_url, api_key, model, force_parameter, limit_type, limit_qty, limit_tokens, enabled, next_cycle_start, now_ts, upstream_id, ], ) return redirect(url_for("admin_console", msg=f"编号 {upstream_id} 已保存")) @APP.post("/admin/upstreams//delete") def admin_upstream_delete(upstream_id: int) -> Response: guard = ensure_admin_access() if guard is not None: return guard conn = get_db() try: conn.execute("BEGIN IMMEDIATE") conn.execute("DELETE FROM logs WHERE upstream_id = ?", (upstream_id,)) conn.execute("DELETE FROM upstreams WHERE id = ?", (upstream_id,)) conn.execute("COMMIT") except Exception as exc: try: conn.execute("ROLLBACK") except sqlite3.Error: pass return redirect(url_for("admin_console", err=f"删除失败:{exc}")) return redirect(url_for("admin_console", msg=f"编号 {upstream_id} 已删除(关联日志已同步清理)")) @APP.get("/v1/models") def list_models() -> Response | tuple[Response, int]: auth = request.headers.get("Authorization", "") if not auth_header_ok(auth): return unauthorized_response() conn = get_db() rows = conn.execute( """ SELECT model, MIN(created_at) AS created_at FROM upstreams WHERE enabled = 1 AND model IS NOT NULL AND TRIM(model) <> '' GROUP BY model ORDER BY model ASC """ ).fetchall() data = [] now_ts = utc_now_ts() for row in rows: model_id = str(row["model"]).strip() if not model_id: continue created_at = int(row["created_at"]) if row["created_at"] is not None else now_ts data.append( { "id": model_id, "object": "model", "created": created_at, "owned_by": "openai-proxy", } ) return jsonify({"object": "list", "data": data}) def forward_non_stream_request( conn: sqlite3.Connection, upstream: dict[str, Any], request_url: str, request_headers: dict[str, str], payload: dict[str, Any], log_id: str, tried_ids: list[int], ) -> tuple[bool, Response]: upstream_id = int(upstream["id"]) try: result = HTTP_SESSION.post( request_url, headers=request_headers, json=payload, timeout=(SETTINGS.connect_timeout_sec, SETTINGS.read_timeout_sec), ) except requests.RequestException as exc: error_text = str(exc) mark_upstream_failure(conn, upstream_id) finish_request_log(conn, log_id, status_code=599, used_tokens=0, error_text=error_text) return False, attach_route_headers(upstream_error_response(error_text, status_code=502), upstream, tried_ids) status_code = int(result.status_code) content_type = result.headers.get("Content-Type") or "application/json; charset=utf-8" body_bytes = result.content body_text = result.text if status_code != 200: mark_upstream_failure(conn, upstream_id) finish_request_log( conn, log_id, status_code=status_code, used_tokens=0, finish_text=None, error_text=body_text, ) response = attach_route_headers(Response(body_bytes, status=status_code, content_type=content_type), upstream, tried_ids) return (not is_retryable_failure(status_code, body_text)), response used_tokens = 0 finish_text = "" try: result_json = result.json() used_tokens = extract_total_tokens(result_json.get("usage")) finish_text = extract_finish_text_from_response(result_json) except ValueError: pass mark_upstream_success(conn, upstream_id, used_tokens) finish_request_log( conn, log_id, status_code=status_code, used_tokens=used_tokens, finish_text=finish_text or body_text, error_text=None, ) response = attach_route_headers(Response(body_bytes, status=status_code, content_type=content_type), upstream, tried_ids) return True, response def finalize_stream_outcome( upstream_id: int, log_id: str, used_tokens: int, finish_text: str, stream_error: str | None, ) -> None: # 流式响应结束时,请求上下文可能已经关闭,必须使用独立连接落库。 conn: sqlite3.Connection | None = None try: conn = open_sqlite_conn(SETTINGS.sqlite_path) if stream_error: mark_upstream_failure(conn, upstream_id) finish_request_log( conn, log_id, status_code=599, used_tokens=used_tokens, finish_text=finish_text, error_text=stream_error, ) else: mark_upstream_success(conn, upstream_id, used_tokens) finish_request_log( conn, log_id, status_code=200, used_tokens=used_tokens, finish_text=finish_text, error_text=None, ) except Exception: LOGGER.exception("流式请求收尾写库失败 upstream_id=%s log_id=%s", upstream_id, log_id) finally: if conn is not None: conn.close() def stream_response_generator( upstream_id: int, log_id: str, upstream_resp: requests.Response, ) -> Iterator[str]: used_tokens = 0 finish_parts: list[str] = [] stream_error: str | None = None try: for line in upstream_resp.iter_lines(decode_unicode=True): if line is None: continue if line == "": yield "\n" continue if line.startswith("data:"): data = line[len("data:") :].lstrip() if data == "[DONE]": yield "data: [DONE]\n\n" continue try: chunk_dict = json.loads(data) except json.JSONDecodeError: yield line + "\n" continue usage_tokens = extract_total_tokens(chunk_dict.get("usage")) if usage_tokens > 0: used_tokens = usage_tokens delta_text = extract_delta_text_from_chunk(chunk_dict) if delta_text: finish_parts.append(delta_text) yield "data: " + json.dumps(chunk_dict, ensure_ascii=False) + "\n\n" continue yield line + "\n" except requests.RequestException as exc: stream_error = str(exc) raise finally: upstream_resp.close() finish_text = "".join(finish_parts) finalize_stream_outcome(upstream_id, log_id, used_tokens, finish_text, stream_error) def forward_stream_request( conn: sqlite3.Connection, upstream: dict[str, Any], request_url: str, request_headers: dict[str, str], payload: dict[str, Any], log_id: str, tried_ids: list[int], ) -> tuple[bool, Response]: upstream_id = int(upstream["id"]) try: upstream_resp = HTTP_SESSION.post( request_url, headers=request_headers, json=payload, stream=True, timeout=(SETTINGS.connect_timeout_sec, SETTINGS.stream_read_timeout_sec), ) except requests.RequestException as exc: error_text = str(exc) mark_upstream_failure(conn, upstream_id) finish_request_log(conn, log_id, status_code=599, used_tokens=0, error_text=error_text) return False, attach_route_headers(upstream_error_response(error_text, status_code=502), upstream, tried_ids) status_code = int(upstream_resp.status_code) if status_code != 200: body_text = upstream_resp.text body_bytes = upstream_resp.content content_type = upstream_resp.headers.get("Content-Type") or "application/json; charset=utf-8" upstream_resp.close() mark_upstream_failure(conn, upstream_id) finish_request_log( conn, log_id, status_code=status_code, used_tokens=0, finish_text=None, error_text=body_text, ) response = attach_route_headers(Response(body_bytes, status=status_code, content_type=content_type), upstream, tried_ids) return (not is_retryable_failure(status_code, body_text)), response content_type = upstream_resp.headers.get("Content-Type") or "text/event-stream; charset=utf-8" stream_iter = stream_with_context(stream_response_generator(upstream_id, log_id, upstream_resp)) return True, attach_route_headers(Response(stream_iter, status=200, content_type=content_type), upstream, tried_ids) @APP.post("/v1/chat/completions") def proxy_chat_completions() -> Response | tuple[Response, int]: auth = request.headers.get("Authorization", "") if not auth_header_ok(auth): return unauthorized_response() user_post = request.get_json(silent=True) if not isinstance(user_post, dict): return jsonify({"error": {"message": "Invalid JSON body"}}), 400 selector = parse_selector(str(user_post.get("model", ""))) if selector is None: return jsonify({"error": {"message": "model is required"}}), 400 conn = get_db() tried_ids: list[int] = [] last_retryable_response: Response | None = None for _ in range(SETTINGS.max_failover_attempts): upstream = choose_upstream(conn, selector, tried_ids) if upstream is None: break upstream_id = int(upstream["id"]) payload = merge_force_parameter(user_post, upstream) request_url, request_headers = build_request_meta(upstream) log_id = create_request_log(conn, upstream_id, payload) stream_mode = bool(payload.get("stream", False)) if stream_mode: should_stop, response = forward_stream_request( conn, upstream, request_url, request_headers, payload, log_id, tried_ids ) else: should_stop, response = forward_non_stream_request( conn, upstream, request_url, request_headers, payload, log_id, tried_ids ) if should_stop: return response tried_ids.append(upstream_id) last_retryable_response = response if last_retryable_response is not None: return last_retryable_response no_available_message = collect_no_available_reason(conn, selector) return jsonify( { "error": { "message": no_available_message, "type": "routing_error", "code": "no_available_model_endpoint", } } ), 429 def setup_logging() -> None: logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s", ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="OpenAI compatible proxy with SQLite routing") parser.add_argument("--init-db", action="store_true", help="Only initialize SQLite schema and exit") return parser.parse_args() def startup_checks() -> None: if SETTINGS.expected_auth == DEFAULT_AUTH: LOGGER.warning( "OPENAI_PROXY_AUTH is using default value '%s'. Please set a strong bearer token before exposing the service.", DEFAULT_AUTH, ) def main() -> None: setup_logging() startup_checks() args = parse_args() init_db(SETTINGS.sqlite_path) if args.init_db: LOGGER.info("SQLite schema initialized: %s", SETTINGS.sqlite_path) return APP.run( host=SETTINGS.listen_host, port=SETTINGS.listen_port, debug=SETTINGS.debug, threaded=True, ) if __name__ == "__main__": main()