from datetime import datetime from typing import List, Optional, Set, Tuple from sqlalchemy.orm import Session, joinedload from app.models.compliance import ComplianceRule, ComplianceIssue from app.models.project import ClassificationResult from app.models.masking import MaskingRule def init_builtin_rules(db: Session): """Initialize built-in compliance rules.""" if db.query(ComplianceRule).first(): return rules = [ ComplianceRule(name="L4/L5字段未配置脱敏", standard="dengbao", description="等保2.0要求:四级及以上数据应进行脱敏处理", check_logic="check_masking", severity="high"), ComplianceRule(name="L5字段缺乏加密存储措施", standard="dengbao", description="等保2.0要求:五级数据应加密存储", check_logic="check_encryption", severity="critical"), ComplianceRule(name="个人敏感信息处理未授权", standard="pipl", description="个人信息保护法:处理敏感个人信息应取得单独同意", check_logic="check_level", severity="high"), ComplianceRule(name="数据跨境传输未评估", standard="gdpr", description="GDPR:个人数据跨境传输需进行影响评估", check_logic="check_audit", severity="medium"), ] for r in rules: db.add(r) db.commit() def scan_compliance(db: Session, project_id: Optional[int] = None) -> List[ComplianceIssue]: """Run compliance scan and generate issues.""" rules = db.query(ComplianceRule).filter(ComplianceRule.is_active == True).all() if not rules: return [] # Get masking rules for check_masking logic masking_rules = db.query(MaskingRule).filter(MaskingRule.is_active == True).all() masking_level_ids = {r.level_id for r in masking_rules if r.level_id} # Build result filter and determine project ids result_filter = [ClassificationResult.level_id.isnot(None)] project_ids: List[int] = [] if project_id: result_filter.append(ClassificationResult.project_id == project_id) project_ids = [project_id] else: project_ids = [ r[0] for r in db.query(ClassificationResult.project_id).distinct().all() ] if project_ids: result_filter.append(ClassificationResult.project_id.in_(project_ids)) else: return [] # Pre-load all results with level and column to avoid N+1 queries results = db.query(ClassificationResult).options( joinedload(ClassificationResult.level), joinedload(ClassificationResult.column), ).filter(*result_filter).all() if not results: return [] # Batch query existing open issues existing_issues = db.query(ComplianceIssue).filter( ComplianceIssue.project_id.in_(project_ids), ComplianceIssue.status == "open", ).all() existing_set: Set[Tuple[int, int, str, int]] = { (i.rule_id, i.project_id, i.entity_type, i.entity_id) for i in existing_issues } issues = [] for r in results: if not r.level: continue level_code = r.level.code for rule in rules: matched = False desc = "" suggestion = "" if rule.check_logic == "check_masking" and level_code in ("L4", "L5"): if r.level_id not in masking_level_ids: matched = True desc = f"字段 '{r.column.name if r.column else '未知'}' 为 {level_code} 级,但未配置脱敏规则" suggestion = "请在【数据脱敏】模块为该分级配置脱敏策略" elif rule.check_logic == "check_encryption" and level_code == "L5": # Placeholder: no encryption check in MVP, always flag matched = True desc = f"字段 '{r.column.name if r.column else '未知'}' 为 L5 级核心数据,建议确认是否加密存储" suggestion = "请确认该字段在数据库中已加密存储" elif rule.check_logic == "check_level" and level_code in ("L4", "L5"): if r.source == "auto": matched = True desc = f"个人敏感字段 '{r.column.name if r.column else '未知'}' 目前为自动识别,建议人工复核并确认授权" suggestion = "请人工确认该字段的处理已取得合法授权" elif rule.check_logic == "check_audit": # Placeholder for cross-border check pass if matched: key = (rule.id, r.project_id, "column", r.column_id or 0) if key not in existing_set: issue = ComplianceIssue( rule_id=rule.id, project_id=r.project_id, entity_type="column", entity_id=r.column_id or 0, entity_name=r.column.name if r.column else "未知", severity=rule.severity, description=desc, suggestion=suggestion, ) db.add(issue) issues.append(issue) existing_set.add(key) if issues: db.commit() return issues def list_issues(db: Session, project_id: Optional[int] = None, status: Optional[str] = None, page: int = 1, page_size: int = 20): query = db.query(ComplianceIssue) if project_id: query = query.filter(ComplianceIssue.project_id == project_id) if status: query = query.filter(ComplianceIssue.status == status) total = query.count() items = query.order_by(ComplianceIssue.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all() return items, total def resolve_issue(db: Session, issue_id: int): issue = db.query(ComplianceIssue).filter(ComplianceIssue.id == issue_id).first() if issue: issue.status = "resolved" issue.resolved_at = datetime.utcnow() db.commit() return issue