122 lines
4.2 KiB
Python
122 lines
4.2 KiB
Python
from typing import Optional, List, Tuple
|
|
from sqlalchemy.orm import Session
|
|
from fastapi import HTTPException, status
|
|
from cryptography.fernet import Fernet
|
|
|
|
from app.models.metadata import DataSource
|
|
from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSourceTest
|
|
from app.core.config import settings
|
|
|
|
# Simple AES-like symmetric encryption for DB passwords
|
|
# In production, use a proper KMS
|
|
_fernet = Fernet(Fernet.generate_key())
|
|
|
|
|
|
def _encrypt_password(password: str) -> str:
|
|
return _fernet.encrypt(password.encode()).decode()
|
|
|
|
|
|
def _decrypt_password(encrypted: str) -> str:
|
|
return _fernet.decrypt(encrypted.encode()).decode()
|
|
|
|
|
|
def get_datasource(db: Session, source_id: int) -> Optional[DataSource]:
|
|
return db.query(DataSource).filter(DataSource.id == source_id).first()
|
|
|
|
|
|
def list_datasources(
|
|
db: Session, keyword: Optional[str] = None, page: int = 1, page_size: int = 20
|
|
) -> Tuple[List[DataSource], int]:
|
|
query = db.query(DataSource)
|
|
if keyword:
|
|
query = query.filter(
|
|
(DataSource.name.contains(keyword)) | (DataSource.host.contains(keyword))
|
|
)
|
|
total = query.count()
|
|
items = query.offset((page - 1) * page_size).limit(page_size).all()
|
|
return items, total
|
|
|
|
|
|
def create_datasource(db: Session, obj_in: DataSourceCreate, user_id: int) -> DataSource:
|
|
db_obj = DataSource(
|
|
name=obj_in.name,
|
|
source_type=obj_in.source_type,
|
|
host=obj_in.host,
|
|
port=obj_in.port,
|
|
database_name=obj_in.database_name,
|
|
username=obj_in.username,
|
|
encrypted_password=_encrypt_password(obj_in.password) if obj_in.password else None,
|
|
extra_params=obj_in.extra_params,
|
|
status=obj_in.status or "active",
|
|
dept_id=obj_in.dept_id,
|
|
created_by=user_id,
|
|
)
|
|
db.add(db_obj)
|
|
db.commit()
|
|
db.refresh(db_obj)
|
|
return db_obj
|
|
|
|
|
|
def update_datasource(db: Session, db_obj: DataSource, obj_in: DataSourceUpdate) -> DataSource:
|
|
update_data = obj_in.model_dump(exclude_unset=True)
|
|
if "password" in update_data and update_data["password"]:
|
|
update_data["encrypted_password"] = _encrypt_password(update_data.pop("password"))
|
|
else:
|
|
update_data.pop("password", None)
|
|
|
|
for field, value in update_data.items():
|
|
setattr(db_obj, field, value)
|
|
db.commit()
|
|
db.refresh(db_obj)
|
|
return db_obj
|
|
|
|
|
|
def delete_datasource(db: Session, source_id: int) -> None:
|
|
db_obj = get_datasource(db, source_id)
|
|
if not db_obj:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在")
|
|
db.delete(db_obj)
|
|
db.commit()
|
|
|
|
|
|
def test_connection(obj_in: DataSourceTest) -> dict:
|
|
from sqlalchemy import create_engine, inspect, text
|
|
|
|
driver_map = {
|
|
"mysql": "mysql+pymysql",
|
|
"postgresql": "postgresql+psycopg2",
|
|
"oracle": "oracle+cx_oracle",
|
|
"sqlserver": "mssql+pymssql",
|
|
"dm": "dm-python", # placeholder
|
|
}
|
|
driver = driver_map.get(obj_in.source_type, obj_in.source_type)
|
|
|
|
if obj_in.source_type == "dm":
|
|
# For MVP, mock test for Dameng
|
|
return {"success": True, "message": "达梦数据库连接测试通过(模拟)"}
|
|
|
|
host = obj_in.host or "localhost"
|
|
port = obj_in.port or 5432
|
|
database = obj_in.database_name or ""
|
|
username = obj_in.username or ""
|
|
password = obj_in.password or ""
|
|
|
|
try:
|
|
if obj_in.source_type == "postgresql":
|
|
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
|
|
elif obj_in.source_type == "mysql":
|
|
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
|
|
elif obj_in.source_type == "oracle":
|
|
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
|
|
elif obj_in.source_type == "sqlserver":
|
|
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
|
|
else:
|
|
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
|
|
|
|
engine = create_engine(url, pool_pre_ping=True, connect_args={"connect_timeout": 5})
|
|
with engine.connect() as conn:
|
|
conn.execute(text("SELECT 1"))
|
|
return {"success": True, "message": "连接测试通过"}
|
|
except Exception as e:
|
|
return {"success": False, "message": f"连接失败: {str(e)}"}
|