Files
hiderfong 6d70520e79 feat: 全量功能模块开发与集成测试修复
- 新增后端模块:Alert、APIAsset、Compliance、Lineage、Masking、Risk、SchemaChange、Unstructured、Watermark
- 新增前端模块页面与API接口
- 新增Alembic迁移脚本(002-014)覆盖全量业务表
- 新增测试数据生成脚本与集成测试脚本
- 修复metadata模型JSON类型导入缺失导致启动失败的问题
- 修复前端Alert/APIAsset页面request模块路径错误
- 更新docker-compose与开发计划文档
2026-04-25 08:51:38 +08:00

144 lines
4.9 KiB
Python

import base64
import hashlib
import logging
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from cryptography.fernet import Fernet
from app.models.metadata import DataSource
from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSourceTest
from app.core.config import settings
logger = logging.getLogger(__name__)
def _get_fernet() -> Fernet:
"""Initialize Fernet with a stable key.
If DB_ENCRYPTION_KEY is set, use it directly.
Otherwise derive deterministically from SECRET_KEY for backward compatibility.
"""
if settings.DB_ENCRYPTION_KEY:
key = settings.DB_ENCRYPTION_KEY.encode()
else:
logger.warning(
"DB_ENCRYPTION_KEY is not set. Deriving encryption key from SECRET_KEY. "
"Please set DB_ENCRYPTION_KEY explicitly via environment variable or .env file."
)
digest = hashlib.sha256(settings.SECRET_KEY.encode()).digest()
key = base64.urlsafe_b64encode(digest)
return Fernet(key)
_fernet = _get_fernet()
def _encrypt_password(password: str) -> str:
return _fernet.encrypt(password.encode()).decode()
def _decrypt_password(encrypted: str) -> str:
return _fernet.decrypt(encrypted.encode()).decode()
def get_datasource(db: Session, source_id: int) -> Optional[DataSource]:
return db.query(DataSource).filter(DataSource.id == source_id).first()
def list_datasources(
db: Session, keyword: Optional[str] = None, page: int = 1, page_size: int = 20
) -> Tuple[List[DataSource], int]:
query = db.query(DataSource)
if keyword:
query = query.filter(
(DataSource.name.contains(keyword)) | (DataSource.host.contains(keyword))
)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
def create_datasource(db: Session, obj_in: DataSourceCreate, user_id: int) -> DataSource:
db_obj = DataSource(
name=obj_in.name,
source_type=obj_in.source_type,
host=obj_in.host,
port=obj_in.port,
database_name=obj_in.database_name,
username=obj_in.username,
encrypted_password=_encrypt_password(obj_in.password) if obj_in.password else None,
extra_params=obj_in.extra_params,
status=obj_in.status or "active",
dept_id=obj_in.dept_id,
created_by=user_id,
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update_datasource(db: Session, db_obj: DataSource, obj_in: DataSourceUpdate) -> DataSource:
update_data = obj_in.model_dump(exclude_unset=True)
if "password" in update_data and update_data["password"]:
update_data["encrypted_password"] = _encrypt_password(update_data.pop("password"))
else:
update_data.pop("password", None)
for field, value in update_data.items():
setattr(db_obj, field, value)
db.commit()
db.refresh(db_obj)
return db_obj
def delete_datasource(db: Session, source_id: int) -> None:
db_obj = get_datasource(db, source_id)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在")
db.delete(db_obj)
db.commit()
def test_connection(obj_in: DataSourceTest) -> dict:
from sqlalchemy import create_engine, inspect, text
driver_map = {
"mysql": "mysql+pymysql",
"postgresql": "postgresql+psycopg2",
"oracle": "oracle+cx_oracle",
"sqlserver": "mssql+pymssql",
"dm": "dm-python", # placeholder
}
driver = driver_map.get(obj_in.source_type, obj_in.source_type)
if obj_in.source_type == "dm":
# For MVP, mock test for Dameng
return {"success": True, "message": "达梦数据库连接测试通过(模拟)"}
host = obj_in.host or "localhost"
port = obj_in.port or 5432
database = obj_in.database_name or ""
username = obj_in.username or ""
password = obj_in.password or ""
try:
if obj_in.source_type == "postgresql":
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
elif obj_in.source_type == "mysql":
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
elif obj_in.source_type == "oracle":
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
elif obj_in.source_type == "sqlserver":
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
else:
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
engine = create_engine(url, pool_pre_ping=True, connect_args={"connect_timeout": 5})
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
return {"success": True, "message": "连接测试通过"}
except Exception as e:
return {"success": False, "message": f"连接失败: {str(e)}"}