diff --git a/src/vitals/core/database.py b/src/vitals/core/database.py index 12a42c9..5187800 100644 --- a/src/vitals/core/database.py +++ b/src/vitals/core/database.py @@ -1,40 +1,61 @@ -"""SQLite 数据库操作""" +"""数据库操作 - 支持 MySQL""" import json import os -import sqlite3 from contextlib import contextmanager from datetime import date, time, datetime, timedelta from pathlib import Path from typing import Optional +import mysql.connector +from mysql.connector import pooling + from .models import Exercise, Meal, Sleep, Weight, UserConfig, User, Reading, Invite -def get_db_path() -> Path: - """获取数据库路径(支持环境变量配置)""" - # 优先使用环境变量 - env_path = os.environ.get("VITALS_DB_PATH") - if env_path: - db_path = Path(env_path) - db_path.parent.mkdir(parents=True, exist_ok=True) - return db_path +# 数据库连接池(全局) +_connection_pool = None - # 默认路径 - db_dir = Path.home() / ".vitals" - db_dir.mkdir(exist_ok=True) - return db_dir / "vitals.db" + +def get_mysql_config() -> dict: + """获取 MySQL 配置""" + return { + "host": os.environ.get("MYSQL_HOST", "localhost"), + "port": int(os.environ.get("MYSQL_PORT", "3306")), + "user": os.environ.get("MYSQL_USER", "vitals"), + "password": os.environ.get("MYSQL_PASSWORD", ""), + "database": os.environ.get("MYSQL_DATABASE", "vitals"), + } + + +def init_connection_pool(): + """初始化数据库连接池""" + global _connection_pool + if _connection_pool is None: + config = get_mysql_config() + _connection_pool = pooling.MySQLConnectionPool( + pool_name="vitals_pool", + pool_size=5, + pool_reset_session=True, + **config + ) + return _connection_pool @contextmanager def get_connection(): """获取数据库连接""" - conn = sqlite3.connect(get_db_path()) - conn.row_factory = sqlite3.Row + pool = init_connection_pool() + conn = pool.get_connection() try: - yield conn + cursor = conn.cursor(dictionary=True) + yield conn, cursor conn.commit() + except Exception: + conn.rollback() + raise finally: + cursor.close() conn.close()