import base64 import hashlib import logging from typing import Optional, List, Tuple from sqlalchemy.orm import Session from fastapi import HTTPException, status from cryptography.fernet import Fernet from app.models.metadata import DataSource from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSourceTest from app.core.config import settings logger = logging.getLogger(__name__) def _get_fernet() -> Fernet: """Initialize Fernet with a stable key. If DB_ENCRYPTION_KEY is set, use it directly. Otherwise derive deterministically from SECRET_KEY for backward compatibility. """ if settings.DB_ENCRYPTION_KEY: key = settings.DB_ENCRYPTION_KEY.encode() else: logger.warning( "DB_ENCRYPTION_KEY is not set. Deriving encryption key from SECRET_KEY. " "Please set DB_ENCRYPTION_KEY explicitly via environment variable or .env file." ) digest = hashlib.sha256(settings.SECRET_KEY.encode()).digest() key = base64.urlsafe_b64encode(digest) return Fernet(key) _fernet = _get_fernet() def _encrypt_password(password: str) -> str: return _fernet.encrypt(password.encode()).decode() def _decrypt_password(encrypted: str) -> str: return _fernet.decrypt(encrypted.encode()).decode() def get_datasource(db: Session, source_id: int) -> Optional[DataSource]: return db.query(DataSource).filter(DataSource.id == source_id).first() def list_datasources( db: Session, keyword: Optional[str] = None, page: int = 1, page_size: int = 20 ) -> Tuple[List[DataSource], int]: query = db.query(DataSource) if keyword: query = query.filter( (DataSource.name.contains(keyword)) | (DataSource.host.contains(keyword)) ) total = query.count() items = query.offset((page - 1) * page_size).limit(page_size).all() return items, total def create_datasource(db: Session, obj_in: DataSourceCreate, user_id: int) -> DataSource: db_obj = DataSource( name=obj_in.name, source_type=obj_in.source_type, host=obj_in.host, port=obj_in.port, database_name=obj_in.database_name, username=obj_in.username, encrypted_password=_encrypt_password(obj_in.password) if obj_in.password else None, extra_params=obj_in.extra_params, status=obj_in.status or "active", dept_id=obj_in.dept_id, created_by=user_id, ) db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj def update_datasource(db: Session, db_obj: DataSource, obj_in: DataSourceUpdate) -> DataSource: update_data = obj_in.model_dump(exclude_unset=True) if "password" in update_data and update_data["password"]: update_data["encrypted_password"] = _encrypt_password(update_data.pop("password")) else: update_data.pop("password", None) for field, value in update_data.items(): setattr(db_obj, field, value) db.commit() db.refresh(db_obj) return db_obj def delete_datasource(db: Session, source_id: int) -> None: db_obj = get_datasource(db, source_id) if not db_obj: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在") db.delete(db_obj) db.commit() def test_connection(obj_in: DataSourceTest) -> dict: from sqlalchemy import create_engine, inspect, text driver_map = { "mysql": "mysql+pymysql", "postgresql": "postgresql+psycopg2", "oracle": "oracle+cx_oracle", "sqlserver": "mssql+pymssql", "dm": "dm-python", # placeholder } driver = driver_map.get(obj_in.source_type, obj_in.source_type) if obj_in.source_type == "dm": # For MVP, mock test for Dameng return {"success": True, "message": "达梦数据库连接测试通过(模拟)"} host = obj_in.host or "localhost" port = obj_in.port or 5432 database = obj_in.database_name or "" username = obj_in.username or "" password = obj_in.password or "" try: if obj_in.source_type == "postgresql": url = f"{driver}://{username}:{password}@{host}:{port}/{database}" elif obj_in.source_type == "mysql": url = f"{driver}://{username}:{password}@{host}:{port}/{database}" elif obj_in.source_type == "oracle": url = f"{driver}://{username}:{password}@{host}:{port}/{database}" elif obj_in.source_type == "sqlserver": url = f"{driver}://{username}:{password}@{host}:{port}/{database}" else: url = f"{driver}://{username}:{password}@{host}:{port}/{database}" engine = create_engine(url, pool_pre_ping=True, connect_args={"connect_timeout": 5}) with engine.connect() as conn: conn.execute(text("SELECT 1")) return {"success": True, "message": "连接测试通过"} except Exception as e: return {"success": False, "message": f"连接失败: {str(e)}"}