Files
hiderfong 6d70520e79 feat: 全量功能模块开发与集成测试修复
- 新增后端模块:Alert、APIAsset、Compliance、Lineage、Masking、Risk、SchemaChange、Unstructured、Watermark
- 新增前端模块页面与API接口
- 新增Alembic迁移脚本(002-014)覆盖全量业务表
- 新增测试数据生成脚本与集成测试脚本
- 修复metadata模型JSON类型导入缺失导致启动失败的问题
- 修复前端Alert/APIAsset页面request模块路径错误
- 更新docker-compose与开发计划文档
2026-04-25 08:51:38 +08:00

126 lines
4.6 KiB
Python

from typing import List, Optional
from sqlalchemy.orm import Session
from datetime import datetime
from app.models.project import ClassificationProject, ClassificationResult
from app.models.classification import DataLevel
from app.models.metadata import DataSource, Database, DataTable, DataColumn
from app.models.masking import MaskingRule
from app.models.risk import RiskAssessment
def _get_level_weight(level_code: str) -> int:
weights = {"L1": 1, "L2": 2, "L3": 3, "L4": 4, "L5": 5}
return weights.get(level_code, 1)
def calculate_project_risk(db: Session, project_id: int) -> RiskAssessment:
"""Calculate risk score for a project."""
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if not project:
return None
results = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project_id,
ClassificationResult.level_id.isnot(None),
).all()
total_risk = 0.0
total_sensitivity = 0.0
total_exposure = 0.0
total_protection = 0.0
detail_items = []
# Get all active masking rules for quick lookup
rules = db.query(MaskingRule).filter(MaskingRule.is_active == True).all()
rule_level_ids = {r.level_id for r in rules if r.level_id}
rule_cat_ids = {r.category_id for r in rules if r.category_id}
for r in results:
if not r.level:
continue
level_weight = _get_level_weight(r.level.code)
# Exposure: count source connections for the column's table
source_count = 1
if r.column and r.column.table and r.column.table.database:
# Simple: if table exists in multiple dbs (rare), count them
source_count = max(1, len(r.column.table.database.source.databases or []))
exposure_factor = 1 + source_count * 0.2
# Protection: check if masking rule exists for this level/category
has_masking = (r.level_id in rule_level_ids) or (r.category_id in rule_cat_ids)
protection_rate = 0.3 if has_masking else 0.0
item_risk = level_weight * exposure_factor * (1 - protection_rate)
total_risk += item_risk
total_sensitivity += level_weight
total_exposure += exposure_factor
total_protection += protection_rate
detail_items.append({
"column_id": r.column_id,
"column_name": r.column.name if r.column else None,
"level": r.level.code if r.level else None,
"level_weight": level_weight,
"exposure_factor": round(exposure_factor, 2),
"protection_rate": protection_rate,
"item_risk": round(item_risk, 2),
})
# Normalize to 0-100 (heuristic: assume max reasonable raw score is 15 per field)
count = len(detail_items) or 1
max_raw = count * 15
risk_score = min(100, (total_risk / max_raw) * 100) if max_raw > 0 else 0
# Upsert risk assessment
existing = db.query(RiskAssessment).filter(
RiskAssessment.entity_type == "project",
RiskAssessment.entity_id == project_id,
).first()
if existing:
existing.risk_score = round(risk_score, 2)
existing.sensitivity_score = round(total_sensitivity / count, 2)
existing.exposure_score = round(total_exposure / count, 2)
existing.protection_score = round(total_protection / count, 2)
existing.details = {"items": detail_items[:100], "total_items": len(detail_items)}
existing.updated_at = datetime.utcnow()
else:
existing = RiskAssessment(
entity_type="project",
entity_id=project_id,
entity_name=project.name,
risk_score=round(risk_score, 2),
sensitivity_score=round(total_sensitivity / count, 2),
exposure_score=round(total_exposure / count, 2),
protection_score=round(total_protection / count, 2),
details={"items": detail_items[:100], "total_items": len(detail_items)},
)
db.add(existing)
db.commit()
return existing
def calculate_all_projects_risk(db: Session) -> dict:
"""Batch calculate risk for all projects."""
projects = db.query(ClassificationProject).all()
updated = 0
for p in projects:
try:
calculate_project_risk(db, p.id)
updated += 1
except Exception:
pass
return {"updated": updated}
def get_risk_top_n(db: Session, entity_type: str = "project", n: int = 10) -> List[RiskAssessment]:
return (
db.query(RiskAssessment)
.filter(RiskAssessment.entity_type == entity_type)
.order_by(RiskAssessment.risk_score.desc())
.limit(n)
.all()
)