from typing import List, Optional from sqlalchemy.orm import Session from datetime import datetime from app.models.project import ClassificationProject, ClassificationResult from app.models.classification import DataLevel from app.models.metadata import DataSource, Database, DataTable, DataColumn from app.models.masking import MaskingRule from app.models.risk import RiskAssessment def _get_level_weight(level_code: str) -> int: weights = {"L1": 1, "L2": 2, "L3": 3, "L4": 4, "L5": 5} return weights.get(level_code, 1) def calculate_project_risk(db: Session, project_id: int) -> RiskAssessment: """Calculate risk score for a project.""" project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first() if not project: return None results = db.query(ClassificationResult).filter( ClassificationResult.project_id == project_id, ClassificationResult.level_id.isnot(None), ).all() total_risk = 0.0 total_sensitivity = 0.0 total_exposure = 0.0 total_protection = 0.0 detail_items = [] # Get all active masking rules for quick lookup rules = db.query(MaskingRule).filter(MaskingRule.is_active == True).all() rule_level_ids = {r.level_id for r in rules if r.level_id} rule_cat_ids = {r.category_id for r in rules if r.category_id} for r in results: if not r.level: continue level_weight = _get_level_weight(r.level.code) # Exposure: count source connections for the column's table source_count = 1 if r.column and r.column.table and r.column.table.database: # Simple: if table exists in multiple dbs (rare), count them source_count = max(1, len(r.column.table.database.source.databases or [])) exposure_factor = 1 + source_count * 0.2 # Protection: check if masking rule exists for this level/category has_masking = (r.level_id in rule_level_ids) or (r.category_id in rule_cat_ids) protection_rate = 0.3 if has_masking else 0.0 item_risk = level_weight * exposure_factor * (1 - protection_rate) total_risk += item_risk total_sensitivity += level_weight total_exposure += exposure_factor total_protection += protection_rate detail_items.append({ "column_id": r.column_id, "column_name": r.column.name if r.column else None, "level": r.level.code if r.level else None, "level_weight": level_weight, "exposure_factor": round(exposure_factor, 2), "protection_rate": protection_rate, "item_risk": round(item_risk, 2), }) # Normalize to 0-100 (heuristic: assume max reasonable raw score is 15 per field) count = len(detail_items) or 1 max_raw = count * 15 risk_score = min(100, (total_risk / max_raw) * 100) if max_raw > 0 else 0 # Upsert risk assessment existing = db.query(RiskAssessment).filter( RiskAssessment.entity_type == "project", RiskAssessment.entity_id == project_id, ).first() if existing: existing.risk_score = round(risk_score, 2) existing.sensitivity_score = round(total_sensitivity / count, 2) existing.exposure_score = round(total_exposure / count, 2) existing.protection_score = round(total_protection / count, 2) existing.details = {"items": detail_items[:100], "total_items": len(detail_items)} existing.updated_at = datetime.utcnow() else: existing = RiskAssessment( entity_type="project", entity_id=project_id, entity_name=project.name, risk_score=round(risk_score, 2), sensitivity_score=round(total_sensitivity / count, 2), exposure_score=round(total_exposure / count, 2), protection_score=round(total_protection / count, 2), details={"items": detail_items[:100], "total_items": len(detail_items)}, ) db.add(existing) db.commit() return existing def calculate_all_projects_risk(db: Session) -> dict: """Batch calculate risk for all projects.""" projects = db.query(ClassificationProject).all() updated = 0 for p in projects: try: calculate_project_risk(db, p.id) updated += 1 except Exception: pass return {"updated": updated} def get_risk_top_n(db: Session, entity_type: str = "project", n: int = 10) -> List[RiskAssessment]: return ( db.query(RiskAssessment) .filter(RiskAssessment.entity_type == entity_type) .order_by(RiskAssessment.risk_score.desc()) .limit(n) .all() )