feat: 全量功能模块开发与集成测试修复

- 新增后端模块:Alert、APIAsset、Compliance、Lineage、Masking、Risk、SchemaChange、Unstructured、Watermark
- 新增前端模块页面与API接口
- 新增Alembic迁移脚本(002-014)覆盖全量业务表
- 新增测试数据生成脚本与集成测试脚本
- 修复metadata模型JSON类型导入缺失导致启动失败的问题
- 修复前端Alert/APIAsset页面request模块路径错误
- 更新docker-compose与开发计划文档
This commit is contained in:
hiderfong
2026-04-25 08:51:38 +08:00
parent 8b2bc84399
commit 6d70520e79
110 changed files with 6125 additions and 87 deletions
+10 -1
View File
@@ -1,6 +1,6 @@
from fastapi import APIRouter
from app.api.v1 import auth, user, datasource, metadata, classification, project, task, report, dashboard
from app.api.v1 import auth, user, datasource, metadata, classification, project, task, report, dashboard, masking, watermark, unstructured, schema_change, risk, compliance, lineage, alert, api_asset
api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["认证"])
@@ -12,3 +12,12 @@ api_router.include_router(project.router, prefix="/projects", tags=["项目管
api_router.include_router(task.router, prefix="/tasks", tags=["任务管理"])
api_router.include_router(report.router, prefix="/reports", tags=["报告管理"])
api_router.include_router(dashboard.router, prefix="/dashboard", tags=["仪表盘"])
api_router.include_router(masking.router, prefix="/masking", tags=["数据脱敏"])
api_router.include_router(watermark.router, prefix="/watermark", tags=["数据水印"])
api_router.include_router(unstructured.router, prefix="/unstructured", tags=["非结构化文件"])
api_router.include_router(schema_change.router, prefix="/schema-changes", tags=["Schema变更"])
api_router.include_router(risk.router, prefix="/risk", tags=["风险评估"])
api_router.include_router(compliance.router, prefix="/compliance", tags=["合规检查"])
api_router.include_router(lineage.router, prefix="/lineage", tags=["数据血缘"])
api_router.include_router(alert.router, prefix="/alerts", tags=["告警与工单"])
api_router.include_router(api_asset.router, prefix="/api-assets", tags=["API资产"])
+115
View File
@@ -0,0 +1,115 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse
from app.services import alert_service
from app.api.deps import get_current_user, require_admin
router = APIRouter()
@router.post("/init-rules")
def init_alert_rules(
db: Session = Depends(get_db),
current_user: User = Depends(require_admin),
):
alert_service.init_builtin_alert_rules(db)
return ResponseModel(message="初始化完成")
@router.post("/check")
def check_alerts(
db: Session = Depends(get_db),
current_user: User = Depends(require_admin),
):
records = alert_service.check_alerts(db)
return ResponseModel(data={"alerts_created": len(records)})
@router.get("/records")
def list_alert_records(
status: Optional[str] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
query = db.query(alert_service.AlertRecord)
if status:
query = query.filter(alert_service.AlertRecord.status == status)
total = query.count()
items = query.order_by(alert_service.AlertRecord.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return ListResponse(
data=[{
"id": r.id,
"rule_id": r.rule_id,
"title": r.title,
"content": r.content,
"severity": r.severity,
"status": r.status,
"created_at": r.created_at.isoformat() if r.created_at else None,
} for r in items],
total=total,
page=page,
page_size=page_size,
)
@router.post("/work-orders")
def create_work_order(
alert_id: int,
title: str,
description: str = "",
assignee_id: Optional[int] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
wo = alert_service.create_work_order(db, alert_id, title, description, assignee_id)
return ResponseModel(data={"id": wo.id})
@router.get("/work-orders")
def list_work_orders(
status: Optional[str] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from app.models.alert import WorkOrder
query = db.query(WorkOrder)
if status:
query = query.filter(WorkOrder.status == status)
total = query.count()
items = query.order_by(WorkOrder.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return ListResponse(
data=[{
"id": w.id,
"alert_id": w.alert_id,
"title": w.title,
"status": w.status,
"assignee_name": w.assignee.username if w.assignee else None,
"created_at": w.created_at.isoformat() if w.created_at else None,
} for w in items],
total=total,
page=page,
page_size=page_size,
)
@router.post("/work-orders/{wo_id}/status")
def update_work_order(
wo_id: int,
status: str,
resolution: str = "",
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
wo = alert_service.update_work_order_status(db, wo_id, status, resolution or None)
if not wo:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="工单不存在")
return ResponseModel(data={"id": wo.id, "status": wo.status})
+131
View File
@@ -0,0 +1,131 @@
from typing import Optional, List
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from pydantic import BaseModel
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse
from app.services import api_asset_service
from app.api.deps import get_current_user
router = APIRouter()
class APIAssetCreate(BaseModel):
name: str
base_url: str
swagger_url: Optional[str] = None
auth_type: Optional[str] = "none"
headers: Optional[dict] = None
description: Optional[str] = None
class APIAssetUpdate(BaseModel):
name: Optional[str] = None
base_url: Optional[str] = None
swagger_url: Optional[str] = None
auth_type: Optional[str] = None
headers: Optional[dict] = None
description: Optional[str] = None
@router.post("")
def create_asset(
body: APIAssetCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
asset = api_asset_service.create_asset(db, body.dict(), current_user.id)
return ResponseModel(data={"id": asset.id})
@router.get("")
def list_assets(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from app.models.api_asset import APIAsset
query = db.query(APIAsset)
total = query.count()
items = query.order_by(APIAsset.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return ListResponse(
data=[{
"id": a.id,
"name": a.name,
"base_url": a.base_url,
"swagger_url": a.swagger_url,
"auth_type": a.auth_type,
"scan_status": a.scan_status,
"total_endpoints": a.total_endpoints,
"sensitive_endpoints": a.sensitive_endpoints,
"created_at": a.created_at.isoformat() if a.created_at else None,
} for a in items],
total=total,
page=page,
page_size=page_size,
)
@router.put("/{asset_id}")
def update_asset(
asset_id: int,
body: APIAssetUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
asset = api_asset_service.update_asset(db, asset_id, body.dict(exclude_unset=True))
if not asset:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="资产不存在")
return ResponseModel(data={"id": asset.id})
@router.delete("/{asset_id}")
def delete_asset(
asset_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
ok = api_asset_service.delete_asset(db, asset_id)
if not ok:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="资产不存在")
return ResponseModel()
@router.post("/{asset_id}/scan")
def scan_asset(
asset_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
result = api_asset_service.scan_swagger(db, asset_id)
return ResponseModel(data=result)
@router.get("/{asset_id}/endpoints")
def list_endpoints(
asset_id: int,
risk_level: Optional[str] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from app.models.api_asset import APIEndpoint
query = db.query(APIEndpoint).filter(APIEndpoint.asset_id == asset_id)
if risk_level:
query = query.filter(APIEndpoint.risk_level == risk_level)
total = query.count()
items = query.order_by(APIEndpoint.id.asc()).offset((page - 1) * page_size).limit(page_size).all()
return ListResponse(
data=[{
"id": e.id,
"method": e.method,
"path": e.path,
"summary": e.summary,
"tags": e.tags,
"parameters": e.parameters,
"sensitive_fields": e.sensitive_fields,
"risk_level": e.risk_level,
"is_active": e.is_active,
} for e in items],
total=total,
page=page,
page_size=page_size,
)
+40
View File
@@ -238,3 +238,43 @@ def auto_classify(
):
result = classification_engine.run_auto_classification(db, project_id)
return ResponseModel(data=result)
@router.post("/ml-train")
def ml_train(
background: bool = True,
model_name: Optional[str] = None,
algorithm: str = "logistic_regression",
db: Session = Depends(get_db),
current_user: User = Depends(require_admin),
):
from app.tasks.ml_tasks import train_ml_model_task
from app.services.ml_service import train_model
if background:
task = train_ml_model_task.delay(model_name=model_name, algorithm=algorithm)
return ResponseModel(data={"task_id": task.id, "status": task.state})
else:
mv = train_model(db, model_name=model_name, algorithm=algorithm)
if mv:
return ResponseModel(data={"model_id": mv.id, "accuracy": mv.accuracy, "train_samples": mv.train_samples})
return ResponseModel(message="训练失败:样本不足或发生错误")
@router.get("/ml-suggest/{project_id}")
def ml_suggest(
project_id: int,
column_ids: Optional[str] = Query(None),
top_k: int = Query(3, ge=1, le=5),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from app.services.ml_service import suggest_for_project_columns
ids = None
if column_ids:
ids = [int(x) for x in column_ids.split(",") if x.strip().isdigit()]
result = suggest_for_project_columns(db, project_id, column_ids=ids, top_k=top_k)
if not result.get("success"):
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result.get("message"))
return ResponseModel(data=result["suggestions"])
+72
View File
@@ -0,0 +1,72 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse
from app.services import compliance_service
from app.api.deps import get_current_user, require_admin
router = APIRouter()
@router.post("/init-rules")
def init_rules(
db: Session = Depends(get_db),
current_user: User = Depends(require_admin),
):
compliance_service.init_builtin_rules(db)
return ResponseModel(message="初始化完成")
@router.post("/scan")
def scan_compliance(
project_id: Optional[int] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
issues = compliance_service.scan_compliance(db, project_id=project_id)
return ResponseModel(data={"issues_found": len(issues)})
@router.get("/issues")
def list_issues(
project_id: Optional[int] = Query(None),
status: Optional[str] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items, total = compliance_service.list_issues(db, project_id=project_id, status=status, page=page, page_size=page_size)
return ListResponse(
data=[{
"id": i.id,
"rule_id": i.rule_id,
"project_id": i.project_id,
"entity_type": i.entity_type,
"entity_name": i.entity_name,
"severity": i.severity,
"description": i.description,
"suggestion": i.suggestion,
"status": i.status,
"created_at": i.created_at.isoformat() if i.created_at else None,
} for i in items],
total=total,
page=page,
page_size=page_size,
)
@router.post("/issues/{issue_id}/resolve")
def resolve_issue(
issue_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
issue = compliance_service.resolve_issue(db, issue_id)
if not issue:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="问题不存在")
return ResponseModel(message="已标记为已解决")
+32
View File
@@ -0,0 +1,32 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel
from app.services import lineage_service
from app.api.deps import get_current_user
router = APIRouter()
@router.post("/parse")
def parse_lineage(
sql: str,
target_table: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
records = lineage_service.parse_sql_lineage(db, sql, target_table)
return ResponseModel(data={"records_created": len(records)})
@router.get("/graph")
def get_graph(
table_name: Optional[str] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
graph = lineage_service.get_lineage_graph(db, table_name=table_name)
return ResponseModel(data=graph)
+88
View File
@@ -0,0 +1,88 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse
from app.services import masking_service
from app.api.deps import get_current_user, require_admin
router = APIRouter()
@router.get("/rules")
def list_masking_rules(
level_id: Optional[int] = Query(None),
category_id: Optional[int] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items, total = masking_service.list_masking_rules(db, level_id=level_id, category_id=category_id, page=page, page_size=page_size)
return ListResponse(
data=[{
"id": r.id,
"name": r.name,
"level_id": r.level_id,
"category_id": r.category_id,
"algorithm": r.algorithm,
"params": r.params,
"is_active": r.is_active,
"description": r.description,
"level_name": r.level.name if r.level else None,
"category_name": r.category.name if r.category else None,
} for r in items],
total=total,
page=page,
page_size=page_size,
)
@router.post("/rules")
def create_masking_rule(
req: dict,
db: Session = Depends(get_db),
current_user: User = Depends(require_admin),
):
item = masking_service.create_masking_rule(db, req)
return ResponseModel(data={"id": item.id})
@router.put("/rules/{rule_id}")
def update_masking_rule(
rule_id: int,
req: dict,
db: Session = Depends(get_db),
current_user: User = Depends(require_admin),
):
db_obj = masking_service.get_masking_rule(db, rule_id)
if not db_obj:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="规则不存在")
item = masking_service.update_masking_rule(db, db_obj, req)
return ResponseModel(data={"id": item.id})
@router.delete("/rules/{rule_id}")
def delete_masking_rule(
rule_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_admin),
):
masking_service.delete_masking_rule(db, rule_id)
return ResponseModel(message="删除成功")
@router.post("/preview")
def preview_masking(
source_id: int,
table_name: str,
project_id: Optional[int] = None,
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
result = masking_service.preview_masking(db, source_id, table_name, project_id=project_id, limit=limit)
return ResponseModel(data=result)
+67 -3
View File
@@ -101,9 +101,73 @@ def delete_project(
@router.post("/{project_id}/auto-classify")
def project_auto_classify(
project_id: int,
background: bool = True,
db: Session = Depends(get_db),
current_user: User = Depends(require_manager),
):
from app.services.classification_engine import run_auto_classification
result = run_auto_classification(db, project_id)
return ResponseModel(data=result)
from app.tasks.classification_tasks import auto_classify_task
from celery.result import AsyncResult
project = project_service.get_project(db, project_id)
if not project:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
if background:
# Check if already running
if project.celery_task_id:
existing = AsyncResult(project.celery_task_id)
if existing.state in ("PENDING", "PROGRESS", "STARTED"):
return ResponseModel(data={"task_id": project.celery_task_id, "status": existing.state})
task = auto_classify_task.delay(project_id)
project.celery_task_id = task.id
project.status = "scanning"
db.commit()
return ResponseModel(data={"task_id": task.id, "status": task.state})
else:
from app.services.classification_engine import run_auto_classification
project.status = "scanning"
db.commit()
result = run_auto_classification(db, project_id)
if result.get("success"):
project.status = "assigning"
else:
project.status = "created"
db.commit()
return ResponseModel(data=result)
@router.get("/{project_id}/auto-classify-status")
def project_auto_classify_status(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from celery.result import AsyncResult
import json
project = project_service.get_project(db, project_id)
if not project:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
task_id = project.celery_task_id
if not task_id:
# Return persisted progress if any
progress = json.loads(project.scan_progress) if project.scan_progress else None
return ResponseModel(data={"status": project.status, "progress": progress})
result = AsyncResult(task_id)
progress = None
if result.state == "PROGRESS" and result.info:
progress = result.info
elif project.scan_progress:
progress = json.loads(project.scan_progress)
return ResponseModel(data={
"status": result.state,
"task_id": task_id,
"progress": progress,
"project_status": project.status,
})
+18
View File
@@ -44,12 +44,30 @@ def get_report_stats(
@router.get("/projects/{project_id}/download")
def download_report(
project_id: int,
format: str = "docx",
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
if format == "excel":
content = report_service.generate_excel_report(db, project_id)
return Response(
content=content,
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={"Content-Disposition": f"attachment; filename=report_project_{project_id}.xlsx"},
)
content = report_service.generate_classification_report(db, project_id)
return Response(
content=content,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f"attachment; filename=report_project_{project_id}.docx"},
)
@router.get("/projects/{project_id}/summary")
def report_summary(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
data = report_service.get_report_summary(db, project_id)
return ResponseModel(data=data)
+73
View File
@@ -0,0 +1,73 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse
from app.services import risk_service
from app.api.deps import get_current_user
router = APIRouter()
@router.post("/recalculate")
def recalculate_risk(
project_id: Optional[int] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
if project_id:
result = risk_service.calculate_project_risk(db, project_id)
return ResponseModel(data={"project_id": project_id, "risk_score": result.risk_score if result else 0})
result = risk_service.calculate_all_projects_risk(db)
return ResponseModel(data=result)
@router.get("/top")
def risk_top(
entity_type: str = Query("project"),
n: int = Query(10, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items = risk_service.get_risk_top_n(db, entity_type=entity_type, n=n)
return ListResponse(
data=[{
"id": r.id,
"entity_type": r.entity_type,
"entity_id": r.entity_id,
"entity_name": r.entity_name,
"risk_score": r.risk_score,
"sensitivity_score": r.sensitivity_score,
"exposure_score": r.exposure_score,
"protection_score": r.protection_score,
"updated_at": r.updated_at.isoformat() if r.updated_at else None,
} for r in items],
total=len(items),
page=1,
page_size=n,
)
@router.get("/projects/{project_id}")
def project_risk(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from app.models.risk import RiskAssessment
item = db.query(RiskAssessment).filter(
RiskAssessment.entity_type == "project",
RiskAssessment.entity_id == project_id,
).first()
if not item:
return ResponseModel(data=None)
return ResponseModel(data={
"risk_score": item.risk_score,
"sensitivity_score": item.sensitivity_score,
"exposure_score": item.exposure_score,
"protection_score": item.protection_score,
"details": item.details,
"updated_at": item.updated_at.isoformat() if item.updated_at else None,
})
+45
View File
@@ -0,0 +1,45 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse
from app.models.schema_change import SchemaChangeLog
from app.api.deps import get_current_user
router = APIRouter()
@router.get("/logs")
def list_schema_changes(
source_id: Optional[int] = Query(None),
change_type: Optional[str] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
query = db.query(SchemaChangeLog)
if source_id:
query = query.filter(SchemaChangeLog.source_id == source_id)
if change_type:
query = query.filter(SchemaChangeLog.change_type == change_type)
total = query.count()
items = query.order_by(SchemaChangeLog.detected_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return ListResponse(
data=[{
"id": log.id,
"source_id": log.source_id,
"database_id": log.database_id,
"table_id": log.table_id,
"column_id": log.column_id,
"change_type": log.change_type,
"old_value": log.old_value,
"new_value": log.new_value,
"detected_at": log.detected_at.isoformat() if log.detected_at else None,
} for log in items],
total=total,
page=page,
page_size=page_size,
)
+108
View File
@@ -0,0 +1,108 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query, UploadFile, File
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse
from app.services import unstructured_service
from app.api.deps import get_current_user
from app.core.events import minio_client
from app.core.config import settings
from app.models.metadata import UnstructuredFile
router = APIRouter()
@router.post("/upload")
def upload_file(
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
# Determine file type
filename = file.filename or "unknown"
ext = filename.split(".")[-1].lower() if "." in filename else ""
type_map = {
"docx": "word", "doc": "word",
"xlsx": "excel", "xls": "excel",
"pdf": "pdf",
"txt": "txt",
}
file_type = type_map.get(ext, "unknown")
# Upload to MinIO
storage_path = f"unstructured/{current_user.id}/{filename}"
try:
data = file.file.read()
minio_client.put_object(
settings.MINIO_BUCKET_NAME,
storage_path,
data=data,
length=len(data),
content_type=file.content_type or "application/octet-stream",
)
except Exception as e:
return ResponseModel(message=f"上传失败: {e}")
db_obj = UnstructuredFile(
original_name=filename,
file_type=file_type,
file_size=len(data),
storage_path=storage_path,
status="pending",
created_by=current_user.id,
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
# Trigger processing
try:
result = unstructured_service.process_unstructured_file(db, db_obj.id)
return ResponseModel(data={"id": db_obj.id, "matches": result.get("matches", []), "status": "processed"})
except Exception as e:
return ResponseModel(data={"id": db_obj.id, "status": "error", "error": str(e)})
@router.get("/files")
def list_files(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
query = db.query(UnstructuredFile).filter(UnstructuredFile.created_by == current_user.id)
total = query.count()
items = query.order_by(UnstructuredFile.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return ListResponse(
data=[{
"id": f.id,
"original_name": f.original_name,
"file_type": f.file_type,
"file_size": f.file_size,
"status": f.status,
"analysis_result": f.analysis_result,
"created_at": f.created_at.isoformat() if f.created_at else None,
} for f in items],
total=total,
page=page,
page_size=page_size,
)
@router.post("/files/{file_id}/reprocess")
def reprocess_file(
file_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
file_obj = db.query(UnstructuredFile).filter(
UnstructuredFile.id == file_id,
UnstructuredFile.created_by == current_user.id,
).first()
if not file_obj:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="文件不存在")
result = unstructured_service.process_unstructured_file(db, file_id)
return ResponseModel(data={"matches": result.get("matches", []), "status": "processed"})
+23
View File
@@ -0,0 +1,23 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel
from app.services import watermark_service
from app.api.deps import get_current_user
router = APIRouter()
@router.post("/trace")
def trace_watermark(
req: dict,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
text = req.get("text", "")
result = watermark_service.trace_watermark(db, text)
if not result:
return ResponseModel(data=None, message="未检测到水印")
return ResponseModel(data=result)
+5
View File
@@ -10,6 +10,11 @@ class Settings(BaseSettings):
DATABASE_URL: str = "postgresql+psycopg2://pdg:pdg_secret_2024@localhost:5432/prop_data_guard"
REDIS_URL: str = "redis://localhost:6379/0"
# Database password encryption key (Fernet-compatible base64, 32 bytes)
# If empty, will be derived from SECRET_KEY for backward compatibility.
# STRONGLY recommended to set this explicitly in production.
DB_ENCRYPTION_KEY: str = ""
SECRET_KEY: str = "prop-data-guard-super-secret-key-change-in-production"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
+15
View File
@@ -2,6 +2,14 @@ from app.models.user import User, Role, Dept, UserRole
from app.models.metadata import DataSource, Database, DataTable, DataColumn, UnstructuredFile
from app.models.classification import Category, DataLevel, RecognitionRule, ClassificationTemplate
from app.models.project import ClassificationProject, ClassificationTask, ClassificationResult, ClassificationChange
from app.models.ml import MLModelVersion
from app.models.masking import MaskingRule
from app.models.watermark import WatermarkLog
from app.models.schema_change import SchemaChangeLog
from app.models.risk import RiskAssessment
from app.models.compliance import ComplianceRule, ComplianceIssue
from app.models.alert import AlertRule, AlertRecord, WorkOrder
from app.models.api_asset import APIAsset, APIEndpoint
from app.models.log import OperationLog
__all__ = [
@@ -9,5 +17,12 @@ __all__ = [
"DataSource", "Database", "DataTable", "DataColumn", "UnstructuredFile",
"Category", "DataLevel", "RecognitionRule", "ClassificationTemplate",
"ClassificationProject", "ClassificationTask", "ClassificationResult", "ClassificationChange",
"MLModelVersion",
"MaskingRule",
"WatermarkLog",
"SchemaChangeLog",
"RiskAssessment",
"ComplianceRule",
"ComplianceIssue",
"OperationLog",
]
+46
View File
@@ -0,0 +1,46 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, ForeignKey, JSON
from sqlalchemy.orm import relationship
from app.core.database import Base
class AlertRule(Base):
__tablename__ = "alert_rule"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(200), nullable=False)
trigger_condition = Column(String(50), nullable=False) # l5_count, risk_score, schema_change
threshold = Column(Integer, default=0)
severity = Column(String(20), default="medium")
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
class AlertRecord(Base):
__tablename__ = "alert_record"
id = Column(Integer, primary_key=True, index=True)
rule_id = Column(Integer, ForeignKey("alert_rule.id"), nullable=False)
title = Column(String(200), nullable=False)
content = Column(Text)
severity = Column(String(20), default="medium")
status = Column(String(20), default="open") # open, acknowledged, resolved
created_at = Column(DateTime, default=datetime.utcnow)
rule = relationship("AlertRule")
class WorkOrder(Base):
__tablename__ = "work_order"
id = Column(Integer, primary_key=True, index=True)
alert_id = Column(Integer, ForeignKey("alert_record.id"), nullable=True)
title = Column(String(200), nullable=False)
description = Column(Text)
assignee_id = Column(Integer, ForeignKey("sys_user.id"), nullable=True)
status = Column(String(20), default="open") # open, in_progress, resolved
resolution = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
resolved_at = Column(DateTime, nullable=True)
assignee = relationship("User")
+41
View File
@@ -0,0 +1,41 @@
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, JSON, ForeignKey, BigInteger
from sqlalchemy.orm import relationship
from app.core.database import Base
from datetime import datetime
class APIAsset(Base):
__tablename__ = "api_asset"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(200), nullable=False)
base_url = Column(String(500), nullable=False)
swagger_url = Column(String(500), nullable=True)
auth_type = Column(String(50), default="none") # none, bearer, api_key, basic
headers = Column(JSON, default=dict)
description = Column(Text, nullable=True)
scan_status = Column(String(20), default="idle") # idle, scanning, completed, failed
total_endpoints = Column(Integer, default=0)
sensitive_endpoints = Column(Integer, default=0)
created_by = Column(Integer, ForeignKey("sys_user.id"), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
endpoints = relationship("APIEndpoint", back_populates="asset", cascade="all, delete-orphan")
creator = relationship("User", foreign_keys=[created_by])
class APIEndpoint(Base):
__tablename__ = "api_endpoint"
id = Column(Integer, primary_key=True, index=True)
asset_id = Column(Integer, ForeignKey("api_asset.id"), nullable=False)
method = Column(String(10), nullable=False) # GET, POST, PUT, DELETE, etc.
path = Column(String(500), nullable=False)
summary = Column(String(500), nullable=True)
tags = Column(JSON, default=list)
parameters = Column(JSON, default=list)
request_body_schema = Column(JSON, nullable=True)
response_schema = Column(JSON, nullable=True)
sensitive_fields = Column(JSON, default=list) # detected PII fields
risk_level = Column(String(20), default="low") # low, medium, high, critical
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
asset = relationship("APIAsset", back_populates="endpoints")
+33
View File
@@ -0,0 +1,33 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, JSON
from app.core.database import Base
class ComplianceRule(Base):
__tablename__ = "compliance_rule"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(200), nullable=False)
standard = Column(String(50), nullable=False) # dengbao, pipl, gdpr
description = Column(Text)
check_logic = Column(String(50), nullable=False) # check_masking, check_encryption, check_audit, check_level
severity = Column(String(20), default="medium") # low, medium, high, critical
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
class ComplianceIssue(Base):
__tablename__ = "compliance_issue"
id = Column(Integer, primary_key=True, index=True)
rule_id = Column(Integer, nullable=False)
project_id = Column(Integer, nullable=True)
entity_type = Column(String(20), nullable=False) # project, source, column
entity_id = Column(Integer, nullable=False)
entity_name = Column(String(200))
severity = Column(String(20), default="medium")
description = Column(Text)
suggestion = Column(Text)
status = Column(String(20), default="open") # open, resolved, ignored
resolved_at = Column(DateTime, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
+16
View File
@@ -0,0 +1,16 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime
from app.core.database import Base
class DataLineage(Base):
__tablename__ = "data_lineage"
id = Column(Integer, primary_key=True, index=True)
source_table = Column(String(200), nullable=False)
source_column = Column(String(200), nullable=True)
target_table = Column(String(200), nullable=False)
target_column = Column(String(200), nullable=True)
relation_type = Column(String(20), default="direct") # direct, derived, lookup
script_content = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
+22
View File
@@ -0,0 +1,22 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, JSON, Text
from sqlalchemy.orm import relationship
from app.core.database import Base
class MaskingRule(Base):
__tablename__ = "masking_rule"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False)
level_id = Column(Integer, ForeignKey("data_level.id"), nullable=True)
category_id = Column(Integer, ForeignKey("category.id"), nullable=True)
algorithm = Column(String(20), nullable=False) # mask, truncate, hash, generalize, replace
params = Column(JSON, default=dict) # algorithm-specific params
is_active = Column(Boolean, default=True)
description = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
level = relationship("DataLevel")
category = relationship("Category")
+14 -1
View File
@@ -1,5 +1,5 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, Text, BigInteger
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, Text, BigInteger, JSON
from sqlalchemy.orm import relationship
from app.core.database import Base
@@ -36,6 +36,10 @@ class Database(Base):
charset = Column(String(50))
table_count = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.utcnow)
last_scanned_at = Column(DateTime, nullable=True)
checksum = Column(String(64), nullable=True)
is_deleted = Column(Boolean, default=False)
deleted_at = Column(DateTime, nullable=True)
source = relationship("DataSource", back_populates="databases")
tables = relationship("DataTable", back_populates="database", cascade="all, delete-orphan")
@@ -51,6 +55,10 @@ class DataTable(Base):
row_count = Column(BigInteger, default=0)
column_count = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.utcnow)
last_scanned_at = Column(DateTime, nullable=True)
checksum = Column(String(64), nullable=True)
is_deleted = Column(Boolean, default=False)
deleted_at = Column(DateTime, nullable=True)
database = relationship("Database", back_populates="tables")
columns = relationship("DataColumn", back_populates="table", cascade="all, delete-orphan")
@@ -68,6 +76,10 @@ class DataColumn(Base):
is_nullable = Column(Boolean, default=True)
sample_data = Column(Text) # JSON array of sample values
created_at = Column(DateTime, default=datetime.utcnow)
last_scanned_at = Column(DateTime, nullable=True)
checksum = Column(String(64), nullable=True)
is_deleted = Column(Boolean, default=False)
deleted_at = Column(DateTime, nullable=True)
table = relationship("DataTable", back_populates="columns")
@@ -81,6 +93,7 @@ class UnstructuredFile(Base):
file_size = Column(BigInteger)
storage_path = Column(String(500))
extracted_text = Column(Text)
analysis_result = Column(JSON, nullable=True) # JSON: {matches: [{rule_name, category, level, snippet}]}
status = Column(String(20), default="pending") # pending, processed, error
created_by = Column(Integer, ForeignKey("sys_user.id"))
created_at = Column(DateTime, default=datetime.utcnow)
+18
View File
@@ -0,0 +1,18 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Float, DateTime, Boolean, Text
from app.core.database import Base
class MLModelVersion(Base):
__tablename__ = "ml_model_version"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False)
model_path = Column(String(500), nullable=False) # joblib dump path
vectorizer_path = Column(String(500), nullable=False) # tfidf vectorizer path
accuracy = Column(Float, default=0.0)
train_samples = Column(Integer, default=0)
train_date = Column(DateTime, default=datetime.utcnow)
is_active = Column(Boolean, default=False)
description = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
+4
View File
@@ -48,6 +48,10 @@ class ClassificationProject(Base):
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Async classification tracking
celery_task_id = Column(String(100), nullable=True)
scan_progress = Column(Text, nullable=True) # JSON: {"scanned": 0, "matched": 0, "total": 0}
template = relationship("ClassificationTemplate")
tasks = relationship("ClassificationTask", back_populates="project", cascade="all, delete-orphan")
results = relationship("ClassificationResult", back_populates="project", cascade="all, delete-orphan")
+20
View File
@@ -0,0 +1,20 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Float, DateTime, ForeignKey, JSON
from sqlalchemy.orm import relationship
from app.core.database import Base
class RiskAssessment(Base):
__tablename__ = "risk_assessment"
id = Column(Integer, primary_key=True, index=True)
entity_type = Column(String(20), nullable=False) # project, source, table, field
entity_id = Column(Integer, nullable=False)
entity_name = Column(String(200))
risk_score = Column(Float, default=0.0) # 0-100
sensitivity_score = Column(Float, default=0.0)
exposure_score = Column(Float, default=0.0)
protection_score = Column(Float, default=0.0)
details = Column(JSON, default=dict)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
+23
View File
@@ -0,0 +1,23 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey
from sqlalchemy.orm import relationship
from app.core.database import Base
class SchemaChangeLog(Base):
__tablename__ = "schema_change_log"
id = Column(Integer, primary_key=True, index=True)
source_id = Column(Integer, ForeignKey("data_source.id"), nullable=False)
database_id = Column(Integer, ForeignKey("meta_database.id"), nullable=True)
table_id = Column(Integer, ForeignKey("meta_table.id"), nullable=True)
column_id = Column(Integer, ForeignKey("meta_column.id"), nullable=True)
change_type = Column(String(20), nullable=False) # add_table, drop_table, add_column, drop_column, change_type, change_comment
old_value = Column(Text)
new_value = Column(Text)
detected_at = Column(DateTime, default=datetime.utcnow)
source = relationship("DataSource")
database = relationship("Database")
table = relationship("DataTable")
column = relationship("DataColumn")
+17
View File
@@ -0,0 +1,17 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey
from sqlalchemy.orm import relationship
from app.core.database import Base
class WatermarkLog(Base):
__tablename__ = "watermark_log"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("sys_user.id"), nullable=False)
export_type = Column(String(20), default="csv") # csv, excel, txt
data_scope = Column(Text) # JSON: {source_id, table_name, row_count}
watermark_key = Column(String(64), nullable=False) # random key for this export
created_at = Column(DateTime, default=datetime.utcnow)
user = relationship("User")
+92
View File
@@ -0,0 +1,92 @@
from typing import List, Optional
from sqlalchemy.orm import Session
from datetime import datetime
from app.models.alert import AlertRule, AlertRecord, WorkOrder
from app.models.project import ClassificationProject, ClassificationResult
from app.models.risk import RiskAssessment
def init_builtin_alert_rules(db: Session):
if db.query(AlertRule).first():
return
rules = [
AlertRule(name="L5字段数量突增", trigger_condition="l5_count", threshold=5, severity="high"),
AlertRule(name="项目风险分过高", trigger_condition="risk_score", threshold=80, severity="critical"),
AlertRule(name="Schema新增敏感字段", trigger_condition="schema_change", threshold=1, severity="medium"),
]
for r in rules:
db.add(r)
db.commit()
def check_alerts(db: Session) -> List[AlertRecord]:
"""Run alert checks and create records."""
rules = db.query(AlertRule).filter(AlertRule.is_active == True).all()
records = []
for rule in rules:
if rule.trigger_condition == "l5_count":
projects = db.query(ClassificationProject).all()
for p in projects:
l5_count = db.query(ClassificationResult).filter(
ClassificationResult.project_id == p.id,
ClassificationResult.level_id.isnot(None),
).join(ClassificationResult.level).filter(
ClassificationResult.level.has(code="L5")
).count()
if l5_count >= rule.threshold:
rec = AlertRecord(
rule_id=rule.id,
title=f"项目 {p.name} L5字段数量达到 {l5_count}",
content=f"阈值: {rule.threshold}",
severity=rule.severity,
)
db.add(rec)
records.append(rec)
elif rule.trigger_condition == "risk_score":
risks = db.query(RiskAssessment).filter(
RiskAssessment.entity_type == "project",
RiskAssessment.risk_score >= rule.threshold,
).all()
for rsk in risks:
rec = AlertRecord(
rule_id=rule.id,
title=f"项目 {rsk.entity_name} 风险分 {rsk.risk_score}",
content=f"阈值: {rule.threshold}",
severity=rule.severity,
)
db.add(rec)
records.append(rec)
db.commit()
return records
def create_work_order(db: Session, alert_id: int, title: str, description: str, assignee_id: Optional[int] = None) -> WorkOrder:
wo = WorkOrder(
alert_id=alert_id,
title=title,
description=description,
assignee_id=assignee_id,
)
db.add(wo)
db.commit()
db.refresh(wo)
return wo
def update_work_order_status(db: Session, wo_id: int, status: str, resolution: str = None) -> WorkOrder:
wo = db.query(WorkOrder).filter(WorkOrder.id == wo_id).first()
if wo:
wo.status = status
if resolution:
wo.resolution = resolution
if status == "resolved":
wo.resolved_at = datetime.utcnow()
# Also resolve linked alert
if wo.alert_id:
alert = db.query(AlertRecord).filter(AlertRecord.id == wo.alert_id).first()
if alert:
alert.status = "resolved"
db.commit()
db.refresh(wo)
return wo
+174
View File
@@ -0,0 +1,174 @@
import requests, json
from typing import Optional
from sqlalchemy.orm import Session
from app.models.api_asset import APIAsset, APIEndpoint
from app.models.metadata import DataColumn
from app.services.classification_engine import match_rule
# Simple sensitive keywords for API field detection
SENSITIVE_KEYWORDS = [
"password", "pwd", "passwd", "secret", "token", "credit_card", "card_no",
"bank_account", "bank_card", "id_card", "id_number", "phone", "mobile",
"email", "address", "name", "age", "gender", "salary", "income",
"health", "medical", "biometric", "fingerprint", "face",
]
def _is_sensitive_field(name: str, schema: dict) -> tuple[bool, str]:
low = name.lower()
for kw in SENSITIVE_KEYWORDS:
if kw in low:
return True, f"keyword:{kw}"
# Check description / format hints
desc = str(schema.get("description", "")).lower()
fmt = str(schema.get("format", "")).lower()
if "email" in fmt or "email" in desc:
return True, "format:email"
if "uuid" in fmt and "user" in low:
return True, "format:user-uuid"
return False, ""
def _extract_fields(schema: dict, prefix: str = "") -> list[dict]:
fields = []
if not isinstance(schema, dict):
return fields
props = schema.get("properties", {})
for k, v in props.items():
full_name = f"{prefix}.{k}" if prefix else k
sensitive, reason = _is_sensitive_field(k, v)
if sensitive:
fields.append({"name": full_name, "type": v.get("type", "unknown"), "reason": reason})
# nested object
if v.get("type") == "object" and "properties" in v:
fields.extend(_extract_fields(v, full_name))
# array items
if v.get("type") == "array" and isinstance(v.get("items"), dict):
fields.extend(_extract_fields(v["items"], full_name + "[]"))
return fields
def _risk_level_from_fields(fields: list[dict]) -> str:
if not fields:
return "low"
high_keywords = {"password", "secret", "token", "credit_card", "bank_account", "biometric", "fingerprint", "face"}
for f in fields:
for kw in high_keywords:
if kw in f["name"].lower():
return "critical" if kw in {"password", "secret", "token", "biometric"} else "high"
return "medium"
def scan_swagger(db: Session, asset_id: int) -> dict:
asset = db.query(APIAsset).filter(APIAsset.id == asset_id).first()
if not asset:
return {"success": False, "error": "Asset not found"}
if not asset.swagger_url:
return {"success": False, "error": "No swagger_url configured"}
asset.scan_status = "scanning"
db.commit()
try:
headers = dict(asset.headers or {})
resp = requests.get(asset.swagger_url, headers=headers, timeout=30)
resp.raise_for_status()
spec = resp.json()
# Clear previous endpoints
db.query(APIEndpoint).filter(APIEndpoint.asset_id == asset_id).delete()
paths = spec.get("paths", {})
total = 0
sensitive_total = 0
for path, methods in paths.items():
for method, detail in methods.items():
if method.lower() not in {"get","post","put","patch","delete","head","options"}:
continue
total += 1
parameters = []
for p in detail.get("parameters", []):
parameters.append({"name": p.get("name"), "in": p.get("in"), "required": p.get("required", False), "type": p.get("schema",{}).get("type","string")})
req_schema = detail.get("requestBody", {}).get("content", {}).get("application/json", {}).get("schema")
resp_schema = None
for code, resp_detail in (detail.get("responses", {}).get("200", {}).get("content", {}) or {}).items():
if isinstance(resp_detail, dict) and "schema" in resp_detail:
resp_schema = resp_detail["schema"]
break
# Also try generic 200
if resp_schema is None:
ok = detail.get("responses", {}).get("200", {})
for ct, cd in ok.get("content", {}).items():
if isinstance(cd, dict) and "schema" in cd:
resp_schema = cd["schema"]
break
fields = []
if req_schema:
fields.extend(_extract_fields(req_schema))
if resp_schema:
fields.extend(_extract_fields(resp_schema))
# dedup
seen = set()
unique_fields = []
for f in fields:
if f["name"] not in seen:
seen.add(f["name"])
unique_fields.append(f)
risk = _risk_level_from_fields(unique_fields)
ep = APIEndpoint(
asset_id=asset_id,
method=method.upper(),
path=path,
summary=detail.get("summary", ""),
tags=detail.get("tags", []),
parameters=parameters,
request_body_schema=req_schema,
response_schema=resp_schema,
sensitive_fields=unique_fields,
risk_level=risk,
)
db.add(ep)
if unique_fields:
sensitive_total += 1
asset.scan_status = "completed"
asset.total_endpoints = total
asset.sensitive_endpoints = sensitive_total
asset.updated_at = __import__('datetime').datetime.utcnow()
db.commit()
return {"success": True, "total": total, "sensitive": sensitive_total}
except Exception as e:
asset.scan_status = "failed"
db.commit()
return {"success": False, "error": str(e)}
def create_asset(db: Session, data: dict, user_id: Optional[int] = None) -> APIAsset:
asset = APIAsset(
name=data["name"],
base_url=data["base_url"],
swagger_url=data.get("swagger_url"),
auth_type=data.get("auth_type", "none"),
headers=data.get("headers"),
description=data.get("description"),
created_by=user_id,
)
db.add(asset)
db.commit()
db.refresh(asset)
return asset
def update_asset(db: Session, asset_id: int, data: dict) -> Optional[APIAsset]:
asset = db.query(APIAsset).filter(APIAsset.id == asset_id).first()
if not asset:
return None
for k, v in data.items():
if hasattr(asset, k):
setattr(asset, k, v)
db.commit()
db.refresh(asset)
return asset
def delete_asset(db: Session, asset_id: int) -> bool:
asset = db.query(APIAsset).filter(APIAsset.id == asset_id).first()
if not asset:
return False
db.delete(asset)
db.commit()
return True
+44 -5
View File
@@ -51,11 +51,39 @@ def match_rule(rule: RecognitionRule, column: DataColumn) -> Tuple[bool, float]:
if t.strip().lower() in enums:
return True, 0.90
elif rule.rule_type == "similarity":
benchmarks = [b.strip().lower() for b in rule.rule_content.split(",") if b.strip()]
if not benchmarks:
return False, 0.0
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
texts = [t.lower() for t in targets] + benchmarks
try:
vectorizer = TfidfVectorizer(analyzer="char_wb", ngram_range=(2, 3))
tfidf = vectorizer.fit_transform(texts)
target_vecs = tfidf[:len(targets)]
bench_vecs = tfidf[len(targets):]
sim_matrix = cosine_similarity(target_vecs, bench_vecs)
max_sim = float(sim_matrix.max())
if max_sim >= 0.75:
return True, round(min(max_sim, 0.99), 4)
except Exception:
pass
return False, 0.0
def run_auto_classification(db: Session, project_id: int, source_ids: Optional[List[int]] = None) -> dict:
"""Run automatic classification for a project."""
def run_auto_classification(
db: Session,
project_id: int,
source_ids: Optional[List[int]] = None,
progress_callback=None,
) -> dict:
"""Run automatic classification for a project.
Args:
progress_callback: Optional callable(scanned, matched, total) to report progress.
"""
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if not project:
return {"success": False, "message": "项目不存在"}
@@ -82,7 +110,10 @@ def run_auto_classification(db: Session, project_id: int, source_ids: Optional[L
columns = columns_query.all()
matched_count = 0
for col in columns:
total = len(columns)
report_interval = max(1, total // 20) # report ~20 times
for idx, col in enumerate(columns):
# Check if already has a result for this project
existing = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project_id,
@@ -121,12 +152,20 @@ def run_auto_classification(db: Session, project_id: int, source_ids: Optional[L
# Increment hit count
best_rule.hit_count = (best_rule.hit_count or 0) + 1
# Report progress periodically
if progress_callback and (idx + 1) % report_interval == 0:
progress_callback(scanned=idx + 1, matched=matched_count, total=total)
db.commit()
# Final progress report
if progress_callback:
progress_callback(scanned=total, matched=matched_count, total=total)
return {
"success": True,
"message": f"自动分类完成,共扫描 {len(columns)} 个字段,命中 {matched_count}",
"scanned": len(columns),
"message": f"自动分类完成,共扫描 {total} 个字段,命中 {matched_count}",
"scanned": total,
"matched": matched_count,
}
+122
View File
@@ -0,0 +1,122 @@
from typing import List, Optional
from sqlalchemy.orm import Session
from datetime import datetime
from app.models.compliance import ComplianceRule, ComplianceIssue
from app.models.project import ClassificationProject, ClassificationResult
from app.models.classification import DataLevel
from app.models.masking import MaskingRule
def init_builtin_rules(db: Session):
"""Initialize built-in compliance rules."""
if db.query(ComplianceRule).first():
return
rules = [
ComplianceRule(name="L4/L5字段未配置脱敏", standard="dengbao", description="等保2.0要求:四级及以上数据应进行脱敏处理", check_logic="check_masking", severity="high"),
ComplianceRule(name="L5字段缺乏加密存储措施", standard="dengbao", description="等保2.0要求:五级数据应加密存储", check_logic="check_encryption", severity="critical"),
ComplianceRule(name="个人敏感信息处理未授权", standard="pipl", description="个人信息保护法:处理敏感个人信息应取得单独同意", check_logic="check_level", severity="high"),
ComplianceRule(name="数据跨境传输未评估", standard="gdpr", description="GDPR:个人数据跨境传输需进行影响评估", check_logic="check_audit", severity="medium"),
]
for r in rules:
db.add(r)
db.commit()
def scan_compliance(db: Session, project_id: Optional[int] = None) -> List[ComplianceIssue]:
"""Run compliance scan and generate issues."""
rules = db.query(ComplianceRule).filter(ComplianceRule.is_active == True).all()
issues = []
# Get masking rules for check_masking logic
masking_rules = db.query(MaskingRule).filter(MaskingRule.is_active == True).all()
masking_level_ids = {r.level_id for r in masking_rules if r.level_id}
query = db.query(ClassificationProject)
if project_id:
query = query.filter(ClassificationProject.id == project_id)
projects = query.all()
for project in projects:
results = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project.id,
ClassificationResult.level_id.isnot(None),
).all()
for r in results:
if not r.level:
continue
level_code = r.level.code
for rule in rules:
matched = False
desc = ""
suggestion = ""
if rule.check_logic == "check_masking" and level_code in ("L4", "L5"):
if r.level_id not in masking_level_ids:
matched = True
desc = f"字段 '{r.column.name if r.column else '未知'}'{level_code} 级,但未配置脱敏规则"
suggestion = "请在【数据脱敏】模块为该分级配置脱敏策略"
elif rule.check_logic == "check_encryption" and level_code == "L5":
# Placeholder: no encryption check in MVP, always flag
matched = True
desc = f"字段 '{r.column.name if r.column else '未知'}' 为 L5 级核心数据,建议确认是否加密存储"
suggestion = "请确认该字段在数据库中已加密存储"
elif rule.check_logic == "check_level" and level_code in ("L4", "L5"):
if r.source == "auto":
matched = True
desc = f"个人敏感字段 '{r.column.name if r.column else '未知'}' 目前为自动识别,建议人工复核并确认授权"
suggestion = "请人工确认该字段的处理已取得合法授权"
elif rule.check_logic == "check_audit":
# Placeholder for cross-border check
pass
if matched:
# Check if open issue already exists
existing = db.query(ComplianceIssue).filter(
ComplianceIssue.rule_id == rule.id,
ComplianceIssue.project_id == project.id,
ComplianceIssue.entity_type == "column",
ComplianceIssue.entity_id == (r.column_id or 0),
ComplianceIssue.status == "open",
).first()
if not existing:
issue = ComplianceIssue(
rule_id=rule.id,
project_id=project.id,
entity_type="column",
entity_id=r.column_id or 0,
entity_name=r.column.name if r.column else "未知",
severity=rule.severity,
description=desc,
suggestion=suggestion,
)
db.add(issue)
issues.append(issue)
db.commit()
return issues
def list_issues(db: Session, project_id: Optional[int] = None, status: Optional[str] = None, page: int = 1, page_size: int = 20):
query = db.query(ComplianceIssue)
if project_id:
query = query.filter(ComplianceIssue.project_id == project_id)
if status:
query = query.filter(ComplianceIssue.status == status)
total = query.count()
items = query.order_by(ComplianceIssue.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return items, total
def resolve_issue(db: Session, issue_id: int):
issue = db.query(ComplianceIssue).filter(ComplianceIssue.id == issue_id).first()
if issue:
issue.status = "resolved"
issue.resolved_at = datetime.utcnow()
db.commit()
return issue
+25 -3
View File
@@ -1,3 +1,6 @@
import base64
import hashlib
import logging
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
@@ -7,9 +10,28 @@ from app.models.metadata import DataSource
from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSourceTest
from app.core.config import settings
# Simple AES-like symmetric encryption for DB passwords
# In production, use a proper KMS
_fernet = Fernet(Fernet.generate_key())
logger = logging.getLogger(__name__)
def _get_fernet() -> Fernet:
"""Initialize Fernet with a stable key.
If DB_ENCRYPTION_KEY is set, use it directly.
Otherwise derive deterministically from SECRET_KEY for backward compatibility.
"""
if settings.DB_ENCRYPTION_KEY:
key = settings.DB_ENCRYPTION_KEY.encode()
else:
logger.warning(
"DB_ENCRYPTION_KEY is not set. Deriving encryption key from SECRET_KEY. "
"Please set DB_ENCRYPTION_KEY explicitly via environment variable or .env file."
)
digest = hashlib.sha256(settings.SECRET_KEY.encode()).digest()
key = base64.urlsafe_b64encode(digest)
return Fernet(key)
_fernet = _get_fernet()
def _encrypt_password(password: str) -> str:
+65
View File
@@ -0,0 +1,65 @@
import re
from typing import List, Optional
from sqlalchemy.orm import Session
from app.models.lineage import DataLineage
def _extract_tables(sql: str) -> List[str]:
"""Extract table names from SQL using regex (simple heuristic)."""
# Normalize SQL
sql = re.sub(r"--.*?\n", " ", sql)
sql = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL)
sql = sql.lower()
tables = set()
# FROM / JOIN / INTO
for pattern in [r"\bfrom\s+([a-z_][a-z0-9_]*)", r"\bjoin\s+([a-z_][a-z0-9_]*)"]:
for m in re.finditer(pattern, sql):
tables.add(m.group(1))
return sorted(tables)
def parse_sql_lineage(db: Session, sql: str, target_table: str) -> List[DataLineage]:
"""Parse SQL and create lineage records pointing to target_table."""
source_tables = _extract_tables(sql)
records = []
for st in source_tables:
if st == target_table:
continue
existing = db.query(DataLineage).filter(
DataLineage.source_table == st,
DataLineage.target_table == target_table,
).first()
if not existing:
rec = DataLineage(
source_table=st,
target_table=target_table,
relation_type="direct",
script_content=sql[:2000],
)
db.add(rec)
records.append(rec)
db.commit()
return records
def get_lineage_graph(db: Session, table_name: Optional[str] = None) -> dict:
"""Build graph data for ECharts."""
query = db.query(DataLineage)
if table_name:
query = query.filter(
(DataLineage.source_table == table_name) | (DataLineage.target_table == table_name)
)
items = query.limit(500).all()
nodes = {}
links = []
for item in items:
nodes[item.source_table] = {"name": item.source_table, "category": 0}
nodes[item.target_table] = {"name": item.target_table, "category": 1}
links.append({"source": item.source_table, "target": item.target_table, "value": item.relation_type})
return {
"nodes": list(nodes.values()),
"links": links,
}
+195
View File
@@ -0,0 +1,195 @@
import hashlib
from typing import Optional, Dict
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.metadata import DataSource, Database, DataTable, DataColumn
from app.models.project import ClassificationResult
from app.models.masking import MaskingRule
from app.services.datasource_service import get_datasource, _decrypt_password
def get_masking_rule(db: Session, rule_id: int):
return db.query(MaskingRule).filter(MaskingRule.id == rule_id).first()
def list_masking_rules(db: Session, level_id=None, category_id=None, page=1, page_size=20):
query = db.query(MaskingRule).filter(MaskingRule.is_active == True)
if level_id:
query = query.filter(MaskingRule.level_id == level_id)
if category_id:
query = query.filter(MaskingRule.category_id == category_id)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
def create_masking_rule(db: Session, data: dict):
db_obj = MaskingRule(**data)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update_masking_rule(db: Session, db_obj: MaskingRule, data: dict):
for k, v in data.items():
if v is not None:
setattr(db_obj, k, v)
db.commit()
db.refresh(db_obj)
return db_obj
def delete_masking_rule(db: Session, rule_id: int):
db_obj = get_masking_rule(db, rule_id)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="规则不存在")
db.delete(db_obj)
db.commit()
def _apply_mask(value, params):
if not value:
return value
keep_prefix = params.get("keep_prefix", 3)
keep_suffix = params.get("keep_suffix", 4)
mask_char = params.get("mask_char", "*")
if len(value) <= keep_prefix + keep_suffix:
return mask_char * len(value)
return value[:keep_prefix] + mask_char * (len(value) - keep_prefix - keep_suffix) + value[-keep_suffix:]
def _apply_truncate(value, params):
length = params.get("length", 3)
suffix = params.get("suffix", "...")
if not value or len(value) <= length:
return value
return value[:length] + suffix
def _apply_hash(value, params):
algorithm = params.get("algorithm", "sha256")
if algorithm == "md5":
return hashlib.md5(str(value).encode()).hexdigest()[:16]
return hashlib.sha256(str(value).encode()).hexdigest()[:32]
def _apply_generalize(value, params):
try:
step = params.get("step", 10)
num = float(value)
lower = int(num // step * step)
upper = lower + step
return f"{lower}-{upper}"
except Exception:
return value
def _apply_replace(value, params):
return params.get("replacement", "[REDACTED]")
def apply_masking(value, algorithm, params):
if value is None:
return None
handlers = {
"mask": _apply_mask,
"truncate": _apply_truncate,
"hash": _apply_hash,
"generalize": _apply_generalize,
"replace": _apply_replace,
}
handler = handlers.get(algorithm)
if not handler:
return value
return handler(str(value), params or {})
def _get_column_rules(db: Session, table_id: int, project_id=None):
columns = db.query(DataColumn).filter(DataColumn.table_id == table_id).all()
col_rules = {}
results = {}
if project_id:
res_list = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project_id,
ClassificationResult.column_id.in_([c.id for c in columns]),
).all()
results = {r.column_id: r for r in res_list}
rules = db.query(MaskingRule).filter(MaskingRule.is_active == True).all()
rule_map = {}
for r in rules:
key = (r.level_id, r.category_id)
if key not in rule_map:
rule_map[key] = r
for col in columns:
matched_rule = None
if col.id in results:
r = results[col.id]
matched_rule = rule_map.get((r.level_id, r.category_id))
if not matched_rule:
matched_rule = rule_map.get((r.level_id, None))
if not matched_rule:
matched_rule = rule_map.get((None, r.category_id))
col_rules[col.id] = matched_rule
return col_rules
def preview_masking(db: Session, source_id: int, table_name: str, project_id=None, limit=20):
source = get_datasource(db, source_id)
if not source:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在")
table = (
db.query(DataTable)
.join(Database)
.filter(Database.source_id == source_id, DataTable.name == table_name)
.first()
)
if not table:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="表不存在")
col_rules = _get_column_rules(db, table.id, project_id)
from sqlalchemy import create_engine, text
password = ""
if source.encrypted_password:
try:
password = _decrypt_password(source.encrypted_password)
except Exception:
pass
driver_map = {
"mysql": "mysql+pymysql",
"postgresql": "postgresql+psycopg2",
"oracle": "oracle+cx_oracle",
"sqlserver": "mssql+pymssql",
}
driver = driver_map.get(source.source_type, source.source_type)
url = f"{driver}://{source.username}:{password}@{source.host}:{source.port}/{source.database_name}"
engine = create_engine(url, pool_pre_ping=True)
columns = db.query(DataColumn).filter(DataColumn.table_id == table.id).all()
rows_raw = []
try:
with engine.connect() as conn:
result = conn.execute(text(f'SELECT * FROM "{table_name}" LIMIT {limit}'))
rows_raw = [dict(row._mapping) for row in result]
except Exception:
try:
with engine.connect() as conn:
result = conn.execute(text(f"SELECT * FROM {table_name} LIMIT {limit}"))
rows_raw = [dict(row._mapping) for row in result]
except Exception as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"查询失败: {e}")
masked_rows = []
for raw in rows_raw:
masked = {}
for col in columns:
val = raw.get(col.name)
rule = col_rules.get(col.id)
if rule:
masked[col.name] = apply_masking(val, rule.algorithm, rule.params or {})
else:
masked[col.name] = val
masked_rows.append(masked)
return {
"success": True,
"columns": [{"name": c.name, "data_type": c.data_type, "has_rule": col_rules.get(c.id) is not None} for c in columns],
"rows": masked_rows,
"total_rows": len(masked_rows),
}
+115 -17
View File
@@ -3,9 +3,23 @@ from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.metadata import DataSource, Database, DataTable, DataColumn
from app.models.schema_change import SchemaChangeLog
from app.services.datasource_service import get_datasource, _decrypt_password
def _log_schema_change(db: Session, source_id: int, change_type: str, database_id: int = None, table_id: int = None, column_id: int = None, old_value: str = None, new_value: str = None):
log = SchemaChangeLog(
source_id=source_id,
database_id=database_id,
table_id=table_id,
column_id=column_id,
change_type=change_type,
old_value=old_value,
new_value=new_value,
)
db.add(log)
def get_database(db: Session, db_id: int) -> Optional[Database]:
return db.query(Database).filter(Database.id == db_id).first()
@@ -19,14 +33,14 @@ def get_column(db: Session, column_id: int) -> Optional[DataColumn]:
def list_databases(db: Session, source_id: Optional[int] = None) -> List[Database]:
query = db.query(Database)
query = db.query(Database).filter(Database.is_deleted == False)
if source_id:
query = query.filter(Database.source_id == source_id)
return query.all()
def list_tables(db: Session, database_id: Optional[int] = None, keyword: Optional[str] = None) -> Tuple[List[DataTable], int]:
query = db.query(DataTable)
query = db.query(DataTable).filter(DataTable.is_deleted == False)
if database_id:
query = query.filter(DataTable.database_id == database_id)
if keyword:
@@ -37,7 +51,7 @@ def list_tables(db: Session, database_id: Optional[int] = None, keyword: Optiona
def list_columns(db: Session, table_id: Optional[int] = None, keyword: Optional[str] = None, page: int = 1, page_size: int = 50) -> Tuple[List[DataColumn], int]:
query = db.query(DataColumn)
query = db.query(DataColumn).filter(DataColumn.is_deleted == False)
if table_id:
query = query.filter(DataColumn.table_id == table_id)
if keyword:
@@ -49,7 +63,7 @@ def list_columns(db: Session, table_id: Optional[int] = None, keyword: Optional[
return items, total
def build_tree(db: Session, source_id: Optional[int] = None) -> List[dict]:
def build_tree(db: Session, source_id: Optional[int] = None, include_deleted: bool = False) -> List[dict]:
sources = db.query(DataSource)
if source_id:
sources = sources.filter(DataSource.id == source_id)
@@ -65,20 +79,24 @@ def build_tree(db: Session, source_id: Optional[int] = None) -> List[dict]:
"meta": {"source_type": s.source_type, "status": s.status},
}
for d in s.databases:
if not include_deleted and d.is_deleted:
continue
db_node = {
"id": d.id,
"name": d.name,
"type": "database",
"children": [],
"meta": {"charset": d.charset, "table_count": d.table_count},
"meta": {"charset": d.charset, "table_count": d.table_count, "is_deleted": d.is_deleted},
}
for t in d.tables:
if not include_deleted and t.is_deleted:
continue
table_node = {
"id": t.id,
"name": t.name,
"type": "table",
"children": [],
"meta": {"comment": t.comment, "row_count": t.row_count, "column_count": t.column_count},
"meta": {"comment": t.comment, "row_count": t.row_count, "column_count": t.column_count, "is_deleted": t.is_deleted},
}
db_node["children"].append(table_node)
source_node["children"].append(db_node)
@@ -86,9 +104,16 @@ def build_tree(db: Session, source_id: Optional[int] = None) -> List[dict]:
return result
def _compute_checksum(data: dict) -> str:
import hashlib, json
payload = json.dumps(data, sort_keys=True, ensure_ascii=False, default=str)
return hashlib.sha256(payload.encode()).hexdigest()[:32]
def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
from sqlalchemy import create_engine, inspect, text
import json
from datetime import datetime
source = get_datasource(db, source_id)
if not source:
@@ -118,29 +143,56 @@ def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
inspector = inspect(engine)
db_names = inspector.get_schema_names() or [source.database_name]
scan_time = datetime.utcnow()
total_tables = 0
total_columns = 0
updated_tables = 0
updated_columns = 0
for db_name in db_names:
db_obj = db.query(Database).filter(Database.source_id == source.id, Database.name == db_name).first()
db_checksum = _compute_checksum({"name": db_name})
db_obj = db.query(Database).filter(
Database.source_id == source.id, Database.name == db_name
).first()
if not db_obj:
db_obj = Database(source_id=source.id, name=db_name)
db_obj = Database(source_id=source.id, name=db_name, checksum=db_checksum, last_scanned_at=scan_time)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
else:
db_obj.checksum = db_checksum
db_obj.last_scanned_at = scan_time
db_obj.is_deleted = False
db_obj.deleted_at = None
table_names = inspector.get_table_names(schema=db_name)
for tname in table_names:
table_obj = db.query(DataTable).filter(DataTable.database_id == db_obj.id, DataTable.name == tname).first()
t_checksum = _compute_checksum({"name": tname})
table_obj = db.query(DataTable).filter(
DataTable.database_id == db_obj.id, DataTable.name == tname
).first()
if not table_obj:
table_obj = DataTable(database_id=db_obj.id, name=tname)
table_obj = DataTable(database_id=db_obj.id, name=tname, checksum=t_checksum, last_scanned_at=scan_time)
db.add(table_obj)
db.commit()
db.refresh(table_obj)
_log_schema_change(db, source.id, "add_table", database_id=db_obj.id, table_id=table_obj.id, new_value=tname)
else:
if table_obj.checksum != t_checksum:
table_obj.checksum = t_checksum
updated_tables += 1
table_obj.last_scanned_at = scan_time
table_obj.is_deleted = False
table_obj.deleted_at = None
columns = inspector.get_columns(tname, schema=db_name)
for col in columns:
col_obj = db.query(DataColumn).filter(DataColumn.table_id == table_obj.id, DataColumn.name == col["name"]).first()
col_checksum = _compute_checksum({
"name": col["name"],
"type": str(col.get("type", "")),
"max_length": col.get("max_length"),
"comment": col.get("comment"),
"nullable": col.get("nullable", True),
})
col_obj = db.query(DataColumn).filter(
DataColumn.table_id == table_obj.id, DataColumn.name == col["name"]
).first()
if not col_obj:
sample = None
try:
@@ -150,7 +202,6 @@ def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
sample = json.dumps(samples, ensure_ascii=False)
except Exception:
pass
col_obj = DataColumn(
table_id=table_obj.id,
name=col["name"],
@@ -159,13 +210,58 @@ def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
comment=col.get("comment"),
is_nullable=col.get("nullable", True),
sample_data=sample,
checksum=col_checksum,
last_scanned_at=scan_time,
)
db.add(col_obj)
total_columns += 1
_log_schema_change(db, source.id, "add_column", database_id=db_obj.id, table_id=table_obj.id, column_id=col_obj.id, new_value=col["name"])
else:
if col_obj.checksum != col_checksum:
old_val = f"type={col_obj.data_type}, len={col_obj.length}, comment={col_obj.comment}"
new_val = f"type={str(col.get('type', ''))}, len={col.get('max_length')}, comment={col.get('comment')}"
_log_schema_change(db, source.id, "change_type", database_id=db_obj.id, table_id=table_obj.id, column_id=col_obj.id, old_value=old_val, new_value=new_val)
col_obj.checksum = col_checksum
col_obj.data_type = str(col.get("type", ""))
col_obj.length = col.get("max_length")
col_obj.comment = col.get("comment")
col_obj.is_nullable = col.get("nullable", True)
updated_columns += 1
col_obj.last_scanned_at = scan_time
col_obj.is_deleted = False
col_obj.deleted_at = None
total_tables += 1
db.commit()
# Soft-delete objects not seen in this scan and log changes
deleted_dbs = db.query(Database).filter(
Database.source_id == source.id,
Database.last_scanned_at < scan_time,
).all()
for d in deleted_dbs:
_log_schema_change(db, source.id, "drop_database", database_id=d.id, old_value=d.name)
d.is_deleted = True
d.deleted_at = scan_time
for db_obj in db.query(Database).filter(Database.source_id == source.id).all():
deleted_tables = db.query(DataTable).filter(
DataTable.database_id == db_obj.id,
DataTable.last_scanned_at < scan_time,
).all()
for t in deleted_tables:
_log_schema_change(db, source.id, "drop_table", database_id=db_obj.id, table_id=t.id, old_value=t.name)
t.is_deleted = True
t.deleted_at = scan_time
for table_obj in db.query(DataTable).filter(DataTable.database_id == db_obj.id).all():
deleted_cols = db.query(DataColumn).filter(
DataColumn.table_id == table_obj.id,
DataColumn.last_scanned_at < scan_time,
).all()
for c in deleted_cols:
_log_schema_change(db, source.id, "drop_column", database_id=db_obj.id, table_id=table_obj.id, column_id=c.id, old_value=c.name)
c.is_deleted = True
c.deleted_at = scan_time
source.status = "active"
db.commit()
@@ -176,6 +272,8 @@ def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
"databases": len(db_names),
"tables": total_tables,
"columns": total_columns,
"updated_tables": updated_tables,
"updated_columns": updated_columns,
}
except Exception as e:
source.status = "error"
@@ -0,0 +1,183 @@
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.metadata import DataSource, Database, DataTable, DataColumn
from app.services.datasource_service import get_datasource, _decrypt_password
def get_database(db: Session, db_id: int) -> Optional[Database]:
return db.query(Database).filter(Database.id == db_id).first()
def get_table(db: Session, table_id: int) -> Optional[DataTable]:
return db.query(DataTable).filter(DataTable.id == table_id).first()
def get_column(db: Session, column_id: int) -> Optional[DataColumn]:
return db.query(DataColumn).filter(DataColumn.id == column_id).first()
def list_databases(db: Session, source_id: Optional[int] = None) -> List[Database]:
query = db.query(Database)
if source_id:
query = query.filter(Database.source_id == source_id)
return query.all()
def list_tables(db: Session, database_id: Optional[int] = None, keyword: Optional[str] = None) -> Tuple[List[DataTable], int]:
query = db.query(DataTable)
if database_id:
query = query.filter(DataTable.database_id == database_id)
if keyword:
query = query.filter(
(DataTable.name.contains(keyword)) | (DataTable.comment.contains(keyword))
)
return query.all(), query.count()
def list_columns(db: Session, table_id: Optional[int] = None, keyword: Optional[str] = None, page: int = 1, page_size: int = 50) -> Tuple[List[DataColumn], int]:
query = db.query(DataColumn)
if table_id:
query = query.filter(DataColumn.table_id == table_id)
if keyword:
query = query.filter(
(DataColumn.name.contains(keyword)) | (DataColumn.comment.contains(keyword))
)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
def build_tree(db: Session, source_id: Optional[int] = None) -> List[dict]:
sources = db.query(DataSource)
if source_id:
sources = sources.filter(DataSource.id == source_id)
sources = sources.all()
result = []
for s in sources:
source_node = {
"id": s.id,
"name": s.name,
"type": "source",
"children": [],
"meta": {"source_type": s.source_type, "status": s.status},
}
for d in s.databases:
db_node = {
"id": d.id,
"name": d.name,
"type": "database",
"children": [],
"meta": {"charset": d.charset, "table_count": d.table_count},
}
for t in d.tables:
table_node = {
"id": t.id,
"name": t.name,
"type": "table",
"children": [],
"meta": {"comment": t.comment, "row_count": t.row_count, "column_count": t.column_count},
}
db_node["children"].append(table_node)
source_node["children"].append(db_node)
result.append(source_node)
return result
def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
from sqlalchemy import create_engine, inspect, text
import json
source = get_datasource(db, source_id)
if not source:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在")
driver_map = {
"mysql": "mysql+pymysql",
"postgresql": "postgresql+psycopg2",
"oracle": "oracle+cx_oracle",
"sqlserver": "mssql+pymssql",
}
driver = driver_map.get(source.source_type, source.source_type)
if source.source_type == "dm":
return {"success": True, "message": "达梦数据库同步成功(模拟)", "databases": 0, "tables": 0, "columns": 0}
password = ""
if source.encrypted_password:
try:
password = _decrypt_password(source.encrypted_password)
except Exception:
pass
try:
url = f"{driver}://{source.username}:{password}@{source.host}:{source.port}/{source.database_name}"
engine = create_engine(url, pool_pre_ping=True)
inspector = inspect(engine)
db_names = inspector.get_schema_names() or [source.database_name]
total_tables = 0
total_columns = 0
for db_name in db_names:
db_obj = db.query(Database).filter(Database.source_id == source.id, Database.name == db_name).first()
if not db_obj:
db_obj = Database(source_id=source.id, name=db_name)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
table_names = inspector.get_table_names(schema=db_name)
for tname in table_names:
table_obj = db.query(DataTable).filter(DataTable.database_id == db_obj.id, DataTable.name == tname).first()
if not table_obj:
table_obj = DataTable(database_id=db_obj.id, name=tname)
db.add(table_obj)
db.commit()
db.refresh(table_obj)
columns = inspector.get_columns(tname, schema=db_name)
for col in columns:
col_obj = db.query(DataColumn).filter(DataColumn.table_id == table_obj.id, DataColumn.name == col["name"]).first()
if not col_obj:
sample = None
try:
with engine.connect() as conn:
result = conn.execute(text(f'SELECT "{col["name"]}" FROM "{db_name}"."{tname}" LIMIT 5'))
samples = [str(r[0]) for r in result if r[0] is not None]
sample = json.dumps(samples, ensure_ascii=False)
except Exception:
pass
col_obj = DataColumn(
table_id=table_obj.id,
name=col["name"],
data_type=str(col.get("type", "")),
length=col.get("max_length"),
comment=col.get("comment"),
is_nullable=col.get("nullable", True),
sample_data=sample,
)
db.add(col_obj)
total_columns += 1
total_tables += 1
db.commit()
source.status = "active"
db.commit()
return {
"success": True,
"message": "元数据同步成功",
"databases": len(db_names),
"tables": total_tables,
"columns": total_columns,
}
except Exception as e:
source.status = "error"
db.commit()
return {"success": False, "message": f"同步失败: {str(e)}", "databases": 0, "tables": 0, "columns": 0}
+195
View File
@@ -0,0 +1,195 @@
import os
import json
import logging
from typing import List, Optional, Tuple
from datetime import datetime
import joblib
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sqlalchemy.orm import Session
from app.models.project import ClassificationResult
from app.models.classification import Category
from app.models.ml import MLModelVersion
logger = logging.getLogger(__name__)
MODELS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "ml_models")
os.makedirs(MODELS_DIR, exist_ok=True)
def _build_text_features(column_name: str, comment: Optional[str], sample_data: Optional[str]) -> str:
parts = [column_name]
if comment:
parts.append(comment)
if sample_data:
try:
samples = json.loads(sample_data)
if isinstance(samples, list):
parts.extend([str(s) for s in samples[:5]])
except Exception:
parts.append(sample_data)
return " ".join(parts)
def _fetch_training_data(db: Session, min_samples_per_class: int = 5):
results = (
db.query(ClassificationResult)
.filter(ClassificationResult.source == "manual")
.filter(ClassificationResult.category_id.isnot(None))
.all()
)
texts = []
labels = []
for r in results:
if r.column:
text = _build_text_features(r.column.name, r.column.comment, r.column.sample_data)
texts.append(text)
labels.append(r.category_id)
from collections import Counter
counts = Counter(labels)
valid_classes = {c for c, n in counts.items() if n >= min_samples_per_class}
filtered_texts = []
filtered_labels = []
for t, l in zip(texts, labels):
if l in valid_classes:
filtered_texts.append(t)
filtered_labels.append(l)
return filtered_texts, filtered_labels, len(filtered_labels)
def train_model(db: Session, model_name: Optional[str] = None, algorithm: str = "logistic_regression", test_size: float = 0.2):
texts, labels, total = _fetch_training_data(db)
if total < 20:
logger.warning("Not enough training data (need >= 20, got %d)", total)
return None
X_train, X_test, y_train, y_test = train_test_split(
texts, labels, test_size=test_size, random_state=42, stratify=labels
)
vectorizer = TfidfVectorizer(analyzer="char_wb", ngram_range=(2, 4), max_features=5000)
X_train_vec = vectorizer.fit_transform(X_train)
X_test_vec = vectorizer.transform(X_test)
if algorithm == "logistic_regression":
clf = LogisticRegression(max_iter=1000, multi_class="multinomial", solver="lbfgs")
elif algorithm == "random_forest":
clf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
else:
clf = LogisticRegression(max_iter=1000, multi_class="multinomial", solver="lbfgs")
clf.fit(X_train_vec, y_train)
y_pred = clf.predict(X_test_vec)
acc = accuracy_score(y_test, y_pred)
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
name = model_name or f"model_{timestamp}"
model_path = os.path.join(MODELS_DIR, f"{name}_clf.joblib")
vec_path = os.path.join(MODELS_DIR, f"{name}_tfidf.joblib")
joblib.dump(clf, model_path)
joblib.dump(vectorizer, vec_path)
db.query(MLModelVersion).filter(MLModelVersion.is_active == True).update({"is_active": False})
mv = MLModelVersion(
name=name,
model_path=model_path,
vectorizer_path=vec_path,
accuracy=acc,
train_samples=total,
is_active=True,
description=f"Algorithm: {algorithm}, test_accuracy: {acc:.4f}",
)
db.add(mv)
db.commit()
db.refresh(mv)
logger.info("Trained model %s with accuracy %.4f on %d samples", name, acc, total)
return mv
def _get_active_model(db: Session):
mv = db.query(MLModelVersion).filter(MLModelVersion.is_active == True).first()
if not mv or not os.path.exists(mv.model_path) or not os.path.exists(mv.vectorizer_path):
return None
clf = joblib.load(mv.model_path)
vectorizer = joblib.load(mv.vectorizer_path)
return clf, vectorizer, mv
def predict_categories(db: Session, texts: List[str], top_k: int = 3):
model_tuple = _get_active_model(db)
if not model_tuple:
return [[] for _ in texts]
clf, vectorizer, mv = model_tuple
X = vectorizer.transform(texts)
if hasattr(clf, "predict_proba"):
probs = clf.predict_proba(X)
else:
preds = clf.predict(X)
return [[{"category_id": int(p), "confidence": 1.0}] for p in preds]
classes = [int(c) for c in clf.classes_]
results = []
for prob in probs:
top_idx = np.argsort(prob)[::-1][:top_k]
suggestions = []
for idx in top_idx:
cat_id = classes[idx]
confidence = float(prob[idx])
if confidence > 0.01:
suggestions.append({"category_id": cat_id, "confidence": round(confidence, 4)})
results.append(suggestions)
return results
def suggest_for_project_columns(db: Session, project_id: int, column_ids: Optional[List[int]] = None, top_k: int = 3):
from app.models.project import ClassificationProject
from app.models.metadata import DataColumn
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if not project:
return {"success": False, "message": "项目不存在"}
query = db.query(DataColumn).join(
ClassificationResult,
(ClassificationResult.column_id == DataColumn.id) & (ClassificationResult.project_id == project_id),
isouter=True,
)
if column_ids:
query = query.filter(DataColumn.id.in_(column_ids))
columns = query.all()
texts = []
col_map = []
for col in columns:
texts.append(_build_text_features(col.name, col.comment, col.sample_data))
col_map.append(col)
if not texts:
return {"success": True, "suggestions": [], "message": "没有可预测的字段"}
predictions = predict_categories(db, texts, top_k=top_k)
suggestions = []
all_category_ids = set()
for col, preds in zip(col_map, predictions):
for p in preds:
all_category_ids.add(p["category_id"])
categories = {c.id: c for c in db.query(Category).filter(Category.id.in_(list(all_category_ids))).all()}
for col, preds in zip(col_map, predictions):
item = {
"column_id": col.id,
"column_name": col.name,
"table_name": col.table.name if col.table else None,
"suggestions": [],
}
for p in preds:
cat = categories.get(p["category_id"])
item["suggestions"].append({
"category_id": p["category_id"],
"category_name": cat.name if cat else None,
"category_code": cat.code if cat else None,
"confidence": p["confidence"],
})
suggestions.append(item)
return {"success": True, "suggestions": suggestions}
+152
View File
@@ -94,3 +94,155 @@ def generate_classification_report(db: Session, project_id: int) -> bytes:
doc.save(buffer)
buffer.seek(0)
return buffer.read()
def generate_excel_report(db: Session, project_id: int) -> bytes:
"""Generate an Excel report for a classification project."""
from openpyxl import Workbook
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
from openpyxl.chart import PieChart, Reference
from sqlalchemy import func
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if not project:
raise ValueError("项目不存在")
wb = Workbook()
ws = wb.active
ws.title = "报告概览"
# Title
ws.merge_cells('A1:D1')
ws['A1'] = '数据分类分级项目报告'
ws['A1'].font = Font(size=18, bold=True)
ws['A1'].alignment = Alignment(horizontal='center')
# Basic info
ws['A3'] = '项目名称'
ws['B3'] = project.name
ws['A4'] = '报告生成时间'
ws['B4'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
ws['A5'] = '项目状态'
ws['B5'] = project.status
ws['A6'] = '模板版本'
ws['B6'] = project.template.version if project.template else 'N/A'
# Statistics
results = db.query(ClassificationResult).filter(ClassificationResult.project_id == project_id).all()
total = len(results)
auto_count = sum(1 for r in results if r.source == 'auto')
manual_count = sum(1 for r in results if r.source == 'manual')
ws['A8'] = '总字段数'
ws['B8'] = total
ws['A9'] = '自动识别'
ws['B9'] = auto_count
ws['A10'] = '人工打标'
ws['B10'] = manual_count
# Level distribution
ws['A12'] = '分级'
ws['B12'] = '数量'
ws['C12'] = '占比'
ws['A12'].font = Font(bold=True)
ws['B12'].font = Font(bold=True)
ws['C12'].font = Font(bold=True)
level_stats = {}
for r in results:
if r.level:
level_stats[r.level.name] = level_stats.get(r.level.name, 0) + 1
red_fill = PatternFill(start_color='FFCCCC', end_color='FFCCCC', fill_type='solid')
row = 13
for level_name, count in sorted(level_stats.items(), key=lambda x: -x[1]):
ws.cell(row=row, column=1, value=level_name)
ws.cell(row=row, column=2, value=count)
pct = f'{count / total * 100:.1f}%' if total > 0 else '0%'
ws.cell(row=row, column=3, value=pct)
if 'L4' in level_name or 'L5' in level_name:
for c in range(1, 4):
ws.cell(row=row, column=c).fill = red_fill
row += 1
# High risk sheet
ws2 = wb.create_sheet("高敏感数据清单")
ws2.append(['字段名', '所属表', '分类', '分级', '来源', '置信度'])
for cell in ws2[1]:
cell.font = Font(bold=True)
cell.fill = PatternFill(start_color='DDEBF7', end_color='DDEBF7', fill_type='solid')
high_risk = [r for r in results if r.level and r.level.code in ('L4', 'L5')]
for r in high_risk[:500]:
ws2.append([
r.column.name if r.column else 'N/A',
r.column.table.name if r.column and r.column.table else 'N/A',
r.category.name if r.category else 'N/A',
r.level.name if r.level else 'N/A',
'自动' if r.source == 'auto' else '人工',
r.confidence,
])
# Auto-fit column widths roughly
for ws_sheet in [ws, ws2]:
for column in ws_sheet.columns:
max_length = 0
column_letter = column[0].column_letter
for cell in column:
try:
if len(str(cell.value)) > max_length:
max_length = len(str(cell.value))
except:
pass
adjusted_width = min(max_length + 2, 50)
ws_sheet.column_dimensions[column_letter].width = adjusted_width
buffer = BytesIO()
wb.save(buffer)
buffer.seek(0)
return buffer.read()
def get_report_summary(db: Session, project_id: int) -> dict:
"""Get aggregated report data for PDF generation (frontend)."""
from sqlalchemy import func
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if not project:
raise ValueError("项目不存在")
results = db.query(ClassificationResult).filter(ClassificationResult.project_id == project_id).all()
total = len(results)
auto_count = sum(1 for r in results if r.source == 'auto')
manual_count = sum(1 for r in results if r.source == 'manual')
level_stats = {}
for r in results:
if r.level:
level_stats[r.level.name] = level_stats.get(r.level.name, 0) + 1
high_risk = []
for r in results:
if r.level and r.level.code in ('L4', 'L5'):
high_risk.append({
"column_name": r.column.name if r.column else 'N/A',
"table_name": r.column.table.name if r.column and r.column.table else 'N/A',
"category_name": r.category.name if r.category else 'N/A',
"level_name": r.level.name if r.level else 'N/A',
"source": '自动' if r.source == 'auto' else '人工',
"confidence": r.confidence,
})
return {
"project_name": project.name,
"status": project.status,
"template_version": project.template.version if project.template else 'N/A',
"generated_at": datetime.now().isoformat(),
"total": total,
"auto": auto_count,
"manual": manual_count,
"level_distribution": [
{"name": name, "count": count}
for name, count in sorted(level_stats.items(), key=lambda x: -x[1])
],
"high_risk": high_risk[:100],
}
+125
View File
@@ -0,0 +1,125 @@
from typing import List, Optional
from sqlalchemy.orm import Session
from datetime import datetime
from app.models.project import ClassificationProject, ClassificationResult
from app.models.classification import DataLevel
from app.models.metadata import DataSource, Database, DataTable, DataColumn
from app.models.masking import MaskingRule
from app.models.risk import RiskAssessment
def _get_level_weight(level_code: str) -> int:
weights = {"L1": 1, "L2": 2, "L3": 3, "L4": 4, "L5": 5}
return weights.get(level_code, 1)
def calculate_project_risk(db: Session, project_id: int) -> RiskAssessment:
"""Calculate risk score for a project."""
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if not project:
return None
results = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project_id,
ClassificationResult.level_id.isnot(None),
).all()
total_risk = 0.0
total_sensitivity = 0.0
total_exposure = 0.0
total_protection = 0.0
detail_items = []
# Get all active masking rules for quick lookup
rules = db.query(MaskingRule).filter(MaskingRule.is_active == True).all()
rule_level_ids = {r.level_id for r in rules if r.level_id}
rule_cat_ids = {r.category_id for r in rules if r.category_id}
for r in results:
if not r.level:
continue
level_weight = _get_level_weight(r.level.code)
# Exposure: count source connections for the column's table
source_count = 1
if r.column and r.column.table and r.column.table.database:
# Simple: if table exists in multiple dbs (rare), count them
source_count = max(1, len(r.column.table.database.source.databases or []))
exposure_factor = 1 + source_count * 0.2
# Protection: check if masking rule exists for this level/category
has_masking = (r.level_id in rule_level_ids) or (r.category_id in rule_cat_ids)
protection_rate = 0.3 if has_masking else 0.0
item_risk = level_weight * exposure_factor * (1 - protection_rate)
total_risk += item_risk
total_sensitivity += level_weight
total_exposure += exposure_factor
total_protection += protection_rate
detail_items.append({
"column_id": r.column_id,
"column_name": r.column.name if r.column else None,
"level": r.level.code if r.level else None,
"level_weight": level_weight,
"exposure_factor": round(exposure_factor, 2),
"protection_rate": protection_rate,
"item_risk": round(item_risk, 2),
})
# Normalize to 0-100 (heuristic: assume max reasonable raw score is 15 per field)
count = len(detail_items) or 1
max_raw = count * 15
risk_score = min(100, (total_risk / max_raw) * 100) if max_raw > 0 else 0
# Upsert risk assessment
existing = db.query(RiskAssessment).filter(
RiskAssessment.entity_type == "project",
RiskAssessment.entity_id == project_id,
).first()
if existing:
existing.risk_score = round(risk_score, 2)
existing.sensitivity_score = round(total_sensitivity / count, 2)
existing.exposure_score = round(total_exposure / count, 2)
existing.protection_score = round(total_protection / count, 2)
existing.details = {"items": detail_items[:100], "total_items": len(detail_items)}
existing.updated_at = datetime.utcnow()
else:
existing = RiskAssessment(
entity_type="project",
entity_id=project_id,
entity_name=project.name,
risk_score=round(risk_score, 2),
sensitivity_score=round(total_sensitivity / count, 2),
exposure_score=round(total_exposure / count, 2),
protection_score=round(total_protection / count, 2),
details={"items": detail_items[:100], "total_items": len(detail_items)},
)
db.add(existing)
db.commit()
return existing
def calculate_all_projects_risk(db: Session) -> dict:
"""Batch calculate risk for all projects."""
projects = db.query(ClassificationProject).all()
updated = 0
for p in projects:
try:
calculate_project_risk(db, p.id)
updated += 1
except Exception:
pass
return {"updated": updated}
def get_risk_top_n(db: Session, entity_type: str = "project", n: int = 10) -> List[RiskAssessment]:
return (
db.query(RiskAssessment)
.filter(RiskAssessment.entity_type == entity_type)
.order_by(RiskAssessment.risk_score.desc())
.limit(n)
.all()
)
@@ -0,0 +1,99 @@
import os
import re
import json
from typing import Optional, List
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.metadata import UnstructuredFile
from app.core.events import minio_client
from app.core.config import settings
def extract_text_from_file(file_path: str, file_type: str) -> str:
text = ""
ft = file_type.lower()
if ft in ("word", "docx"):
try:
from docx import Document
doc = Document(file_path)
text = "\n".join([p.text for p in doc.paragraphs if p.text])
except Exception as e:
raise ValueError(f"解析Word失败: {e}")
elif ft in ("excel", "xlsx", "xls"):
try:
from openpyxl import load_workbook
wb = load_workbook(file_path, data_only=True)
parts = []
for sheet in wb.worksheets:
for row in sheet.iter_rows(values_only=True):
parts.append(" ".join([str(c) for c in row if c is not None]))
text = "\n".join(parts)
except Exception as e:
raise ValueError(f"解析Excel失败: {e}")
elif ft == "pdf":
try:
import pdfplumber
with pdfplumber.open(file_path) as pdf:
text = "\n".join([page.extract_text() or "" for page in pdf.pages])
except Exception as e:
raise ValueError(f"解析PDF失败: {e}")
elif ft == "txt":
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
text = f.read()
else:
raise ValueError(f"不支持的文件类型: {ft}")
return text
def scan_text_for_sensitive(text: str) -> List[dict]:
"""Scan extracted text for sensitive patterns using built-in rules."""
matches = []
# ID card
id_pattern = re.compile(r"(?<!\d)\d{17}[\dXx](?!\d)")
for m in id_pattern.finditer(text):
snippet = text[max(0, m.start()-10):min(len(text), m.end()+10)]
matches.append({"rule_name": "身份证号", "category_code": "CUST_PERSONAL", "level_code": "L4", "snippet": snippet, "position": m.start()})
# Phone
phone_pattern = re.compile(r"(?<!\d)1[3-9]\d{9}(?!\d)")
for m in phone_pattern.finditer(text):
snippet = text[max(0, m.start()-10):min(len(text), m.end()+10)]
matches.append({"rule_name": "手机号", "category_code": "CUST_PERSONAL", "level_code": "L4", "snippet": snippet, "position": m.start()})
# Bank card (simple 16-19 digits)
bank_pattern = re.compile(r"(?<!\d)\d{16,19}(?!\d)")
for m in bank_pattern.finditer(text):
snippet = text[max(0, m.start()-10):min(len(text), m.end()+10)]
matches.append({"rule_name": "银行卡号", "category_code": "FIN_PAYMENT", "level_code": "L4", "snippet": snippet, "position": m.start()})
# Amount
amount_pattern = re.compile(r"(?<!\d)\d{1,3}(,\d{3})*\.\d{2}(?!\d)")
for m in amount_pattern.finditer(text):
snippet = text[max(0, m.start()-10):min(len(text), m.end()+10)]
matches.append({"rule_name": "金额", "category_code": "FIN_PAYMENT", "level_code": "L3", "snippet": snippet, "position": m.start()})
return matches
def process_unstructured_file(db: Session, file_id: int) -> dict:
file_obj = db.query(UnstructuredFile).filter(UnstructuredFile.id == file_id).first()
if not file_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="文件不存在")
if not file_obj.storage_path:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="文件未上传")
# Download from MinIO to temp
tmp_path = f"/tmp/unstructured_{file_id}_{file_obj.original_name}"
try:
minio_client.fget_object(settings.MINIO_BUCKET_NAME, file_obj.storage_path, tmp_path)
text = extract_text_from_file(tmp_path, file_obj.file_type or "")
file_obj.extracted_text = text[:50000] # limit storage
matches = scan_text_for_sensitive(text)
file_obj.analysis_result = {"matches": matches, "total_chars": len(text)}
file_obj.status = "processed"
db.commit()
return {"success": True, "matches": matches, "total_chars": len(text)}
except Exception as e:
file_obj.status = "error"
db.commit()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
+97
View File
@@ -0,0 +1,97 @@
import secrets
from typing import Optional, Tuple
from sqlalchemy.orm import Session
from app.models.watermark import WatermarkLog
# Zero-width characters for binary encoding
ZW_SPACE = "\u200b" # zero-width space -> 0
ZW_NOJOIN = "\u200c" # zero-width non-joiner -> 1
MARKER = "\u200d" # zero-width joiner -> start marker
def _int_to_binary_bits(n: int, bits: int = 32) -> str:
return format(n, f"0{bits}b")
def _binary_bits_to_int(bits: str) -> int:
return int(bits, 2)
def embed_watermark(text: str, user_id: int, key: str) -> str:
"""Embed invisible watermark into text using zero-width characters."""
# Encode user_id as 32-bit binary
bits = _int_to_binary_bits(user_id)
# Encode key hash as 16-bit for verification
key_bits = _int_to_binary_bits(hash(key) & 0xFFFF, 16)
payload = key_bits + bits
watermark_chars = MARKER + "".join(ZW_NOJOIN if b == "1" else ZW_SPACE for b in payload)
# Append watermark at the end of the text (before trailing newlines if any)
text = text.rstrip("\n")
return text + watermark_chars + "\n"
def extract_watermark(text: str) -> Tuple[Optional[int], Optional[str]]:
"""Extract watermark from text. Returns (user_id, key_hash_bits) or (None, None)."""
if MARKER not in text:
return None, None
idx = text.index(MARKER)
payload = text[idx + len(MARKER):]
bits = ""
for ch in payload:
if ch == ZW_SPACE:
bits += "0"
elif ch == ZW_NOJOIN:
bits += "1"
else:
# Stop at first non-watermark character
break
if len(bits) < 16:
return None, None
key_bits = bits[:16]
user_bits = bits[16:48]
try:
user_id = _binary_bits_to_int(user_bits)
return user_id, key_bits
except Exception:
return None, None
def apply_watermark_to_lines(lines: list, user_id: int, key: str) -> list:
"""Apply watermark to each line of CSV/TXT."""
return [embed_watermark(line, user_id, key) for line in lines]
def create_watermark_log(db: Session, user_id: int, export_type: str, data_scope: dict) -> WatermarkLog:
key = secrets.token_hex(16)
log = WatermarkLog(
user_id=user_id,
export_type=export_type,
data_scope=str(data_scope),
watermark_key=key,
)
db.add(log)
db.commit()
db.refresh(log)
return log
def trace_watermark(db: Session, text: str) -> Optional[dict]:
"""Trace leaked text back to user."""
user_id, _ = extract_watermark(text)
if user_id is None:
return None
log = (
db.query(WatermarkLog)
.filter(WatermarkLog.user_id == user_id)
.order_by(WatermarkLog.created_at.desc())
.first()
)
if not log:
return None
return {
"user_id": log.user_id,
"username": log.user.username if log.user else None,
"export_type": log.export_type,
"data_scope": log.data_scope,
"created_at": log.created_at.isoformat() if log.created_at else None,
}
+39 -9
View File
@@ -1,3 +1,4 @@
import json
from app.tasks.worker import celery_app
@@ -5,12 +6,10 @@ from app.tasks.worker import celery_app
def auto_classify_task(self, project_id: int, source_ids: list = None):
"""
Async task to run automatic classification on metadata.
Phase 1 placeholder.
"""
from app.core.database import SessionLocal
from app.models.project import ClassificationProject, ClassificationResult, ResultStatus
from app.models.classification import RecognitionRule
from app.models.metadata import DataColumn
from app.models.project import ClassificationProject
from app.services.classification_engine import run_auto_classification
db = SessionLocal()
try:
@@ -18,15 +17,46 @@ def auto_classify_task(self, project_id: int, source_ids: list = None):
if not project:
return {"status": "failed", "reason": "project not found"}
# Update project status
def progress_callback(scanned, matched, total):
percent = int(scanned / total * 100) if total else 0
meta = {
"scanned": scanned,
"matched": matched,
"total": total,
"percent": percent,
}
self.update_state(state="PROGRESS", meta=meta)
# Persist lightweight progress to DB for UI polling
project.scan_progress = json.dumps(meta)
db.commit()
# Initialize
project.status = "scanning"
project.scan_progress = json.dumps({"scanned": 0, "matched": 0, "total": 0, "percent": 0})
db.commit()
rules = db.query(RecognitionRule).filter(RecognitionRule.is_active == True).all()
# TODO: implement rule matching logic in Phase 2
result = run_auto_classification(
db,
project_id,
source_ids=source_ids,
progress_callback=progress_callback,
)
project.status = "assigning"
if result.get("success"):
project.status = "assigning"
else:
project.status = "created"
project.celery_task_id = None
db.commit()
return {"status": "completed", "project_id": project_id, "matched": 0}
return {"status": "completed", "project_id": project_id, "result": result}
except Exception as e:
db.rollback()
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if project:
project.status = "created"
project.celery_task_id = None
db.commit()
return {"status": "failed", "reason": str(e)}
finally:
db.close()
+26
View File
@@ -0,0 +1,26 @@
from app.tasks.worker import celery_app
@celery_app.task(bind=True)
def train_ml_model_task(self, model_name: str = None, algorithm: str = "logistic_regression"):
from app.core.database import SessionLocal
from app.services.ml_service import train_model
db = SessionLocal()
try:
self.update_state(state="PROGRESS", meta={"message": "Fetching training data"})
mv = train_model(db, model_name=model_name, algorithm=algorithm)
if mv:
return {
"status": "completed",
"model_id": mv.id,
"name": mv.name,
"accuracy": mv.accuracy,
"train_samples": mv.train_samples,
}
else:
return {"status": "failed", "reason": "Not enough training data (need >= 20 samples)"}
except Exception as e:
return {"status": "failed", "reason": str(e)}
finally:
db.close()
+1 -1
View File
@@ -5,7 +5,7 @@ celery_app = Celery(
"data_pointer",
broker=settings.REDIS_URL,
backend=settings.REDIS_URL,
include=["app.tasks.classification_tasks"],
include=["app.tasks.classification_tasks", "app.tasks.ml_tasks"],
)
celery_app.conf.update(