from typing import Optional, List, Tuple from sqlalchemy.orm import Session from fastapi import HTTPException, status from app.models.metadata import DataSource, Database, DataTable, DataColumn from app.models.schema_change import SchemaChangeLog from app.services.datasource_service import get_datasource, _decrypt_password def _log_schema_change(db: Session, source_id: int, change_type: str, database_id: int = None, table_id: int = None, column_id: int = None, old_value: str = None, new_value: str = None): log = SchemaChangeLog( source_id=source_id, database_id=database_id, table_id=table_id, column_id=column_id, change_type=change_type, old_value=old_value, new_value=new_value, ) db.add(log) def get_database(db: Session, db_id: int) -> Optional[Database]: return db.query(Database).filter(Database.id == db_id).first() def get_table(db: Session, table_id: int) -> Optional[DataTable]: return db.query(DataTable).filter(DataTable.id == table_id).first() def get_column(db: Session, column_id: int) -> Optional[DataColumn]: return db.query(DataColumn).filter(DataColumn.id == column_id).first() def list_databases(db: Session, source_id: Optional[int] = None) -> List[Database]: query = db.query(Database).filter(Database.is_deleted == False) if source_id: query = query.filter(Database.source_id == source_id) return query.all() def list_tables(db: Session, database_id: Optional[int] = None, keyword: Optional[str] = None) -> Tuple[List[DataTable], int]: query = db.query(DataTable).filter(DataTable.is_deleted == False) if database_id: query = query.filter(DataTable.database_id == database_id) if keyword: query = query.filter( (DataTable.name.contains(keyword)) | (DataTable.comment.contains(keyword)) ) return query.all(), query.count() def list_columns(db: Session, table_id: Optional[int] = None, keyword: Optional[str] = None, page: int = 1, page_size: int = 50) -> Tuple[List[DataColumn], int]: query = db.query(DataColumn).filter(DataColumn.is_deleted == False) if table_id: query = query.filter(DataColumn.table_id == table_id) if keyword: query = query.filter( (DataColumn.name.contains(keyword)) | (DataColumn.comment.contains(keyword)) ) total = query.count() items = query.offset((page - 1) * page_size).limit(page_size).all() return items, total def build_tree(db: Session, source_id: Optional[int] = None, include_deleted: bool = False) -> List[dict]: sources = db.query(DataSource) if source_id: sources = sources.filter(DataSource.id == source_id) sources = sources.all() result = [] for s in sources: source_node = { "id": s.id, "name": s.name, "type": "source", "children": [], "meta": {"source_type": s.source_type, "status": s.status}, } for d in s.databases: if not include_deleted and d.is_deleted: continue db_node = { "id": d.id, "name": d.name, "type": "database", "children": [], "meta": {"charset": d.charset, "table_count": d.table_count, "is_deleted": d.is_deleted}, } for t in d.tables: if not include_deleted and t.is_deleted: continue table_node = { "id": t.id, "name": t.name, "type": "table", "children": [], "meta": {"comment": t.comment, "row_count": t.row_count, "column_count": t.column_count, "is_deleted": t.is_deleted}, } db_node["children"].append(table_node) source_node["children"].append(db_node) result.append(source_node) return result def _compute_checksum(data: dict) -> str: import hashlib, json payload = json.dumps(data, sort_keys=True, ensure_ascii=False, default=str) return hashlib.sha256(payload.encode()).hexdigest()[:32] def sync_metadata(db: Session, source_id: int, user_id: int) -> dict: from sqlalchemy import create_engine, inspect, text import json from datetime import datetime source = get_datasource(db, source_id) if not source: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在") driver_map = { "mysql": "mysql+pymysql", "postgresql": "postgresql+psycopg2", "oracle": "oracle+cx_oracle", "sqlserver": "mssql+pymssql", } driver = driver_map.get(source.source_type, source.source_type) if source.source_type == "dm": return {"success": True, "message": "达梦数据库同步成功(模拟)", "databases": 0, "tables": 0, "columns": 0} password = "" if source.encrypted_password: try: password = _decrypt_password(source.encrypted_password) except Exception: pass try: url = f"{driver}://{source.username}:{password}@{source.host}:{source.port}/{source.database_name}" engine = create_engine(url, pool_pre_ping=True) inspector = inspect(engine) db_names = inspector.get_schema_names() or [source.database_name] scan_time = datetime.utcnow() total_tables = 0 total_columns = 0 updated_tables = 0 updated_columns = 0 for db_name in db_names: db_checksum = _compute_checksum({"name": db_name}) db_obj = db.query(Database).filter( Database.source_id == source.id, Database.name == db_name ).first() if not db_obj: db_obj = Database(source_id=source.id, name=db_name, checksum=db_checksum, last_scanned_at=scan_time) db.add(db_obj) else: db_obj.checksum = db_checksum db_obj.last_scanned_at = scan_time db_obj.is_deleted = False db_obj.deleted_at = None table_names = inspector.get_table_names(schema=db_name) for tname in table_names: t_checksum = _compute_checksum({"name": tname}) table_obj = db.query(DataTable).filter( DataTable.database_id == db_obj.id, DataTable.name == tname ).first() if not table_obj: table_obj = DataTable(database_id=db_obj.id, name=tname, checksum=t_checksum, last_scanned_at=scan_time) db.add(table_obj) _log_schema_change(db, source.id, "add_table", database_id=db_obj.id, table_id=table_obj.id, new_value=tname) else: if table_obj.checksum != t_checksum: table_obj.checksum = t_checksum updated_tables += 1 table_obj.last_scanned_at = scan_time table_obj.is_deleted = False table_obj.deleted_at = None columns = inspector.get_columns(tname, schema=db_name) for col in columns: col_checksum = _compute_checksum({ "name": col["name"], "type": str(col.get("type", "")), "max_length": col.get("max_length"), "comment": col.get("comment"), "nullable": col.get("nullable", True), }) col_obj = db.query(DataColumn).filter( DataColumn.table_id == table_obj.id, DataColumn.name == col["name"] ).first() if not col_obj: sample = None try: with engine.connect() as conn: result = conn.execute(text(f'SELECT "{col["name"]}" FROM "{db_name}"."{tname}" LIMIT 5')) samples = [str(r[0]) for r in result if r[0] is not None] sample = json.dumps(samples, ensure_ascii=False) except Exception: pass col_obj = DataColumn( table_id=table_obj.id, name=col["name"], data_type=str(col.get("type", "")), length=col.get("max_length"), comment=col.get("comment"), is_nullable=col.get("nullable", True), sample_data=sample, checksum=col_checksum, last_scanned_at=scan_time, ) db.add(col_obj) total_columns += 1 _log_schema_change(db, source.id, "add_column", database_id=db_obj.id, table_id=table_obj.id, column_id=col_obj.id, new_value=col["name"]) else: if col_obj.checksum != col_checksum: old_val = f"type={col_obj.data_type}, len={col_obj.length}, comment={col_obj.comment}" new_val = f"type={str(col.get('type', ''))}, len={col.get('max_length')}, comment={col.get('comment')}" _log_schema_change(db, source.id, "change_type", database_id=db_obj.id, table_id=table_obj.id, column_id=col_obj.id, old_value=old_val, new_value=new_val) col_obj.checksum = col_checksum col_obj.data_type = str(col.get("type", "")) col_obj.length = col.get("max_length") col_obj.comment = col.get("comment") col_obj.is_nullable = col.get("nullable", True) updated_columns += 1 col_obj.last_scanned_at = scan_time col_obj.is_deleted = False col_obj.deleted_at = None total_tables += 1 # Soft-delete objects not seen in this scan and log changes deleted_dbs = db.query(Database).filter( Database.source_id == source.id, Database.last_scanned_at < scan_time, ).all() for d in deleted_dbs: _log_schema_change(db, source.id, "drop_database", database_id=d.id, old_value=d.name) d.is_deleted = True d.deleted_at = scan_time for db_obj in db.query(Database).filter(Database.source_id == source.id).all(): deleted_tables = db.query(DataTable).filter( DataTable.database_id == db_obj.id, DataTable.last_scanned_at < scan_time, ).all() for t in deleted_tables: _log_schema_change(db, source.id, "drop_table", database_id=db_obj.id, table_id=t.id, old_value=t.name) t.is_deleted = True t.deleted_at = scan_time for table_obj in db.query(DataTable).filter(DataTable.database_id == db_obj.id).all(): deleted_cols = db.query(DataColumn).filter( DataColumn.table_id == table_obj.id, DataColumn.last_scanned_at < scan_time, ).all() for c in deleted_cols: _log_schema_change(db, source.id, "drop_column", database_id=db_obj.id, table_id=table_obj.id, column_id=c.id, old_value=c.name) c.is_deleted = True c.deleted_at = scan_time source.status = "active" db.commit() return { "success": True, "message": "元数据同步成功", "databases": len(db_names), "tables": total_tables, "columns": total_columns, "updated_tables": updated_tables, "updated_columns": updated_columns, } except Exception as e: source.status = "error" db.commit() return {"success": False, "message": f"同步失败: {str(e)}", "databases": 0, "tables": 0, "columns": 0}