2169 lines
73 KiB
Python
2169 lines
73 KiB
Python
from __future__ import annotations
|
||
|
||
import argparse
|
||
import html
|
||
import json
|
||
import logging
|
||
import os
|
||
import sqlite3
|
||
import time
|
||
import uuid
|
||
from dataclasses import dataclass
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
from typing import Any, Iterator
|
||
|
||
import requests
|
||
from flask import Flask, Response, abort, g, jsonify, redirect, request, session, stream_with_context, url_for
|
||
from requests.adapters import HTTPAdapter
|
||
|
||
|
||
LOGGER = logging.getLogger("openai_proxy")
|
||
DEFAULT_AUTH = "Bearer change-me"
|
||
DEFAULT_SECRET_KEY = "openai-proxy-dev-secret-change-me"
|
||
RETRYABLE_STATUSES = {401, 403, 404, 408, 409, 425, 429, 500, 502, 503, 504}
|
||
|
||
|
||
def env_int(name: str, default: int) -> int:
|
||
raw = os.getenv(name)
|
||
if raw is None:
|
||
return default
|
||
try:
|
||
return int(raw.strip())
|
||
except (TypeError, ValueError):
|
||
return default
|
||
|
||
|
||
def env_bool(name: str, default: bool) -> bool:
|
||
raw = os.getenv(name)
|
||
if raw is None:
|
||
return default
|
||
value = raw.strip().lower()
|
||
return value in {"1", "true", "yes", "on"}
|
||
|
||
|
||
def load_dotenv_file(path: Path) -> None:
|
||
if not path.exists() or not path.is_file():
|
||
return
|
||
try:
|
||
lines = path.read_text(encoding="utf-8-sig").splitlines()
|
||
except OSError:
|
||
return
|
||
for line in lines:
|
||
text = line.strip()
|
||
if not text or text.startswith("#"):
|
||
continue
|
||
if text.startswith("export "):
|
||
text = text[7:].strip()
|
||
if "=" not in text:
|
||
continue
|
||
key, value = text.split("=", 1)
|
||
key = key.strip()
|
||
value = value.strip().strip("'").strip('"')
|
||
if not key:
|
||
continue
|
||
if key not in os.environ:
|
||
os.environ[key] = value
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class Settings:
|
||
expected_auth: str
|
||
sqlite_path: Path
|
||
busy_timeout_ms: int
|
||
listen_host: str
|
||
listen_port: int
|
||
connect_timeout_sec: int
|
||
read_timeout_sec: int
|
||
stream_read_timeout_sec: int
|
||
max_failover_attempts: int
|
||
auto_group_fallback: bool
|
||
failure_cooldown_sec: int
|
||
log_full_payload: bool
|
||
log_text_limit: int
|
||
admin_enabled: bool
|
||
admin_token: str
|
||
secret_key: str
|
||
debug: bool
|
||
|
||
|
||
def load_settings() -> Settings:
|
||
sqlite_path = Path(os.getenv("OPENAI_PROXY_SQLITE_PATH", "./data/openai_proxy.db")).expanduser()
|
||
return Settings(
|
||
expected_auth=os.getenv("OPENAI_PROXY_AUTH", DEFAULT_AUTH),
|
||
sqlite_path=sqlite_path,
|
||
busy_timeout_ms=env_int("OPENAI_PROXY_DB_BUSY_TIMEOUT_MS", 5000),
|
||
listen_host=os.getenv("OPENAI_PROXY_LISTEN_HOST", "0.0.0.0"),
|
||
listen_port=env_int("OPENAI_PROXY_LISTEN_PORT", 8056),
|
||
connect_timeout_sec=env_int("OPENAI_PROXY_CONNECT_TIMEOUT_SEC", 10),
|
||
read_timeout_sec=env_int("OPENAI_PROXY_READ_TIMEOUT_SEC", 120),
|
||
stream_read_timeout_sec=env_int("OPENAI_PROXY_STREAM_READ_TIMEOUT_SEC", 600),
|
||
max_failover_attempts=max(1, env_int("OPENAI_PROXY_MAX_FAILOVER_ATTEMPTS", 5)),
|
||
auto_group_fallback=env_bool("OPENAI_PROXY_AUTO_GROUP_FALLBACK", True),
|
||
failure_cooldown_sec=max(0, env_int("OPENAI_PROXY_FAILURE_COOLDOWN_SEC", 30)),
|
||
log_full_payload=env_bool("OPENAI_PROXY_LOG_FULL_PAYLOAD", False),
|
||
log_text_limit=max(256, env_int("OPENAI_PROXY_LOG_TEXT_LIMIT", 100000)),
|
||
admin_enabled=env_bool("OPENAI_PROXY_ADMIN_ENABLED", True),
|
||
admin_token=os.getenv("OPENAI_PROXY_ADMIN_TOKEN", "").strip(),
|
||
secret_key=os.getenv("OPENAI_PROXY_SECRET_KEY", DEFAULT_SECRET_KEY).strip() or DEFAULT_SECRET_KEY,
|
||
debug=env_bool("OPENAI_PROXY_DEBUG", False),
|
||
)
|
||
|
||
|
||
load_dotenv_file(Path(os.getenv("OPENAI_PROXY_ENV_FILE", ".env")))
|
||
SETTINGS = load_settings()
|
||
APP = Flask(__name__)
|
||
APP.config["SECRET_KEY"] = SETTINGS.secret_key
|
||
APP.config["SESSION_COOKIE_HTTPONLY"] = True
|
||
APP.config["SESSION_COOKIE_SAMESITE"] = "Lax"
|
||
HTTP_SESSION = requests.Session()
|
||
HTTP_SESSION.mount("http://", HTTPAdapter(pool_connections=50, pool_maxsize=50, max_retries=0))
|
||
HTTP_SESSION.mount("https://", HTTPAdapter(pool_connections=50, pool_maxsize=50, max_retries=0))
|
||
|
||
|
||
def utc_now_ts() -> int:
|
||
return int(time.time())
|
||
|
||
|
||
def limit_type_alias(limit_type: str | None) -> str:
|
||
value = (limit_type or "").strip().lower()
|
||
alias_map = {
|
||
"miao": "second",
|
||
"fen": "minute",
|
||
"shi": "hour",
|
||
"tian": "day",
|
||
"yue": "month",
|
||
"nian": "year",
|
||
"sec": "second",
|
||
"min": "minute",
|
||
}
|
||
if value in {"second", "minute", "hour", "day", "month", "year"}:
|
||
return value
|
||
return alias_map.get(value, "day")
|
||
|
||
|
||
def cycle_start_ts(limit_type: str | None, ts: int) -> int:
|
||
unit = limit_type_alias(limit_type)
|
||
dt = datetime.fromtimestamp(ts, tz=timezone.utc)
|
||
if unit == "second":
|
||
return ts
|
||
if unit == "minute":
|
||
return int(dt.replace(second=0, microsecond=0).timestamp())
|
||
if unit == "hour":
|
||
return int(dt.replace(minute=0, second=0, microsecond=0).timestamp())
|
||
if unit == "day":
|
||
return int(dt.replace(hour=0, minute=0, second=0, microsecond=0).timestamp())
|
||
if unit == "month":
|
||
return int(dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0).timestamp())
|
||
if unit == "year":
|
||
return int(dt.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0).timestamp())
|
||
return int(dt.replace(hour=0, minute=0, second=0, microsecond=0).timestamp())
|
||
|
||
|
||
def open_sqlite_conn(db_path: Path) -> sqlite3.Connection:
|
||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||
conn = sqlite3.connect(str(db_path), timeout=SETTINGS.busy_timeout_ms / 1000, isolation_level=None)
|
||
conn.row_factory = sqlite3.Row
|
||
conn.execute("PRAGMA journal_mode=WAL;")
|
||
conn.execute(f"PRAGMA busy_timeout={SETTINGS.busy_timeout_ms};")
|
||
conn.execute("PRAGMA synchronous=NORMAL;")
|
||
conn.execute("PRAGMA foreign_keys=ON;")
|
||
return conn
|
||
|
||
|
||
def init_db(db_path: Path) -> None:
|
||
conn = open_sqlite_conn(db_path)
|
||
try:
|
||
conn.executescript(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS upstreams (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
name TEXT NOT NULL,
|
||
group_name TEXT NOT NULL DEFAULT '1',
|
||
base_url TEXT NOT NULL,
|
||
api_key TEXT NOT NULL,
|
||
model TEXT NOT NULL,
|
||
force_parameter TEXT,
|
||
limit_type TEXT NOT NULL DEFAULT 'day',
|
||
limit_qty INTEGER NOT NULL DEFAULT 0,
|
||
limit_tokens INTEGER NOT NULL DEFAULT 0,
|
||
used_cycle_qty INTEGER NOT NULL DEFAULT 0,
|
||
used_cycle_tokens INTEGER NOT NULL DEFAULT 0,
|
||
used_all_qty INTEGER NOT NULL DEFAULT 0,
|
||
used_all_tokens INTEGER NOT NULL DEFAULT 0,
|
||
cycle_started_at INTEGER NOT NULL DEFAULT (unixepoch()),
|
||
last_used_at INTEGER,
|
||
consecutive_failures INTEGER NOT NULL DEFAULT 0,
|
||
cooldown_until INTEGER,
|
||
enabled INTEGER NOT NULL DEFAULT 1,
|
||
created_at INTEGER NOT NULL DEFAULT (unixepoch()),
|
||
updated_at INTEGER NOT NULL DEFAULT (unixepoch())
|
||
);
|
||
|
||
CREATE INDEX IF NOT EXISTS idx_upstreams_enabled_model
|
||
ON upstreams(enabled, model, name, group_name);
|
||
|
||
CREATE INDEX IF NOT EXISTS idx_upstreams_scheduling
|
||
ON upstreams(enabled, cooldown_until, consecutive_failures, last_used_at);
|
||
|
||
CREATE TABLE IF NOT EXISTS logs (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
req_id TEXT NOT NULL UNIQUE,
|
||
upstream_id INTEGER NOT NULL,
|
||
request_at INTEGER NOT NULL,
|
||
finish_at INTEGER,
|
||
status_code INTEGER,
|
||
used_tokens INTEGER NOT NULL DEFAULT 0,
|
||
request_payload TEXT,
|
||
finish_text TEXT,
|
||
error_text TEXT,
|
||
FOREIGN KEY (upstream_id) REFERENCES upstreams(id)
|
||
);
|
||
|
||
CREATE INDEX IF NOT EXISTS idx_logs_request_at
|
||
ON logs(request_at);
|
||
"""
|
||
)
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
def get_db() -> sqlite3.Connection:
|
||
conn = getattr(g, "_sqlite_conn", None)
|
||
if conn is None:
|
||
conn = open_sqlite_conn(SETTINGS.sqlite_path)
|
||
g._sqlite_conn = conn
|
||
return conn
|
||
|
||
|
||
@APP.teardown_appcontext
|
||
def close_db(_: BaseException | None) -> None:
|
||
conn = getattr(g, "_sqlite_conn", None)
|
||
if conn is not None:
|
||
conn.close()
|
||
g._sqlite_conn = None
|
||
|
||
|
||
def sqlite_row_to_dict(row: sqlite3.Row | None) -> dict[str, Any] | None:
|
||
if row is None:
|
||
return None
|
||
return {key: row[key] for key in row.keys()}
|
||
|
||
|
||
def parse_selector(raw_model: str) -> dict[str, Any] | None:
|
||
value = (raw_model or "").strip()[:100]
|
||
if not value:
|
||
return None
|
||
if value.lower().startswith("group:"):
|
||
group_part = value[6:]
|
||
groups = [part.strip() for part in group_part.split(",") if part.strip()]
|
||
if not groups:
|
||
return None
|
||
return {"type": "groups", "groups": groups, "value": value}
|
||
if "," in value:
|
||
groups = [part.strip() for part in value.split(",") if part.strip()]
|
||
if not groups:
|
||
return None
|
||
return {"type": "groups", "groups": groups, "value": value}
|
||
if value.isdigit():
|
||
return {"type": "groups", "groups": [value], "value": value}
|
||
return {"type": "exact", "model": value, "value": value}
|
||
|
||
|
||
def fetch_candidate_ids(
|
||
conn: sqlite3.Connection, selector: dict[str, Any], excluded_ids: list[int], now_ts: int
|
||
) -> list[int]:
|
||
params: list[Any] = []
|
||
sql = "SELECT id FROM upstreams WHERE enabled = 1"
|
||
if excluded_ids:
|
||
placeholders = ",".join("?" for _ in excluded_ids)
|
||
sql += f" AND id NOT IN ({placeholders})"
|
||
params.extend(excluded_ids)
|
||
if selector["type"] == "groups":
|
||
groups = list(selector["groups"])
|
||
if SETTINGS.auto_group_fallback and len(groups) == 1 and groups[0].strip().isdigit():
|
||
start_group = int(groups[0].strip())
|
||
group_rows = conn.execute("SELECT DISTINCT group_name FROM upstreams WHERE enabled = 1").fetchall()
|
||
numeric_groups = sorted(
|
||
{
|
||
int(str(row["group_name"]).strip())
|
||
for row in group_rows
|
||
if str(row["group_name"]).strip().isdigit()
|
||
}
|
||
)
|
||
expanded_groups = [str(item) for item in numeric_groups if item >= start_group]
|
||
if expanded_groups:
|
||
groups = expanded_groups
|
||
placeholders = ",".join("?" for _ in groups)
|
||
sql += f" AND group_name IN ({placeholders})"
|
||
params.extend(groups)
|
||
else:
|
||
sql += " AND (name = ? OR model = ?)"
|
||
params.append(selector["model"])
|
||
params.append(selector["model"])
|
||
sql += """
|
||
ORDER BY
|
||
CASE WHEN COALESCE(cooldown_until, 0) <= ? THEN 0 ELSE 1 END ASC,
|
||
-- 稳定优先:默认按配置顺序(编号)选择,失败才在单次请求内切到下一个。
|
||
id ASC
|
||
"""
|
||
params.append(now_ts)
|
||
rows = conn.execute(sql, params).fetchall()
|
||
return [int(row["id"]) for row in rows]
|
||
|
||
|
||
def selector_group_fallback(selector: dict[str, Any]) -> dict[str, Any] | None:
|
||
if selector.get("type") != "exact":
|
||
return None
|
||
raw = str(selector.get("model") or "").strip()
|
||
if not raw:
|
||
return None
|
||
return {"type": "groups", "groups": [raw], "value": f"group:{raw}"}
|
||
|
||
|
||
def evaluate_row_available_now(row: dict[str, Any], now_ts: int) -> tuple[bool, str]:
|
||
cooldown_until = int(row.get("cooldown_until") or 0)
|
||
if cooldown_until > now_ts:
|
||
return False, "cooldown"
|
||
|
||
cycle_started_at = int(row.get("cycle_started_at") or 0)
|
||
target_cycle_start = cycle_start_ts(str(row.get("limit_type") or "day"), now_ts)
|
||
used_cycle_qty = int(row.get("used_cycle_qty") or 0)
|
||
used_cycle_tokens = int(row.get("used_cycle_tokens") or 0)
|
||
if cycle_started_at < target_cycle_start:
|
||
used_cycle_qty = 0
|
||
used_cycle_tokens = 0
|
||
|
||
limit_qty = int(row.get("limit_qty") or 0)
|
||
limit_tokens = int(row.get("limit_tokens") or 0)
|
||
qty_ok = limit_qty <= 0 or used_cycle_qty < limit_qty
|
||
token_ok = limit_tokens <= 0 or used_cycle_tokens < limit_tokens
|
||
if qty_ok and token_ok:
|
||
return True, "ok"
|
||
return False, "quota"
|
||
|
||
|
||
def collect_no_available_reason(conn: sqlite3.Connection, selector: dict[str, Any]) -> str:
|
||
now_ts = utc_now_ts()
|
||
selectors = [selector]
|
||
fallback_selector = selector_group_fallback(selector)
|
||
if fallback_selector is not None:
|
||
selectors.append(fallback_selector)
|
||
|
||
matched_ids: set[int] = set()
|
||
for item in selectors:
|
||
matched_ids.update(fetch_candidate_ids(conn, item, [], now_ts))
|
||
|
||
if not matched_ids:
|
||
model_value = str(selector.get("value") or selector.get("model") or "")
|
||
return f"No available model endpoint: 未匹配到可用上游,请检查 model/分组。当前 model={model_value}"
|
||
|
||
placeholders = ",".join("?" for _ in matched_ids)
|
||
rows = conn.execute(
|
||
f"""
|
||
SELECT id, name, group_name, model, limit_type, limit_qty, limit_tokens,
|
||
used_cycle_qty, used_cycle_tokens, cycle_started_at, cooldown_until
|
||
FROM upstreams
|
||
WHERE id IN ({placeholders})
|
||
""",
|
||
tuple(sorted(matched_ids)),
|
||
).fetchall()
|
||
cooldown_count = 0
|
||
quota_count = 0
|
||
ok_count = 0
|
||
for row in rows:
|
||
row_dict = sqlite_row_to_dict(row) or {}
|
||
available, reason = evaluate_row_available_now(row_dict, now_ts)
|
||
if available:
|
||
ok_count += 1
|
||
elif reason == "cooldown":
|
||
cooldown_count += 1
|
||
elif reason == "quota":
|
||
quota_count += 1
|
||
|
||
return (
|
||
"No available model endpoint: "
|
||
f"匹配到 {len(matched_ids)} 条上游,当前可用 {ok_count} 条,冷却中 {cooldown_count} 条,限额阻塞 {quota_count} 条。"
|
||
)
|
||
|
||
|
||
def refresh_cycle_if_needed(conn: sqlite3.Connection, row: dict[str, Any], now_ts: int) -> dict[str, Any]:
|
||
current_cycle_start = int(row.get("cycle_started_at") or 0)
|
||
target_cycle_start = cycle_start_ts(row.get("limit_type"), now_ts)
|
||
if current_cycle_start >= target_cycle_start:
|
||
return row
|
||
conn.execute(
|
||
"""
|
||
UPDATE upstreams
|
||
SET used_cycle_qty = 0,
|
||
used_cycle_tokens = 0,
|
||
cycle_started_at = ?,
|
||
updated_at = ?
|
||
WHERE id = ?
|
||
""",
|
||
(target_cycle_start, now_ts, row["id"]),
|
||
)
|
||
row["used_cycle_qty"] = 0
|
||
row["used_cycle_tokens"] = 0
|
||
row["cycle_started_at"] = target_cycle_start
|
||
return row
|
||
|
||
|
||
def has_quota(row: dict[str, Any]) -> bool:
|
||
limit_qty = int(row.get("limit_qty") or 0)
|
||
limit_tokens = int(row.get("limit_tokens") or 0)
|
||
used_cycle_qty = int(row.get("used_cycle_qty") or 0)
|
||
used_cycle_tokens = int(row.get("used_cycle_tokens") or 0)
|
||
qty_ok = limit_qty <= 0 or used_cycle_qty < limit_qty
|
||
token_ok = limit_tokens <= 0 or used_cycle_tokens < limit_tokens
|
||
return qty_ok and token_ok
|
||
|
||
|
||
def reserve_upstream(conn: sqlite3.Connection, upstream_id: int, now_ts: int) -> dict[str, Any] | None:
|
||
try:
|
||
conn.execute("BEGIN IMMEDIATE")
|
||
row = sqlite_row_to_dict(
|
||
conn.execute("SELECT * FROM upstreams WHERE id = ? AND enabled = 1", (upstream_id,)).fetchone()
|
||
)
|
||
if row is None:
|
||
conn.execute("ROLLBACK")
|
||
return None
|
||
row = refresh_cycle_if_needed(conn, row, now_ts)
|
||
cooldown_until = int(row.get("cooldown_until") or 0)
|
||
if cooldown_until > now_ts:
|
||
conn.execute("ROLLBACK")
|
||
return None
|
||
if not has_quota(row):
|
||
conn.execute("ROLLBACK")
|
||
return None
|
||
conn.execute(
|
||
"""
|
||
UPDATE upstreams
|
||
SET used_cycle_qty = used_cycle_qty + 1,
|
||
used_all_qty = used_all_qty + 1,
|
||
last_used_at = ?,
|
||
updated_at = ?
|
||
WHERE id = ?
|
||
""",
|
||
(now_ts, now_ts, upstream_id),
|
||
)
|
||
conn.execute("COMMIT")
|
||
row["used_cycle_qty"] = int(row.get("used_cycle_qty") or 0) + 1
|
||
row["used_all_qty"] = int(row.get("used_all_qty") or 0) + 1
|
||
row["last_used_at"] = now_ts
|
||
return row
|
||
except sqlite3.OperationalError:
|
||
try:
|
||
conn.execute("ROLLBACK")
|
||
except sqlite3.Error:
|
||
pass
|
||
return None
|
||
except Exception:
|
||
try:
|
||
conn.execute("ROLLBACK")
|
||
except sqlite3.Error:
|
||
pass
|
||
raise
|
||
|
||
|
||
def mark_upstream_success(conn: sqlite3.Connection, upstream_id: int, used_tokens: int) -> None:
|
||
now_ts = utc_now_ts()
|
||
conn.execute(
|
||
"""
|
||
UPDATE upstreams
|
||
SET used_cycle_tokens = used_cycle_tokens + ?,
|
||
used_all_tokens = used_all_tokens + ?,
|
||
consecutive_failures = 0,
|
||
cooldown_until = NULL,
|
||
updated_at = ?
|
||
WHERE id = ?
|
||
""",
|
||
(used_tokens, used_tokens, now_ts, upstream_id),
|
||
)
|
||
|
||
|
||
def mark_upstream_failure(conn: sqlite3.Connection, upstream_id: int) -> None:
|
||
now_ts = utc_now_ts()
|
||
cooldown_until = now_ts + SETTINGS.failure_cooldown_sec if SETTINGS.failure_cooldown_sec > 0 else now_ts
|
||
conn.execute(
|
||
"""
|
||
UPDATE upstreams
|
||
SET consecutive_failures = consecutive_failures + 1,
|
||
cooldown_until = ?,
|
||
updated_at = ?
|
||
WHERE id = ?
|
||
""",
|
||
(cooldown_until, now_ts, upstream_id),
|
||
)
|
||
|
||
|
||
def is_retryable_status(status_code: int) -> bool:
|
||
return status_code in RETRYABLE_STATUSES
|
||
|
||
|
||
def is_retryable_failure(status_code: int, response_text: str) -> bool:
|
||
if is_retryable_status(status_code):
|
||
return True
|
||
if status_code not in {400, 422}:
|
||
return False
|
||
text = (response_text or "").lower()
|
||
# 对常见“模型不存在/模型不可用”类错误放行重试到下一个上游。
|
||
model_hints = [
|
||
"model",
|
||
"not found",
|
||
"no such model",
|
||
"unknown model",
|
||
"invalid model",
|
||
"model not",
|
||
"模型",
|
||
"不存在",
|
||
"不可用",
|
||
]
|
||
return any(token in text for token in model_hints)
|
||
|
||
|
||
def parse_int_like(value: Any) -> int | None:
|
||
if isinstance(value, int):
|
||
return value
|
||
if isinstance(value, float):
|
||
return int(value)
|
||
if isinstance(value, str):
|
||
text = value.strip()
|
||
if not text:
|
||
return None
|
||
if text.isdigit() or (text.startswith("-") and text[1:].isdigit()):
|
||
try:
|
||
return int(text)
|
||
except ValueError:
|
||
return None
|
||
try:
|
||
return int(float(text))
|
||
except ValueError:
|
||
return None
|
||
return None
|
||
|
||
|
||
def sanitize_request_payload(payload: dict[str, Any]) -> dict[str, Any]:
|
||
if SETTINGS.log_full_payload:
|
||
return payload
|
||
messages = payload.get("messages")
|
||
message_count = len(messages) if isinstance(messages, list) else 0
|
||
return {
|
||
"model": payload.get("model"),
|
||
"stream": bool(payload.get("stream", False)),
|
||
"max_tokens": payload.get("max_tokens"),
|
||
"max_output_tokens": payload.get("max_output_tokens"),
|
||
"message_count": message_count,
|
||
}
|
||
|
||
|
||
def create_request_log(conn: sqlite3.Connection, upstream_id: int, payload: dict[str, Any]) -> str:
|
||
req_id = uuid.uuid4().hex
|
||
safe_payload = sanitize_request_payload(payload)
|
||
payload_json = json.dumps(safe_payload, ensure_ascii=False, allow_nan=False)
|
||
conn.execute(
|
||
"""
|
||
INSERT INTO logs(req_id, upstream_id, request_at, request_payload)
|
||
VALUES (?, ?, ?, ?)
|
||
""",
|
||
(req_id, upstream_id, utc_now_ts(), payload_json),
|
||
)
|
||
return req_id
|
||
|
||
|
||
def finish_request_log(
|
||
conn: sqlite3.Connection,
|
||
req_id: str,
|
||
status_code: int,
|
||
used_tokens: int,
|
||
finish_text: str | None = None,
|
||
error_text: str | None = None,
|
||
) -> None:
|
||
trimmed_finish = (finish_text or "")[: SETTINGS.log_text_limit]
|
||
trimmed_error = (error_text or "")[: SETTINGS.log_text_limit]
|
||
conn.execute(
|
||
"""
|
||
UPDATE logs
|
||
SET finish_at = ?,
|
||
status_code = ?,
|
||
used_tokens = ?,
|
||
finish_text = ?,
|
||
error_text = ?
|
||
WHERE req_id = ?
|
||
""",
|
||
(utc_now_ts(), status_code, used_tokens, trimmed_finish, trimmed_error, req_id),
|
||
)
|
||
|
||
|
||
def merge_force_parameter(payload: dict[str, Any], upstream: dict[str, Any]) -> dict[str, Any]:
|
||
merged = dict(payload)
|
||
merged["model"] = upstream.get("model")
|
||
force_raw = upstream.get("force_parameter")
|
||
if force_raw is None:
|
||
return merged
|
||
force_str = str(force_raw).strip()
|
||
if not force_str:
|
||
return merged
|
||
try:
|
||
parsed = json.loads(force_str)
|
||
except json.JSONDecodeError:
|
||
return merged
|
||
force_payload: dict[str, Any] = {}
|
||
if isinstance(parsed, dict):
|
||
force_payload = dict(parsed)
|
||
elif isinstance(parsed, list):
|
||
for item in parsed:
|
||
if isinstance(item, dict):
|
||
force_payload.update(item)
|
||
if force_payload:
|
||
merged.update(force_payload)
|
||
return merged
|
||
|
||
|
||
def build_request_meta(upstream: dict[str, Any]) -> tuple[str, dict[str, str]]:
|
||
base = str(upstream.get("base_url", "")).rstrip("/") + "/"
|
||
api_key = str(upstream.get("api_key", "")).strip()
|
||
auth_value = api_key if api_key.startswith("Bearer ") else f"Bearer {api_key}"
|
||
return (
|
||
base + "chat/completions",
|
||
{
|
||
"Authorization": auth_value,
|
||
"Content-Type": "application/json",
|
||
},
|
||
)
|
||
|
||
|
||
def build_models_meta(upstream: dict[str, Any]) -> tuple[str, dict[str, str]]:
|
||
base = str(upstream.get("base_url", "")).rstrip("/") + "/"
|
||
api_key = str(upstream.get("api_key", "")).strip()
|
||
auth_value = api_key if api_key.startswith("Bearer ") else f"Bearer {api_key}"
|
||
return (
|
||
base + "models",
|
||
{
|
||
"Authorization": auth_value,
|
||
},
|
||
)
|
||
|
||
|
||
def build_admin_probe_payload(upstream: dict[str, Any]) -> dict[str, Any]:
|
||
# 仅用于后台连通性探测,尽量减少消耗。
|
||
payload = {
|
||
"model": str(upstream.get("model") or ""),
|
||
"messages": [{"role": "user", "content": "ping"}],
|
||
"stream": False,
|
||
"max_tokens": 1,
|
||
"temperature": 0,
|
||
}
|
||
return merge_force_parameter(payload, upstream)
|
||
|
||
|
||
def extract_total_tokens(usage: Any) -> int:
|
||
if not usage:
|
||
return 0
|
||
if hasattr(usage, "model_dump"):
|
||
usage = usage.model_dump(mode="json")
|
||
if not isinstance(usage, dict):
|
||
return 0
|
||
total_tokens = parse_int_like(usage.get("total_tokens"))
|
||
if isinstance(total_tokens, int) and total_tokens > 0:
|
||
return total_tokens
|
||
prompt_tokens = parse_int_like(usage.get("prompt_tokens"))
|
||
completion_tokens = parse_int_like(usage.get("completion_tokens"))
|
||
if isinstance(prompt_tokens, int) and isinstance(completion_tokens, int):
|
||
return prompt_tokens + completion_tokens
|
||
return 0
|
||
|
||
|
||
def extract_finish_text_from_response(resp_json: Any) -> str:
|
||
if not isinstance(resp_json, dict):
|
||
return ""
|
||
choices = resp_json.get("choices")
|
||
if not isinstance(choices, list) or not choices:
|
||
return ""
|
||
first = choices[0]
|
||
if not isinstance(first, dict):
|
||
return ""
|
||
message = first.get("message")
|
||
if isinstance(message, dict):
|
||
content = message.get("content")
|
||
if isinstance(content, str):
|
||
return content
|
||
text = first.get("text")
|
||
if isinstance(text, str):
|
||
return text
|
||
return ""
|
||
|
||
|
||
def extract_delta_text_from_chunk(chunk_dict: Any) -> str:
|
||
if not isinstance(chunk_dict, dict):
|
||
return ""
|
||
choices = chunk_dict.get("choices")
|
||
if not isinstance(choices, list) or not choices:
|
||
return ""
|
||
first = choices[0]
|
||
if not isinstance(first, dict):
|
||
return ""
|
||
delta = first.get("delta")
|
||
if isinstance(delta, dict):
|
||
content = delta.get("content")
|
||
if isinstance(content, str):
|
||
return content
|
||
return ""
|
||
|
||
|
||
def unauthorized_response() -> tuple[Response, int]:
|
||
return jsonify({"error": {"message": "Unauthorized"}}), 401
|
||
|
||
|
||
def upstream_error_response(message: str, status_code: int = 502) -> Response:
|
||
response = jsonify({"error": {"message": message}})
|
||
response.status_code = status_code
|
||
return response
|
||
|
||
|
||
def auth_header_ok(auth_header: str) -> bool:
|
||
expected_raw = (SETTINGS.expected_auth or "").strip()
|
||
got_raw = (auth_header or "").strip()
|
||
if not expected_raw:
|
||
return False
|
||
if got_raw == expected_raw:
|
||
return True
|
||
# 允许 Bearer 前缀大小写或客户端仅传 token 本体,减少接入出错概率。
|
||
return token_body(got_raw) == token_body(expected_raw)
|
||
|
||
|
||
def token_body(token: str) -> str:
|
||
text = (token or "").strip()
|
||
if text.lower().startswith("bearer "):
|
||
return text[7:].strip()
|
||
return text
|
||
|
||
|
||
def effective_admin_token() -> str:
|
||
if SETTINGS.admin_token:
|
||
return token_body(SETTINGS.admin_token)
|
||
return token_body(SETTINGS.expected_auth)
|
||
|
||
|
||
def admin_token_ok(user_input: str) -> bool:
|
||
expected = effective_admin_token()
|
||
if not expected:
|
||
return False
|
||
typed = token_body(user_input)
|
||
return typed == expected
|
||
|
||
|
||
def is_local_request() -> bool:
|
||
remote_addr = (request.remote_addr or "").strip()
|
||
return remote_addr in {"127.0.0.1", "::1", "localhost"}
|
||
|
||
|
||
def ensure_admin_access() -> Response | None:
|
||
if not SETTINGS.admin_enabled:
|
||
abort(404)
|
||
if session.get("admin_ok") is True:
|
||
return None
|
||
# 默认没有额外管理员口令时,允许本机直接访问可视化页面,方便个人使用。
|
||
if not SETTINGS.admin_token and is_local_request():
|
||
return None
|
||
return redirect(url_for("admin_login", next=request.path))
|
||
|
||
|
||
def to_form_int(raw: str, default: int = 0) -> int:
|
||
try:
|
||
return int((raw or "").strip())
|
||
except (TypeError, ValueError):
|
||
return default
|
||
|
||
|
||
def parse_force_parameter_text(raw_text: str) -> tuple[bool, str | None]:
|
||
text = (raw_text or "").strip()
|
||
if not text:
|
||
return True, None
|
||
try:
|
||
parsed = json.loads(text)
|
||
except json.JSONDecodeError:
|
||
return False, "强制参数不是合法 JSON"
|
||
if not isinstance(parsed, (dict, list)):
|
||
return False, "强制参数必须是 JSON 对象或数组"
|
||
return True, json.dumps(parsed, ensure_ascii=False)
|
||
|
||
|
||
def format_ts(ts_value: Any) -> str:
|
||
if ts_value is None:
|
||
return "-"
|
||
try:
|
||
ts_int = int(ts_value)
|
||
except (TypeError, ValueError):
|
||
return "-"
|
||
if ts_int <= 0:
|
||
return "-"
|
||
return datetime.fromtimestamp(ts_int).strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
|
||
def attach_route_headers(response: Response, upstream: dict[str, Any], tried_ids: list[int]) -> Response:
|
||
current_id = int(upstream.get("id") or 0)
|
||
chain_ids = list(tried_ids) + [current_id]
|
||
response.headers["X-Route-Upstream-Id"] = str(current_id)
|
||
response.headers["X-Route-Group"] = str(upstream.get("group_name") or "")
|
||
response.headers["X-Route-Model"] = str(upstream.get("model") or "")
|
||
response.headers["X-Route-Tried-Ids"] = ",".join(str(item) for item in chain_ids if int(item) > 0)
|
||
return response
|
||
|
||
|
||
def choose_upstream(conn: sqlite3.Connection, selector: dict[str, Any], tried_ids: list[int]) -> dict[str, Any] | None:
|
||
now_ts = utc_now_ts()
|
||
selectors = [selector]
|
||
fallback_selector = selector_group_fallback(selector)
|
||
if fallback_selector is not None:
|
||
selectors.append(fallback_selector)
|
||
checked_ids: set[int] = set()
|
||
for item in selectors:
|
||
candidate_ids = fetch_candidate_ids(conn, item, tried_ids, now_ts)
|
||
for candidate_id in candidate_ids:
|
||
if candidate_id in checked_ids:
|
||
continue
|
||
checked_ids.add(candidate_id)
|
||
reserved = reserve_upstream(conn, candidate_id, now_ts)
|
||
if reserved is not None:
|
||
return reserved
|
||
return None
|
||
|
||
|
||
@APP.get("/healthz")
|
||
def healthz() -> dict[str, str]:
|
||
return {"status": "ok"}
|
||
|
||
|
||
@APP.get("/")
|
||
def index_page() -> Response:
|
||
admin_text = "已开启" if SETTINGS.admin_enabled else "未开启"
|
||
page = f"""<!doctype html>
|
||
<html lang="zh-CN">
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||
<title>OpenAI 路由服务</title>
|
||
<style>
|
||
body {{
|
||
margin: 0;
|
||
padding: 32px;
|
||
font-family: "Microsoft YaHei", "PingFang SC", "Segoe UI", sans-serif;
|
||
background: #f4f6fb;
|
||
color: #1c2432;
|
||
}}
|
||
.card {{
|
||
max-width: 840px;
|
||
margin: 0 auto;
|
||
background: #fff;
|
||
border: 1px solid #e2e8f0;
|
||
border-radius: 14px;
|
||
padding: 24px;
|
||
box-shadow: 0 8px 30px rgba(23, 35, 67, 0.08);
|
||
}}
|
||
h1 {{
|
||
margin: 0 0 12px;
|
||
font-size: 26px;
|
||
}}
|
||
p {{
|
||
margin: 0 0 10px;
|
||
color: #445268;
|
||
line-height: 1.6;
|
||
}}
|
||
.btn {{
|
||
display: inline-block;
|
||
margin-right: 12px;
|
||
margin-top: 14px;
|
||
padding: 10px 14px;
|
||
border-radius: 10px;
|
||
text-decoration: none;
|
||
border: 1px solid #c6d2ea;
|
||
color: #163a78;
|
||
background: #eef4ff;
|
||
font-weight: 600;
|
||
}}
|
||
</style>
|
||
</head>
|
||
<body>
|
||
<div class="card">
|
||
<h1>OpenAI 路由服务</h1>
|
||
<p>服务状态:正常</p>
|
||
<p>可视化管理:{admin_text}</p>
|
||
<p>接口入口:<code>/v1/chat/completions</code>、<code>/v1/models</code></p>
|
||
<a class="btn" href="/admin">进入可视化管理</a>
|
||
<a class="btn" href="/healthz">健康检查</a>
|
||
</div>
|
||
</body>
|
||
</html>"""
|
||
return Response(page, content_type="text/html; charset=utf-8")
|
||
|
||
|
||
def limit_type_display(limit_type: str | None) -> str:
|
||
mapping = {
|
||
"second": "秒",
|
||
"minute": "分钟",
|
||
"hour": "小时",
|
||
"day": "天",
|
||
"month": "月",
|
||
"year": "年",
|
||
}
|
||
alias = limit_type_alias(limit_type)
|
||
return mapping.get(alias, "天")
|
||
|
||
|
||
def render_limit_type_options(selected: str | None) -> str:
|
||
choices = [
|
||
("second", "秒"),
|
||
("minute", "分钟"),
|
||
("hour", "小时"),
|
||
("day", "天"),
|
||
("month", "月"),
|
||
("year", "年"),
|
||
]
|
||
current = limit_type_alias(selected)
|
||
parts: list[str] = []
|
||
for value, text in choices:
|
||
selected_attr = " selected" if current == value else ""
|
||
parts.append(f"<option value='{value}'{selected_attr}>{text}</option>")
|
||
return "".join(parts)
|
||
|
||
|
||
@APP.get("/admin/login")
|
||
def admin_login() -> Response:
|
||
if not SETTINGS.admin_enabled:
|
||
abort(404)
|
||
next_url = request.args.get("next", "/admin")
|
||
if not next_url.startswith("/"):
|
||
next_url = "/admin"
|
||
err = request.args.get("err", "")
|
||
err_html = f"<p class='err'>{html.escape(err)}</p>" if err else ""
|
||
page = f"""<!doctype html>
|
||
<html lang="zh-CN">
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||
<title>管理员登录</title>
|
||
<style>
|
||
body {{
|
||
margin: 0;
|
||
min-height: 100vh;
|
||
display: grid;
|
||
place-items: center;
|
||
background: radial-gradient(circle at top, #f6fafc, #e8eef8);
|
||
font-family: "Microsoft YaHei", "PingFang SC", "Segoe UI", sans-serif;
|
||
color: #20304a;
|
||
}}
|
||
.card {{
|
||
width: min(460px, 92vw);
|
||
background: #ffffff;
|
||
border: 1px solid #dbe5f5;
|
||
border-radius: 14px;
|
||
box-shadow: 0 12px 32px rgba(23, 45, 88, 0.12);
|
||
padding: 24px;
|
||
}}
|
||
h1 {{
|
||
margin: 0 0 8px;
|
||
font-size: 24px;
|
||
}}
|
||
p {{
|
||
color: #4c5f7c;
|
||
}}
|
||
.err {{
|
||
color: #b42318;
|
||
font-weight: 700;
|
||
}}
|
||
input {{
|
||
width: 100%;
|
||
box-sizing: border-box;
|
||
border: 1px solid #c7d4eb;
|
||
border-radius: 10px;
|
||
padding: 10px 12px;
|
||
font-size: 14px;
|
||
margin-top: 6px;
|
||
}}
|
||
button {{
|
||
margin-top: 14px;
|
||
width: 100%;
|
||
border: 0;
|
||
border-radius: 10px;
|
||
padding: 11px 12px;
|
||
font-size: 15px;
|
||
font-weight: 700;
|
||
color: #fff;
|
||
background: #2563eb;
|
||
cursor: pointer;
|
||
}}
|
||
</style>
|
||
</head>
|
||
<body>
|
||
<div class="card">
|
||
<h1>管理后台登录</h1>
|
||
<p>请输入管理员口令后继续。</p>
|
||
{err_html}
|
||
<form method="post" action="/admin/login">
|
||
<input type="hidden" name="next" value="{html.escape(next_url)}" />
|
||
<label>管理员口令</label>
|
||
<input type="password" name="token" placeholder="请输入口令" required />
|
||
<button type="submit">进入后台</button>
|
||
</form>
|
||
</div>
|
||
</body>
|
||
</html>"""
|
||
return Response(page, content_type="text/html; charset=utf-8")
|
||
|
||
|
||
@APP.post("/admin/login")
|
||
def admin_login_submit() -> Response:
|
||
if not SETTINGS.admin_enabled:
|
||
abort(404)
|
||
next_url = request.form.get("next", "/admin")
|
||
if not next_url.startswith("/"):
|
||
next_url = "/admin"
|
||
token = request.form.get("token", "")
|
||
if admin_token_ok(token):
|
||
session["admin_ok"] = True
|
||
return redirect(next_url)
|
||
return redirect(url_for("admin_login", next=next_url, err="口令错误,请重试"))
|
||
|
||
|
||
@APP.get("/admin/logout")
|
||
def admin_logout() -> Response:
|
||
if not SETTINGS.admin_enabled:
|
||
abort(404)
|
||
session.pop("admin_ok", None)
|
||
return redirect(url_for("admin_login", next="/admin"))
|
||
|
||
|
||
@APP.get("/admin")
|
||
def admin_console() -> Response:
|
||
guard = ensure_admin_access()
|
||
if guard is not None:
|
||
return guard
|
||
conn = get_db()
|
||
rows = conn.execute(
|
||
"""
|
||
SELECT id, name, group_name, model, base_url, limit_type, limit_qty, limit_tokens,
|
||
used_cycle_qty, used_cycle_tokens, used_all_qty, used_all_tokens,
|
||
consecutive_failures, cooldown_until, last_used_at, enabled, updated_at
|
||
FROM upstreams
|
||
ORDER BY id ASC
|
||
"""
|
||
).fetchall()
|
||
summary_row = conn.execute(
|
||
"""
|
||
SELECT
|
||
COUNT(*) AS req_total,
|
||
COALESCE(SUM(CASE WHEN status_code = 200 THEN 1 ELSE 0 END), 0) AS req_success,
|
||
COALESCE(SUM(CASE WHEN status_code IS NOT NULL AND status_code <> 200 THEN 1 ELSE 0 END), 0) AS req_fail,
|
||
COALESCE(SUM(used_tokens), 0) AS token_total
|
||
FROM logs
|
||
"""
|
||
).fetchone()
|
||
day_since_ts = utc_now_ts() - 86400
|
||
day_row = conn.execute(
|
||
"""
|
||
SELECT
|
||
COUNT(*) AS req_day,
|
||
COALESCE(SUM(used_tokens), 0) AS token_day
|
||
FROM logs
|
||
WHERE request_at >= ?
|
||
""",
|
||
(day_since_ts,),
|
||
).fetchone()
|
||
top_rows = conn.execute(
|
||
"""
|
||
SELECT
|
||
u.id AS upstream_id,
|
||
u.name AS name,
|
||
u.group_name AS group_name,
|
||
u.model AS model,
|
||
COUNT(l.id) AS req_total,
|
||
COALESCE(SUM(CASE WHEN l.status_code = 200 THEN 1 ELSE 0 END), 0) AS req_success,
|
||
COALESCE(SUM(CASE WHEN l.status_code IS NOT NULL AND l.status_code <> 200 THEN 1 ELSE 0 END), 0) AS req_fail,
|
||
COALESCE(SUM(l.used_tokens), 0) AS token_total,
|
||
MAX(l.request_at) AS last_at
|
||
FROM upstreams u
|
||
LEFT JOIN logs l ON l.upstream_id = u.id
|
||
GROUP BY u.id, u.name, u.group_name, u.model
|
||
ORDER BY req_total DESC, token_total DESC, u.id ASC
|
||
LIMIT 20
|
||
"""
|
||
).fetchall()
|
||
recent_rows = conn.execute(
|
||
"""
|
||
SELECT
|
||
l.id AS log_id,
|
||
l.request_at AS request_at,
|
||
l.status_code AS status_code,
|
||
l.used_tokens AS used_tokens,
|
||
u.id AS upstream_id,
|
||
u.name AS name,
|
||
u.group_name AS group_name,
|
||
u.model AS model
|
||
FROM logs l
|
||
LEFT JOIN upstreams u ON u.id = l.upstream_id
|
||
ORDER BY l.id DESC
|
||
LIMIT 50
|
||
"""
|
||
).fetchall()
|
||
msg = request.args.get("msg", "")
|
||
err = request.args.get("err", "")
|
||
banner_html = ""
|
||
if msg:
|
||
banner_html += f"<p class='msg'>{html.escape(msg)}</p>"
|
||
if err:
|
||
banner_html += f"<p class='err'>{html.escape(err)}</p>"
|
||
summary = sqlite_row_to_dict(summary_row) or {}
|
||
day_summary = sqlite_row_to_dict(day_row) or {}
|
||
top_parts: list[str] = []
|
||
for item in top_rows:
|
||
row = sqlite_row_to_dict(item) or {}
|
||
top_parts.append(
|
||
f"<tr>"
|
||
f"<td>{row.get('upstream_id')}</td>"
|
||
f"<td>{html.escape(str(row.get('name') or ''))}</td>"
|
||
f"<td>{html.escape(str(row.get('group_name') or ''))}</td>"
|
||
f"<td>{html.escape(str(row.get('model') or ''))}</td>"
|
||
f"<td>{row.get('req_total') or 0}</td>"
|
||
f"<td>{row.get('req_success') or 0}</td>"
|
||
f"<td>{row.get('req_fail') or 0}</td>"
|
||
f"<td>{row.get('token_total') or 0}</td>"
|
||
f"<td>{format_ts(row.get('last_at'))}</td>"
|
||
f"</tr>"
|
||
)
|
||
top_html = "".join(top_parts) if top_parts else "<tr><td colspan='9'>暂无统计数据</td></tr>"
|
||
recent_parts: list[str] = []
|
||
for item in recent_rows:
|
||
row = sqlite_row_to_dict(item) or {}
|
||
recent_parts.append(
|
||
f"<tr>"
|
||
f"<td>{row.get('log_id') or '-'}</td>"
|
||
f"<td>{format_ts(row.get('request_at'))}</td>"
|
||
f"<td>{row.get('status_code') or '-'}</td>"
|
||
f"<td>{row.get('used_tokens') or 0}</td>"
|
||
f"<td>{row.get('upstream_id') or '-'}</td>"
|
||
f"<td>{html.escape(str(row.get('name') or ''))}</td>"
|
||
f"<td>{html.escape(str(row.get('group_name') or ''))}</td>"
|
||
f"<td>{html.escape(str(row.get('model') or ''))}</td>"
|
||
f"</tr>"
|
||
)
|
||
recent_html = "".join(recent_parts) if recent_parts else "<tr><td colspan='8'>暂无调用记录</td></tr>"
|
||
body_rows: list[str] = []
|
||
for row in rows:
|
||
row_dict = sqlite_row_to_dict(row) or {}
|
||
enabled_checked = "checked" if int(row_dict.get("enabled") or 0) == 1 else ""
|
||
options_html = render_limit_type_options(str(row_dict.get("limit_type") or "day"))
|
||
body_rows.append(
|
||
f"""
|
||
<tr>
|
||
<td>{row_dict.get("id")}</td>
|
||
<td colspan="11">
|
||
<form method="post" action="/admin/upstreams/{row_dict.get('id')}/update" class="row-form">
|
||
<label>中文别名</label><input name="name" value="{html.escape(str(row_dict.get('name') or ''))}" required />
|
||
<label>分组编号</label><input name="group_name" value="{html.escape(str(row_dict.get('group_name') or '1'))}" required />
|
||
<label>上游模型名</label><input name="model" value="{html.escape(str(row_dict.get('model') or ''))}" required />
|
||
<label>上游地址</label><input name="base_url" value="{html.escape(str(row_dict.get('base_url') or ''))}" required />
|
||
<label>API 密钥</label><input name="api_key" value="{html.escape(str(row_dict.get('api_key') or ''))}" required />
|
||
<label>强制参数(JSON)</label><textarea name="force_parameter">{html.escape(str(row_dict.get('force_parameter') or ''))}</textarea>
|
||
<label>限额周期</label><select name="limit_type">{options_html}</select>
|
||
<label>请求限额</label><input name="limit_qty" value="{row_dict.get('limit_qty')}" />
|
||
<label>Token 限额</label><input name="limit_tokens" value="{row_dict.get('limit_tokens')}" />
|
||
<label class="check"><input type="checkbox" name="enabled" value="1" {enabled_checked} /> 启用</label>
|
||
<div class="stats">
|
||
周期请求: {row_dict.get('used_cycle_qty')} |
|
||
周期Token: {row_dict.get('used_cycle_tokens')} |
|
||
总请求: {row_dict.get('used_all_qty')} |
|
||
总Token: {row_dict.get('used_all_tokens')} |
|
||
连续失败: {row_dict.get('consecutive_failures')} |
|
||
冷却到: {format_ts(row_dict.get('cooldown_until'))} |
|
||
周期限额单位: {limit_type_display(str(row_dict.get('limit_type') or 'day'))}
|
||
</div>
|
||
<button type="submit" class="btn-save">保存修改</button>
|
||
<button type="submit" formaction="/admin/upstreams/{row_dict.get('id')}/test" formnovalidate class="btn-test">测试连通性</button>
|
||
<button type="submit" formaction="/admin/upstreams/{row_dict.get('id')}/delete" formnovalidate class="btn-del" onclick="return confirm('确认删除这条上游吗?');">删除</button>
|
||
</form>
|
||
</td>
|
||
</tr>
|
||
"""
|
||
)
|
||
body_html = "".join(body_rows) if body_rows else "<tr><td colspan='12'>暂无上游配置</td></tr>"
|
||
page = f"""<!doctype html>
|
||
<html lang="zh-CN">
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||
<title>可视化管理后台</title>
|
||
<style>
|
||
body {{
|
||
margin: 0;
|
||
padding: 24px;
|
||
font-family: "Microsoft YaHei", "PingFang SC", sans-serif;
|
||
background: linear-gradient(180deg, #eef3ff 0%, #f7f9ff 48%, #f5f8ff 100%);
|
||
color: #21304b;
|
||
}}
|
||
h1 {{
|
||
margin: 0 0 8px 0;
|
||
font-size: 22px;
|
||
}}
|
||
.hint {{
|
||
margin: 0 0 14px 0;
|
||
color: #4e5f7a;
|
||
}}
|
||
.msg {{
|
||
background: #ecfdf3;
|
||
border: 1px solid #abefc6;
|
||
color: #067647;
|
||
padding: 10px 12px;
|
||
border-radius: 10px;
|
||
}}
|
||
.err {{
|
||
background: #fef3f2;
|
||
border: 1px solid #fecdca;
|
||
color: #b42318;
|
||
padding: 10px 12px;
|
||
border-radius: 10px;
|
||
}}
|
||
.panel {{
|
||
border: 1px solid #d5e1f5;
|
||
border-radius: 14px;
|
||
background: #fff;
|
||
padding: 16px;
|
||
margin-bottom: 16px;
|
||
box-shadow: 0 8px 20px rgba(32, 48, 75, 0.06);
|
||
}}
|
||
.stats-cards {{
|
||
display: grid;
|
||
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
|
||
gap: 10px;
|
||
}}
|
||
.card {{
|
||
border: 1px solid #d8e3f5;
|
||
border-radius: 10px;
|
||
padding: 10px 12px;
|
||
background: #f8fbff;
|
||
}}
|
||
.card .num {{
|
||
margin-top: 6px;
|
||
font-size: 20px;
|
||
font-weight: 700;
|
||
color: #15428b;
|
||
}}
|
||
.form-grid {{
|
||
display: grid;
|
||
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
|
||
gap: 8px;
|
||
}}
|
||
label {{
|
||
font-size: 12px;
|
||
font-weight: 700;
|
||
color: #486185;
|
||
display: block;
|
||
margin-top: 8px;
|
||
}}
|
||
input, textarea, select {{
|
||
width: 100%;
|
||
box-sizing: border-box;
|
||
padding: 8px 10px;
|
||
border-radius: 8px;
|
||
border: 1px solid #c8d6ee;
|
||
margin-top: 4px;
|
||
font-size: 13px;
|
||
background: #fff;
|
||
color: #21304b;
|
||
}}
|
||
textarea {{
|
||
min-height: 62px;
|
||
resize: vertical;
|
||
}}
|
||
.check {{
|
||
display: flex;
|
||
align-items: center;
|
||
gap: 8px;
|
||
margin-top: 18px;
|
||
font-size: 13px;
|
||
}}
|
||
.check input {{
|
||
width: auto;
|
||
margin: 0;
|
||
}}
|
||
.table-wrap {{
|
||
border: 1px solid #d5e1f5;
|
||
border-radius: 10px;
|
||
overflow: auto;
|
||
background: #fff;
|
||
box-shadow: inset 0 0 0 1px rgba(255,255,255,0.4);
|
||
}}
|
||
table {{
|
||
border-collapse: collapse;
|
||
width: 100%;
|
||
min-width: 1180px;
|
||
}}
|
||
th, td {{
|
||
border-bottom: 1px solid #edf2fa;
|
||
border-right: 1px solid #edf2fa;
|
||
padding: 10px 12px;
|
||
font-size: 13px;
|
||
text-align: left;
|
||
vertical-align: top;
|
||
word-break: break-word;
|
||
}}
|
||
th {{
|
||
position: sticky;
|
||
top: 0;
|
||
background: #eff5ff;
|
||
z-index: 1;
|
||
}}
|
||
th:last-child, td:last-child {{
|
||
border-right: none;
|
||
}}
|
||
tr:nth-child(even) td {{
|
||
background: #fbfdff;
|
||
}}
|
||
tr:hover td {{
|
||
background: #f1f6ff;
|
||
}}
|
||
.row-form {{
|
||
display: grid;
|
||
grid-template-columns: repeat(auto-fit, minmax(170px, 1fr));
|
||
gap: 6px 10px;
|
||
}}
|
||
.stats {{
|
||
grid-column: 1 / -1;
|
||
color: #5b708f;
|
||
font-size: 12px;
|
||
margin-top: 4px;
|
||
}}
|
||
.btn-save, .btn-test, .btn-del, .btn-add, .btn-link, .btn-clean {{
|
||
border: 0;
|
||
border-radius: 9px;
|
||
padding: 8px 12px;
|
||
font-weight: 700;
|
||
cursor: pointer;
|
||
margin-top: 12px;
|
||
color: white;
|
||
text-decoration: none;
|
||
display: inline-block;
|
||
width: fit-content;
|
||
}}
|
||
.btn-save {{
|
||
background: #2563eb;
|
||
}}
|
||
.btn-test {{
|
||
background: #0d9488;
|
||
}}
|
||
.btn-add {{
|
||
background: #0f766e;
|
||
}}
|
||
.btn-link {{
|
||
background: #475467;
|
||
}}
|
||
.btn-clean {{
|
||
background: #9f1239;
|
||
}}
|
||
.btn-del {{
|
||
background: #b42318;
|
||
}}
|
||
.btn-save:hover, .btn-test:hover, .btn-del:hover, .btn-add:hover, .btn-link:hover, .btn-clean:hover {{
|
||
filter: brightness(1.05);
|
||
transform: translateY(-1px);
|
||
transition: all .15s ease;
|
||
}}
|
||
.mini-form {{
|
||
display: flex;
|
||
align-items: center;
|
||
gap: 10px;
|
||
flex-wrap: wrap;
|
||
}}
|
||
.mini-form input {{
|
||
width: 120px;
|
||
margin-top: 0;
|
||
}}
|
||
</style>
|
||
</head>
|
||
<body>
|
||
<h1>上游可视化管理</h1>
|
||
<p class="hint">这里可以新增、编辑、启停、删除上游。建议给客户端使用中文别名(例如:通用聊天、快速模型)。</p>
|
||
{banner_html}
|
||
<div class="panel">
|
||
<h2>总览统计</h2>
|
||
<div class="stats-cards">
|
||
<div class="card"><div>总请求数</div><div class="num">{summary.get("req_total", 0)}</div></div>
|
||
<div class="card"><div>成功请求</div><div class="num">{summary.get("req_success", 0)}</div></div>
|
||
<div class="card"><div>失败请求</div><div class="num">{summary.get("req_fail", 0)}</div></div>
|
||
<div class="card"><div>总 Token</div><div class="num">{summary.get("token_total", 0)}</div></div>
|
||
<div class="card"><div>近24小时请求</div><div class="num">{day_summary.get("req_day", 0)}</div></div>
|
||
<div class="card"><div>近24小时 Token</div><div class="num">{day_summary.get("token_day", 0)}</div></div>
|
||
</div>
|
||
</div>
|
||
<div class="panel">
|
||
<h2>新增上游</h2>
|
||
<form method="post" action="/admin/logs/cleanup" class="mini-form">
|
||
<label style="margin:0;">日志清理(保留最近天数)</label>
|
||
<input type="number" name="days" min="1" max="3650" value="7" />
|
||
<button type="submit" class="btn-clean" onclick="return confirm('确认清理旧日志吗?');">清理日志</button>
|
||
</form>
|
||
<form method="post" action="/admin/upstreams/create" class="form-grid">
|
||
<div><label>中文别名</label><input name="name" placeholder="例如:通用聊天" required /></div>
|
||
<div><label>分组编号</label><input name="group_name" value="1" required /></div>
|
||
<div><label>上游模型名</label><input name="model" placeholder="例如:gpt-4o-mini" required /></div>
|
||
<div><label>上游地址</label><input name="base_url" placeholder="https://xxx.com/v1" required /></div>
|
||
<div><label>API 密钥</label><input name="api_key" placeholder="sk-xxx" required /></div>
|
||
<div><label>强制参数(JSON)</label><textarea name="force_parameter" placeholder='例如:{{"temperature":0.7}}'></textarea></div>
|
||
<div><label>限额周期</label><select name="limit_type">{render_limit_type_options("day")}</select></div>
|
||
<div><label>请求限额(0=不限)</label><input name="limit_qty" value="0" /></div>
|
||
<div><label>Token 限额(0=不限)</label><input name="limit_tokens" value="0" /></div>
|
||
<div><label class="check"><input type="checkbox" name="enabled" value="1" checked /> 启用</label></div>
|
||
<div>
|
||
<button type="submit" class="btn-add">新增上游</button>
|
||
<a class="btn-link" href="/admin/stats">打开完整统计页</a>
|
||
<a class="btn-link" href="/admin/logout">退出管理</a>
|
||
</div>
|
||
</form>
|
||
</div>
|
||
<div class="panel">
|
||
<h2>上游统计 TOP 20</h2>
|
||
<div class="table-wrap">
|
||
<table>
|
||
<thead>
|
||
<tr>
|
||
<th>编号</th><th>中文别名</th><th>分组</th><th>模型</th><th>总请求</th><th>成功</th><th>失败</th><th>总Token</th><th>最后调用</th>
|
||
</tr>
|
||
</thead>
|
||
<tbody>{top_html}</tbody>
|
||
</table>
|
||
</div>
|
||
</div>
|
||
<div class="panel">
|
||
<h2>最近 50 次调用(可直接看到命中上游)</h2>
|
||
<div class="table-wrap">
|
||
<table>
|
||
<thead>
|
||
<tr>
|
||
<th>日志ID</th><th>时间</th><th>状态码</th><th>Token</th><th>上游ID</th><th>中文别名</th><th>分组</th><th>模型</th>
|
||
</tr>
|
||
</thead>
|
||
<tbody>{recent_html}</tbody>
|
||
</table>
|
||
</div>
|
||
</div>
|
||
<div class="table-wrap">
|
||
<table>
|
||
<thead><tr><th style="width:80px;">编号</th><th>上游配置</th></tr></thead>
|
||
<tbody>{body_html}</tbody>
|
||
</table>
|
||
</div>
|
||
</body>
|
||
</html>"""
|
||
return Response(page, content_type="text/html; charset=utf-8")
|
||
|
||
|
||
@APP.get("/admin/stats")
|
||
def admin_stats() -> Response:
|
||
guard = ensure_admin_access()
|
||
if guard is not None:
|
||
return guard
|
||
conn = get_db()
|
||
daily_rows = conn.execute(
|
||
"""
|
||
SELECT
|
||
strftime('%Y-%m-%d', datetime(request_at, 'unixepoch', 'localtime')) AS day_key,
|
||
COUNT(*) AS req_total,
|
||
COALESCE(SUM(CASE WHEN status_code = 200 THEN 1 ELSE 0 END), 0) AS req_success,
|
||
COALESCE(SUM(CASE WHEN status_code IS NOT NULL AND status_code <> 200 THEN 1 ELSE 0 END), 0) AS req_fail,
|
||
COALESCE(SUM(used_tokens), 0) AS token_total
|
||
FROM logs
|
||
WHERE request_at >= ?
|
||
GROUP BY day_key
|
||
ORDER BY day_key DESC
|
||
LIMIT 14
|
||
""",
|
||
(utc_now_ts() - 14 * 86400,),
|
||
).fetchall()
|
||
error_rows = conn.execute(
|
||
"""
|
||
SELECT
|
||
l.req_id AS req_id,
|
||
l.status_code AS status_code,
|
||
l.error_text AS error_text,
|
||
l.finish_at AS finish_at,
|
||
u.name AS name,
|
||
u.group_name AS group_name,
|
||
u.model AS model
|
||
FROM logs l
|
||
LEFT JOIN upstreams u ON u.id = l.upstream_id
|
||
WHERE l.status_code IS NOT NULL AND l.status_code <> 200
|
||
ORDER BY l.id DESC
|
||
LIMIT 100
|
||
"""
|
||
).fetchall()
|
||
|
||
day_parts: list[str] = []
|
||
for item in daily_rows:
|
||
row = sqlite_row_to_dict(item) or {}
|
||
day_parts.append(
|
||
f"<tr>"
|
||
f"<td>{html.escape(str(row.get('day_key') or '-'))}</td>"
|
||
f"<td>{row.get('req_total') or 0}</td>"
|
||
f"<td>{row.get('req_success') or 0}</td>"
|
||
f"<td>{row.get('req_fail') or 0}</td>"
|
||
f"<td>{row.get('token_total') or 0}</td>"
|
||
f"</tr>"
|
||
)
|
||
day_html = "".join(day_parts) if day_parts else "<tr><td colspan='5'>暂无数据</td></tr>"
|
||
|
||
err_parts: list[str] = []
|
||
for item in error_rows:
|
||
row = sqlite_row_to_dict(item) or {}
|
||
err_parts.append(
|
||
f"<tr>"
|
||
f"<td>{html.escape(str(row.get('req_id') or ''))}</td>"
|
||
f"<td>{row.get('status_code') or '-'}</td>"
|
||
f"<td>{html.escape(str(row.get('name') or ''))}</td>"
|
||
f"<td>{html.escape(str(row.get('group_name') or ''))}</td>"
|
||
f"<td>{html.escape(str(row.get('model') or ''))}</td>"
|
||
f"<td>{format_ts(row.get('finish_at'))}</td>"
|
||
f"<td>{html.escape(str(row.get('error_text') or ''))}</td>"
|
||
f"</tr>"
|
||
)
|
||
err_html = "".join(err_parts) if err_parts else "<tr><td colspan='7'>暂无失败日志</td></tr>"
|
||
|
||
page = f"""<!doctype html>
|
||
<html lang="zh-CN">
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||
<title>完整统计</title>
|
||
<style>
|
||
body {{
|
||
margin: 0;
|
||
padding: 24px;
|
||
font-family: "Microsoft YaHei", "PingFang SC", sans-serif;
|
||
background: linear-gradient(180deg, #eef3ff 0%, #f7f9ff 48%, #f5f8ff 100%);
|
||
color: #21304b;
|
||
}}
|
||
h1 {{
|
||
margin: 0 0 10px;
|
||
font-size: 24px;
|
||
}}
|
||
.link {{
|
||
display: inline-block;
|
||
margin-bottom: 14px;
|
||
text-decoration: none;
|
||
color: #0b4ecf;
|
||
font-weight: 700;
|
||
background: #eef4ff;
|
||
border: 1px solid #cfe0ff;
|
||
border-radius: 10px;
|
||
padding: 8px 12px;
|
||
}}
|
||
.panel {{
|
||
border: 1px solid #d5e1f5;
|
||
border-radius: 12px;
|
||
background: #fff;
|
||
padding: 14px;
|
||
margin-bottom: 14px;
|
||
box-shadow: 0 8px 20px rgba(32, 48, 75, 0.06);
|
||
}}
|
||
table {{
|
||
border-collapse: collapse;
|
||
width: 100%;
|
||
}}
|
||
th, td {{
|
||
border-bottom: 1px solid #edf2fa;
|
||
border-right: 1px solid #edf2fa;
|
||
padding: 8px 10px;
|
||
font-size: 13px;
|
||
text-align: left;
|
||
word-break: break-word;
|
||
vertical-align: top;
|
||
}}
|
||
th {{
|
||
background: #eff5ff;
|
||
position: sticky;
|
||
top: 0;
|
||
}}
|
||
th:last-child, td:last-child {{
|
||
border-right: none;
|
||
}}
|
||
.table-wrap {{
|
||
border: 1px solid #d5e1f5;
|
||
border-radius: 10px;
|
||
overflow: auto;
|
||
background: #fff;
|
||
}}
|
||
tr:nth-child(even) td {{
|
||
background: #fbfdff;
|
||
}}
|
||
tr:hover td {{
|
||
background: #f1f6ff;
|
||
}}
|
||
</style>
|
||
</head>
|
||
<body>
|
||
<h1>完整统计</h1>
|
||
<a class="link" href="/admin">返回管理首页</a>
|
||
<div class="panel">
|
||
<h2>近 14 天按天统计</h2>
|
||
<div class="table-wrap">
|
||
<table>
|
||
<thead><tr><th>日期</th><th>总请求</th><th>成功</th><th>失败</th><th>总 Token</th></tr></thead>
|
||
<tbody>{day_html}</tbody>
|
||
</table>
|
||
</div>
|
||
</div>
|
||
<div class="panel">
|
||
<h2>最近 100 条失败日志</h2>
|
||
<div class="table-wrap">
|
||
<table>
|
||
<thead><tr><th>请求ID</th><th>状态码</th><th>中文别名</th><th>分组</th><th>模型</th><th>时间</th><th>错误</th></tr></thead>
|
||
<tbody>{err_html}</tbody>
|
||
</table>
|
||
</div>
|
||
</div>
|
||
</body>
|
||
</html>"""
|
||
return Response(page, content_type="text/html; charset=utf-8")
|
||
|
||
|
||
@APP.post("/admin/logs/cleanup")
|
||
def admin_logs_cleanup() -> Response:
|
||
guard = ensure_admin_access()
|
||
if guard is not None:
|
||
return guard
|
||
days = max(1, min(3650, to_form_int(request.form.get("days"), 7)))
|
||
cutoff_ts = utc_now_ts() - days * 86400
|
||
conn = get_db()
|
||
cursor = conn.execute("DELETE FROM logs WHERE request_at < ?", (cutoff_ts,))
|
||
deleted_count = int(cursor.rowcount or 0)
|
||
return redirect(url_for("admin_console", msg=f"日志清理完成:删除 {deleted_count} 条(保留最近 {days} 天)"))
|
||
|
||
|
||
@APP.post("/admin/upstreams/<int:upstream_id>/test")
|
||
def admin_upstream_test(upstream_id: int) -> Response:
|
||
guard = ensure_admin_access()
|
||
if guard is not None:
|
||
return guard
|
||
conn = get_db()
|
||
row = conn.execute("SELECT * FROM upstreams WHERE id = ?", (upstream_id,)).fetchone()
|
||
upstream = sqlite_row_to_dict(row)
|
||
if upstream is None:
|
||
return redirect(url_for("admin_console", err=f"连通性测试失败:编号 {upstream_id} 不存在"))
|
||
|
||
models_url, models_headers = build_models_meta(upstream)
|
||
try:
|
||
resp = HTTP_SESSION.get(
|
||
models_url,
|
||
headers=models_headers,
|
||
timeout=(SETTINGS.connect_timeout_sec, min(20, SETTINGS.read_timeout_sec)),
|
||
)
|
||
except requests.RequestException as exc:
|
||
resp = None
|
||
models_error = str(exc)
|
||
else:
|
||
models_error = ""
|
||
|
||
if resp is not None and resp.status_code == 200:
|
||
return redirect(url_for("admin_console", msg=f"编号 {upstream_id} 连通性正常:/models 返回 200"))
|
||
|
||
# /models 失败时,再用极小 token 的 chat/completions 探测,避免误判。
|
||
probe_payload = build_admin_probe_payload(upstream)
|
||
probe_url, probe_headers = build_request_meta(upstream)
|
||
try:
|
||
probe_resp = HTTP_SESSION.post(
|
||
probe_url,
|
||
headers=probe_headers,
|
||
json=probe_payload,
|
||
timeout=(SETTINGS.connect_timeout_sec, min(30, SETTINGS.read_timeout_sec)),
|
||
)
|
||
except requests.RequestException as exc:
|
||
reason = models_error or str(exc)
|
||
return redirect(url_for("admin_console", err=f"编号 {upstream_id} 连通性失败:{reason}"))
|
||
|
||
if probe_resp.status_code == 200:
|
||
return redirect(url_for("admin_console", msg=f"编号 {upstream_id} 连通性正常:/chat/completions 返回 200"))
|
||
|
||
detail = probe_resp.text[:180].replace("\n", " ").replace("\r", " ")
|
||
if resp is not None:
|
||
detail = f"/models={resp.status_code}, /chat={probe_resp.status_code}, detail={detail}"
|
||
else:
|
||
detail = f"/models网络异常, /chat={probe_resp.status_code}, detail={detail}"
|
||
return redirect(url_for("admin_console", err=f"编号 {upstream_id} 连通性失败:{detail}"))
|
||
|
||
|
||
@APP.post("/admin/upstreams/create")
|
||
def admin_upstream_create() -> Response:
|
||
guard = ensure_admin_access()
|
||
if guard is not None:
|
||
return guard
|
||
name = (request.form.get("name") or "").strip()
|
||
group_name = (request.form.get("group_name") or "1").strip()
|
||
model = (request.form.get("model") or "").strip()
|
||
base_url = (request.form.get("base_url") or "").strip()
|
||
api_key = (request.form.get("api_key") or "").strip()
|
||
limit_type = limit_type_alias(request.form.get("limit_type"))
|
||
limit_qty = max(0, to_form_int(request.form.get("limit_qty"), 0))
|
||
limit_tokens = max(0, to_form_int(request.form.get("limit_tokens"), 0))
|
||
enabled = 1 if (request.form.get("enabled") or "").strip().lower() in {"1", "on", "true"} else 0
|
||
ok_force, force_parameter = parse_force_parameter_text(request.form.get("force_parameter", ""))
|
||
|
||
if not name or not model or not base_url or not api_key:
|
||
return redirect(url_for("admin_console", err="新增失败:请把必填项填写完整"))
|
||
if not ok_force:
|
||
return redirect(url_for("admin_console", err=force_parameter or "新增失败:强制参数格式错误"))
|
||
|
||
now_ts = utc_now_ts()
|
||
conn = get_db()
|
||
conn.execute(
|
||
"""
|
||
INSERT INTO upstreams(
|
||
name, group_name, base_url, api_key, model, force_parameter,
|
||
limit_type, limit_qty, limit_tokens, enabled,
|
||
cycle_started_at, created_at, updated_at
|
||
)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
name,
|
||
group_name or "1",
|
||
base_url,
|
||
api_key,
|
||
model,
|
||
force_parameter,
|
||
limit_type,
|
||
limit_qty,
|
||
limit_tokens,
|
||
enabled,
|
||
cycle_start_ts(limit_type, now_ts),
|
||
now_ts,
|
||
now_ts,
|
||
),
|
||
)
|
||
return redirect(url_for("admin_console", msg="新增成功"))
|
||
|
||
|
||
@APP.post("/admin/upstreams/<int:upstream_id>/update")
|
||
def admin_upstream_update(upstream_id: int) -> Response:
|
||
guard = ensure_admin_access()
|
||
if guard is not None:
|
||
return guard
|
||
name = (request.form.get("name") or "").strip()
|
||
group_name = (request.form.get("group_name") or "1").strip()
|
||
model = (request.form.get("model") or "").strip()
|
||
base_url = (request.form.get("base_url") or "").strip()
|
||
api_key = (request.form.get("api_key") or "").strip()
|
||
limit_type = limit_type_alias(request.form.get("limit_type"))
|
||
limit_qty = max(0, to_form_int(request.form.get("limit_qty"), 0))
|
||
limit_tokens = max(0, to_form_int(request.form.get("limit_tokens"), 0))
|
||
enabled = 1 if (request.form.get("enabled") or "").strip().lower() in {"1", "on", "true"} else 0
|
||
ok_force, force_parameter = parse_force_parameter_text(request.form.get("force_parameter", ""))
|
||
|
||
if not name or not model or not base_url or not api_key:
|
||
return redirect(url_for("admin_console", err=f"保存失败:编号 {upstream_id} 的必填项不能为空"))
|
||
if not ok_force:
|
||
return redirect(url_for("admin_console", err=f"保存失败:编号 {upstream_id} 的强制参数格式错误"))
|
||
|
||
conn = get_db()
|
||
current = conn.execute("SELECT id, limit_type FROM upstreams WHERE id = ?", (upstream_id,)).fetchone()
|
||
if current is None:
|
||
return redirect(url_for("admin_console", err=f"保存失败:编号 {upstream_id} 不存在"))
|
||
now_ts = utc_now_ts()
|
||
next_cycle_start = cycle_start_ts(limit_type, now_ts)
|
||
old_limit_type = limit_type_alias(str(current["limit_type"] or "day"))
|
||
if old_limit_type == limit_type:
|
||
cycle_started_at_sql = "cycle_started_at"
|
||
cycle_started_at_value = None
|
||
else:
|
||
cycle_started_at_sql = "?"
|
||
cycle_started_at_value = next_cycle_start
|
||
|
||
params: list[Any] = [
|
||
name,
|
||
group_name or "1",
|
||
base_url,
|
||
api_key,
|
||
model,
|
||
force_parameter,
|
||
limit_type,
|
||
limit_qty,
|
||
limit_tokens,
|
||
enabled,
|
||
now_ts,
|
||
upstream_id,
|
||
]
|
||
if cycle_started_at_value is None:
|
||
conn.execute(
|
||
"""
|
||
UPDATE upstreams
|
||
SET name = ?, group_name = ?, base_url = ?, api_key = ?, model = ?, force_parameter = ?,
|
||
limit_type = ?, limit_qty = ?, limit_tokens = ?, enabled = ?, updated_at = ?
|
||
WHERE id = ?
|
||
""",
|
||
params,
|
||
)
|
||
else:
|
||
conn.execute(
|
||
f"""
|
||
UPDATE upstreams
|
||
SET name = ?, group_name = ?, base_url = ?, api_key = ?, model = ?, force_parameter = ?,
|
||
limit_type = ?, limit_qty = ?, limit_tokens = ?, enabled = ?, cycle_started_at = {cycle_started_at_sql}, updated_at = ?
|
||
WHERE id = ?
|
||
""",
|
||
[
|
||
name,
|
||
group_name or "1",
|
||
base_url,
|
||
api_key,
|
||
model,
|
||
force_parameter,
|
||
limit_type,
|
||
limit_qty,
|
||
limit_tokens,
|
||
enabled,
|
||
next_cycle_start,
|
||
now_ts,
|
||
upstream_id,
|
||
],
|
||
)
|
||
|
||
return redirect(url_for("admin_console", msg=f"编号 {upstream_id} 已保存"))
|
||
|
||
|
||
@APP.post("/admin/upstreams/<int:upstream_id>/delete")
|
||
def admin_upstream_delete(upstream_id: int) -> Response:
|
||
guard = ensure_admin_access()
|
||
if guard is not None:
|
||
return guard
|
||
conn = get_db()
|
||
try:
|
||
conn.execute("BEGIN IMMEDIATE")
|
||
conn.execute("DELETE FROM logs WHERE upstream_id = ?", (upstream_id,))
|
||
conn.execute("DELETE FROM upstreams WHERE id = ?", (upstream_id,))
|
||
conn.execute("COMMIT")
|
||
except Exception as exc:
|
||
try:
|
||
conn.execute("ROLLBACK")
|
||
except sqlite3.Error:
|
||
pass
|
||
return redirect(url_for("admin_console", err=f"删除失败:{exc}"))
|
||
return redirect(url_for("admin_console", msg=f"编号 {upstream_id} 已删除(关联日志已同步清理)"))
|
||
|
||
|
||
@APP.get("/v1/models")
|
||
def list_models() -> Response | tuple[Response, int]:
|
||
auth = request.headers.get("Authorization", "")
|
||
if not auth_header_ok(auth):
|
||
return unauthorized_response()
|
||
conn = get_db()
|
||
rows = conn.execute(
|
||
"""
|
||
SELECT model, MIN(created_at) AS created_at
|
||
FROM upstreams
|
||
WHERE enabled = 1 AND model IS NOT NULL AND TRIM(model) <> ''
|
||
GROUP BY model
|
||
ORDER BY model ASC
|
||
"""
|
||
).fetchall()
|
||
data = []
|
||
now_ts = utc_now_ts()
|
||
for row in rows:
|
||
model_id = str(row["model"]).strip()
|
||
if not model_id:
|
||
continue
|
||
created_at = int(row["created_at"]) if row["created_at"] is not None else now_ts
|
||
data.append(
|
||
{
|
||
"id": model_id,
|
||
"object": "model",
|
||
"created": created_at,
|
||
"owned_by": "openai-proxy",
|
||
}
|
||
)
|
||
return jsonify({"object": "list", "data": data})
|
||
|
||
|
||
def forward_non_stream_request(
|
||
conn: sqlite3.Connection,
|
||
upstream: dict[str, Any],
|
||
request_url: str,
|
||
request_headers: dict[str, str],
|
||
payload: dict[str, Any],
|
||
log_id: str,
|
||
tried_ids: list[int],
|
||
) -> tuple[bool, Response]:
|
||
upstream_id = int(upstream["id"])
|
||
try:
|
||
result = HTTP_SESSION.post(
|
||
request_url,
|
||
headers=request_headers,
|
||
json=payload,
|
||
timeout=(SETTINGS.connect_timeout_sec, SETTINGS.read_timeout_sec),
|
||
)
|
||
except requests.RequestException as exc:
|
||
error_text = str(exc)
|
||
mark_upstream_failure(conn, upstream_id)
|
||
finish_request_log(conn, log_id, status_code=599, used_tokens=0, error_text=error_text)
|
||
return False, attach_route_headers(upstream_error_response(error_text, status_code=502), upstream, tried_ids)
|
||
|
||
status_code = int(result.status_code)
|
||
content_type = result.headers.get("Content-Type") or "application/json; charset=utf-8"
|
||
body_bytes = result.content
|
||
body_text = result.text
|
||
|
||
if status_code != 200:
|
||
mark_upstream_failure(conn, upstream_id)
|
||
finish_request_log(
|
||
conn,
|
||
log_id,
|
||
status_code=status_code,
|
||
used_tokens=0,
|
||
finish_text=None,
|
||
error_text=body_text,
|
||
)
|
||
response = attach_route_headers(Response(body_bytes, status=status_code, content_type=content_type), upstream, tried_ids)
|
||
return (not is_retryable_failure(status_code, body_text)), response
|
||
|
||
used_tokens = 0
|
||
finish_text = ""
|
||
try:
|
||
result_json = result.json()
|
||
used_tokens = extract_total_tokens(result_json.get("usage"))
|
||
finish_text = extract_finish_text_from_response(result_json)
|
||
except ValueError:
|
||
pass
|
||
|
||
mark_upstream_success(conn, upstream_id, used_tokens)
|
||
finish_request_log(
|
||
conn,
|
||
log_id,
|
||
status_code=status_code,
|
||
used_tokens=used_tokens,
|
||
finish_text=finish_text or body_text,
|
||
error_text=None,
|
||
)
|
||
response = attach_route_headers(Response(body_bytes, status=status_code, content_type=content_type), upstream, tried_ids)
|
||
return True, response
|
||
|
||
|
||
def finalize_stream_outcome(
|
||
upstream_id: int,
|
||
log_id: str,
|
||
used_tokens: int,
|
||
finish_text: str,
|
||
stream_error: str | None,
|
||
) -> None:
|
||
# 流式响应结束时,请求上下文可能已经关闭,必须使用独立连接落库。
|
||
conn: sqlite3.Connection | None = None
|
||
try:
|
||
conn = open_sqlite_conn(SETTINGS.sqlite_path)
|
||
if stream_error:
|
||
mark_upstream_failure(conn, upstream_id)
|
||
finish_request_log(
|
||
conn,
|
||
log_id,
|
||
status_code=599,
|
||
used_tokens=used_tokens,
|
||
finish_text=finish_text,
|
||
error_text=stream_error,
|
||
)
|
||
else:
|
||
mark_upstream_success(conn, upstream_id, used_tokens)
|
||
finish_request_log(
|
||
conn,
|
||
log_id,
|
||
status_code=200,
|
||
used_tokens=used_tokens,
|
||
finish_text=finish_text,
|
||
error_text=None,
|
||
)
|
||
except Exception:
|
||
LOGGER.exception("流式请求收尾写库失败 upstream_id=%s log_id=%s", upstream_id, log_id)
|
||
finally:
|
||
if conn is not None:
|
||
conn.close()
|
||
|
||
|
||
def stream_response_generator(
|
||
upstream_id: int,
|
||
log_id: str,
|
||
upstream_resp: requests.Response,
|
||
) -> Iterator[str]:
|
||
used_tokens = 0
|
||
finish_parts: list[str] = []
|
||
stream_error: str | None = None
|
||
try:
|
||
for line in upstream_resp.iter_lines(decode_unicode=True):
|
||
if line is None:
|
||
continue
|
||
if line == "":
|
||
yield "\n"
|
||
continue
|
||
if line.startswith("data:"):
|
||
data = line[len("data:") :].lstrip()
|
||
if data == "[DONE]":
|
||
yield "data: [DONE]\n\n"
|
||
continue
|
||
try:
|
||
chunk_dict = json.loads(data)
|
||
except json.JSONDecodeError:
|
||
yield line + "\n"
|
||
continue
|
||
usage_tokens = extract_total_tokens(chunk_dict.get("usage"))
|
||
if usage_tokens > 0:
|
||
used_tokens = usage_tokens
|
||
delta_text = extract_delta_text_from_chunk(chunk_dict)
|
||
if delta_text:
|
||
finish_parts.append(delta_text)
|
||
yield "data: " + json.dumps(chunk_dict, ensure_ascii=False) + "\n\n"
|
||
continue
|
||
yield line + "\n"
|
||
except requests.RequestException as exc:
|
||
stream_error = str(exc)
|
||
raise
|
||
finally:
|
||
upstream_resp.close()
|
||
finish_text = "".join(finish_parts)
|
||
finalize_stream_outcome(upstream_id, log_id, used_tokens, finish_text, stream_error)
|
||
|
||
|
||
def forward_stream_request(
|
||
conn: sqlite3.Connection,
|
||
upstream: dict[str, Any],
|
||
request_url: str,
|
||
request_headers: dict[str, str],
|
||
payload: dict[str, Any],
|
||
log_id: str,
|
||
tried_ids: list[int],
|
||
) -> tuple[bool, Response]:
|
||
upstream_id = int(upstream["id"])
|
||
try:
|
||
upstream_resp = HTTP_SESSION.post(
|
||
request_url,
|
||
headers=request_headers,
|
||
json=payload,
|
||
stream=True,
|
||
timeout=(SETTINGS.connect_timeout_sec, SETTINGS.stream_read_timeout_sec),
|
||
)
|
||
except requests.RequestException as exc:
|
||
error_text = str(exc)
|
||
mark_upstream_failure(conn, upstream_id)
|
||
finish_request_log(conn, log_id, status_code=599, used_tokens=0, error_text=error_text)
|
||
return False, attach_route_headers(upstream_error_response(error_text, status_code=502), upstream, tried_ids)
|
||
|
||
status_code = int(upstream_resp.status_code)
|
||
if status_code != 200:
|
||
body_text = upstream_resp.text
|
||
body_bytes = upstream_resp.content
|
||
content_type = upstream_resp.headers.get("Content-Type") or "application/json; charset=utf-8"
|
||
upstream_resp.close()
|
||
mark_upstream_failure(conn, upstream_id)
|
||
finish_request_log(
|
||
conn,
|
||
log_id,
|
||
status_code=status_code,
|
||
used_tokens=0,
|
||
finish_text=None,
|
||
error_text=body_text,
|
||
)
|
||
response = attach_route_headers(Response(body_bytes, status=status_code, content_type=content_type), upstream, tried_ids)
|
||
return (not is_retryable_failure(status_code, body_text)), response
|
||
|
||
content_type = upstream_resp.headers.get("Content-Type") or "text/event-stream; charset=utf-8"
|
||
stream_iter = stream_with_context(stream_response_generator(upstream_id, log_id, upstream_resp))
|
||
return True, attach_route_headers(Response(stream_iter, status=200, content_type=content_type), upstream, tried_ids)
|
||
|
||
|
||
@APP.post("/v1/chat/completions")
|
||
def proxy_chat_completions() -> Response | tuple[Response, int]:
|
||
auth = request.headers.get("Authorization", "")
|
||
if not auth_header_ok(auth):
|
||
return unauthorized_response()
|
||
|
||
user_post = request.get_json(silent=True)
|
||
if not isinstance(user_post, dict):
|
||
return jsonify({"error": {"message": "Invalid JSON body"}}), 400
|
||
|
||
selector = parse_selector(str(user_post.get("model", "")))
|
||
if selector is None:
|
||
return jsonify({"error": {"message": "model is required"}}), 400
|
||
|
||
conn = get_db()
|
||
tried_ids: list[int] = []
|
||
last_retryable_response: Response | None = None
|
||
|
||
for _ in range(SETTINGS.max_failover_attempts):
|
||
upstream = choose_upstream(conn, selector, tried_ids)
|
||
if upstream is None:
|
||
break
|
||
upstream_id = int(upstream["id"])
|
||
payload = merge_force_parameter(user_post, upstream)
|
||
request_url, request_headers = build_request_meta(upstream)
|
||
log_id = create_request_log(conn, upstream_id, payload)
|
||
|
||
stream_mode = bool(payload.get("stream", False))
|
||
if stream_mode:
|
||
should_stop, response = forward_stream_request(
|
||
conn, upstream, request_url, request_headers, payload, log_id, tried_ids
|
||
)
|
||
else:
|
||
should_stop, response = forward_non_stream_request(
|
||
conn, upstream, request_url, request_headers, payload, log_id, tried_ids
|
||
)
|
||
if should_stop:
|
||
return response
|
||
tried_ids.append(upstream_id)
|
||
last_retryable_response = response
|
||
|
||
if last_retryable_response is not None:
|
||
return last_retryable_response
|
||
no_available_message = collect_no_available_reason(conn, selector)
|
||
return jsonify(
|
||
{
|
||
"error": {
|
||
"message": no_available_message,
|
||
"type": "routing_error",
|
||
"code": "no_available_model_endpoint",
|
||
}
|
||
}
|
||
), 429
|
||
|
||
|
||
def setup_logging() -> None:
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||
)
|
||
|
||
|
||
def parse_args() -> argparse.Namespace:
|
||
parser = argparse.ArgumentParser(description="OpenAI compatible proxy with SQLite routing")
|
||
parser.add_argument("--init-db", action="store_true", help="Only initialize SQLite schema and exit")
|
||
return parser.parse_args()
|
||
|
||
|
||
def startup_checks() -> None:
|
||
if SETTINGS.expected_auth == DEFAULT_AUTH:
|
||
LOGGER.warning(
|
||
"OPENAI_PROXY_AUTH is using default value '%s'. Please set a strong bearer token before exposing the service.",
|
||
DEFAULT_AUTH,
|
||
)
|
||
|
||
|
||
def main() -> None:
|
||
setup_logging()
|
||
startup_checks()
|
||
args = parse_args()
|
||
init_db(SETTINGS.sqlite_path)
|
||
if args.init_db:
|
||
LOGGER.info("SQLite schema initialized: %s", SETTINGS.sqlite_path)
|
||
return
|
||
APP.run(
|
||
host=SETTINGS.listen_host,
|
||
port=SETTINGS.listen_port,
|
||
debug=SETTINGS.debug,
|
||
threaded=True,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|