import hashlib from typing import Optional, Dict from sqlalchemy.orm import Session from fastapi import HTTPException, status from app.models.metadata import DataSource, Database, DataTable, DataColumn from app.models.project import ClassificationResult from app.models.masking import MaskingRule from app.services.datasource_service import get_datasource, _decrypt_password def get_masking_rule(db: Session, rule_id: int): return db.query(MaskingRule).filter(MaskingRule.id == rule_id).first() def list_masking_rules(db: Session, level_id=None, category_id=None, page=1, page_size=20): query = db.query(MaskingRule).filter(MaskingRule.is_active == True) if level_id: query = query.filter(MaskingRule.level_id == level_id) if category_id: query = query.filter(MaskingRule.category_id == category_id) total = query.count() items = query.offset((page - 1) * page_size).limit(page_size).all() return items, total def create_masking_rule(db: Session, data: dict): db_obj = MaskingRule(**data) db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj def update_masking_rule(db: Session, db_obj: MaskingRule, data: dict): for k, v in data.items(): if v is not None: setattr(db_obj, k, v) db.commit() db.refresh(db_obj) return db_obj def delete_masking_rule(db: Session, rule_id: int): db_obj = get_masking_rule(db, rule_id) if not db_obj: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="规则不存在") db.delete(db_obj) db.commit() def _apply_mask(value, params): if not value: return value keep_prefix = params.get("keep_prefix", 3) keep_suffix = params.get("keep_suffix", 4) mask_char = params.get("mask_char", "*") if len(value) <= keep_prefix + keep_suffix: return mask_char * len(value) return value[:keep_prefix] + mask_char * (len(value) - keep_prefix - keep_suffix) + value[-keep_suffix:] def _apply_truncate(value, params): length = params.get("length", 3) suffix = params.get("suffix", "...") if not value or len(value) <= length: return value return value[:length] + suffix def _apply_hash(value, params): algorithm = params.get("algorithm", "sha256") if algorithm == "md5": return hashlib.md5(str(value).encode()).hexdigest()[:16] return hashlib.sha256(str(value).encode()).hexdigest()[:32] def _apply_generalize(value, params): try: step = params.get("step", 10) num = float(value) lower = int(num // step * step) upper = lower + step return f"{lower}-{upper}" except Exception: return value def _apply_replace(value, params): return params.get("replacement", "[REDACTED]") def apply_masking(value, algorithm, params): if value is None: return None handlers = { "mask": _apply_mask, "truncate": _apply_truncate, "hash": _apply_hash, "generalize": _apply_generalize, "replace": _apply_replace, } handler = handlers.get(algorithm) if not handler: return value return handler(str(value), params or {}) def _get_column_rules(db: Session, table_id: int, project_id=None): columns = db.query(DataColumn).filter(DataColumn.table_id == table_id).all() col_rules = {} results = {} if project_id: res_list = db.query(ClassificationResult).filter( ClassificationResult.project_id == project_id, ClassificationResult.column_id.in_([c.id for c in columns]), ).all() results = {r.column_id: r for r in res_list} rules = db.query(MaskingRule).filter(MaskingRule.is_active == True).all() rule_map = {} for r in rules: key = (r.level_id, r.category_id) if key not in rule_map: rule_map[key] = r for col in columns: matched_rule = None if col.id in results: r = results[col.id] matched_rule = rule_map.get((r.level_id, r.category_id)) if not matched_rule: matched_rule = rule_map.get((r.level_id, None)) if not matched_rule: matched_rule = rule_map.get((None, r.category_id)) col_rules[col.id] = matched_rule return col_rules def preview_masking(db: Session, source_id: int, table_name: str, project_id=None, limit=20): source = get_datasource(db, source_id) if not source: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在") table = ( db.query(DataTable) .join(Database) .filter(Database.source_id == source_id, DataTable.name == table_name) .first() ) if not table: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="表不存在") col_rules = _get_column_rules(db, table.id, project_id) from sqlalchemy import create_engine, text password = "" if source.encrypted_password: try: password = _decrypt_password(source.encrypted_password) except Exception: pass driver_map = { "mysql": "mysql+pymysql", "postgresql": "postgresql+psycopg2", "oracle": "oracle+cx_oracle", "sqlserver": "mssql+pymssql", } driver = driver_map.get(source.source_type, source.source_type) url = f"{driver}://{source.username}:{password}@{source.host}:{source.port}/{source.database_name}" engine = create_engine(url, pool_pre_ping=True) columns = db.query(DataColumn).filter(DataColumn.table_id == table.id).all() rows_raw = [] try: with engine.connect() as conn: result = conn.execute(text(f'SELECT * FROM "{table_name}" LIMIT {limit}')) rows_raw = [dict(row._mapping) for row in result] except Exception: try: with engine.connect() as conn: result = conn.execute(text(f"SELECT * FROM {table_name} LIMIT {limit}")) rows_raw = [dict(row._mapping) for row in result] except Exception as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"查询失败: {e}") masked_rows = [] for raw in rows_raw: masked = {} for col in columns: val = raw.get(col.name) rule = col_rules.get(col.id) if rule: masked[col.name] = apply_masking(val, rule.algorithm, rule.params or {}) else: masked[col.name] = val masked_rows.append(masked) return { "success": True, "columns": [{"name": c.name, "data_type": c.data_type, "has_rule": col_rules.get(c.id) is not None} for c in columns], "rows": masked_rows, "total_rows": len(masked_rows), }