from typing import Optional from fastapi import APIRouter, Depends, Query from sqlalchemy.orm import Session from sqlalchemy import func from app.core.database import get_db from app.models.user import User from app.models.metadata import DataSource, DataTable, DataColumn from app.models.project import ClassificationResult, ClassificationProject from app.models.classification import Category, DataLevel from app.schemas.common import ResponseModel from app.api.deps import get_current_user router = APIRouter() @router.get("/stats") def get_dashboard_stats( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """Dashboard overview statistics based on real DB data.""" data_sources = db.query(DataSource).count() tables = db.query(DataTable).count() columns = db.query(DataColumn).count() labeled = db.query(ClassificationResult).count() sensitive = db.query(ClassificationResult).join(DataLevel).filter( DataLevel.code.in_(['L4', 'L5']) ).count() projects = db.query(ClassificationProject).count() return ResponseModel(data={ "data_sources": data_sources, "tables": tables, "columns": columns, "labeled": labeled, "sensitive": sensitive, "projects": projects, }) @router.get("/distribution") def get_dashboard_distribution( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """Distribution data for charts based on real DB data.""" # Level distribution level_dist = db.query(DataLevel.name, DataLevel.code, DataLevel.color, func.count(ClassificationResult.id)).\ join(ClassificationResult, DataLevel.id == ClassificationResult.level_id).\ group_by(DataLevel.id).\ order_by(DataLevel.sort_order).all() # Category distribution category_dist = db.query(Category.name, func.count(ClassificationResult.id)).\ join(ClassificationResult, Category.id == ClassificationResult.category_id).\ group_by(Category.id).\ order_by(func.count(ClassificationResult.id).desc()).limit(8).all() # Source distribution source_dist = db.query(ClassificationResult.source, func.count(ClassificationResult.id)).\ group_by(ClassificationResult.source).all() # Project progress projects = db.query(ClassificationProject).all() project_progress = [] for p in projects: total = db.query(ClassificationResult).filter(ClassificationResult.project_id == p.id).count() reviewed = db.query(ClassificationResult).filter( ClassificationResult.project_id == p.id, ClassificationResult.status == 'reviewed', ).count() project_progress.append({ "id": p.id, "name": p.name, "status": p.status, "progress": round(reviewed / total * 100) if total else 0, "planned_end": p.planned_end.isoformat() if p.planned_end else None, }) # Heatmap: source vs level sources = db.query(DataSource).order_by(DataSource.id).limit(8).all() levels = db.query(DataLevel).order_by(DataLevel.sort_order).all() heatmap = [] for si, source in enumerate(sources): for li, level in enumerate(levels): count = db.query(func.count(ClassificationResult.id)).\ join(DataColumn, ClassificationResult.column_id == DataColumn.id).\ join(DataTable, DataColumn.table_id == DataTable.id).\ join(DataSource, DataTable.database_id == DataSource.id).\ filter(DataSource.id == source.id, ClassificationResult.level_id == level.id).scalar() heatmap.append({ "source_name": source.name, "level_code": level.code, "count": count or 0, }) return ResponseModel(data={ "level_distribution": [ {"name": name, "code": code, "color": color, "count": count} for name, code, color, count in level_dist ], "category_distribution": [ {"name": name, "count": count} for name, count in category_dist ], "source_distribution": [ {"source": src, "count": count} for src, count in source_dist ], "project_progress": project_progress, "heatmap": heatmap, })