from io import BytesIO from typing import Optional from sqlalchemy.orm import Session from datetime import datetime from docx import Document from docx.shared import Inches, Pt, RGBColor from docx.enum.text import WD_ALIGN_PARAGRAPH from app.models.project import ClassificationProject, ClassificationResult from app.models.classification import Category, DataLevel def generate_classification_report(db: Session, project_id: int) -> bytes: """Generate a Word report for a classification project.""" project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first() if not project: raise ValueError("项目不存在") doc = Document() # Title title = doc.add_heading('数据分类分级项目报告', 0) title.alignment = WD_ALIGN_PARAGRAPH.CENTER # Basic info doc.add_heading('一、项目基本信息', level=1) info_table = doc.add_table(rows=4, cols=2) info_table.style = 'Light Grid Accent 1' info_data = [ ('项目名称', project.name), ('报告生成时间', datetime.now().strftime('%Y-%m-%d %H:%M:%S')), ('项目状态', project.status), ('模板版本', project.template.version if project.template else 'N/A'), ] for i, (k, v) in enumerate(info_data): info_table.rows[i].cells[0].text = k info_table.rows[i].cells[1].text = str(v) # Statistics doc.add_heading('二、分类分级统计', level=1) results = db.query(ClassificationResult).filter(ClassificationResult.project_id == project_id).all() total = len(results) auto_count = sum(1 for r in results if r.source == 'auto') manual_count = sum(1 for r in results if r.source == 'manual') level_stats = {} for r in results: if r.level: level_stats[r.level.name] = level_stats.get(r.level.name, 0) + 1 doc.add_paragraph(f'总字段数: {total}') doc.add_paragraph(f'自动识别: {auto_count}') doc.add_paragraph(f'人工打标: {manual_count}') doc.add_heading('三、分级分布', level=1) level_table = doc.add_table(rows=1, cols=3) level_table.style = 'Light Grid Accent 1' hdr_cells = level_table.rows[0].cells hdr_cells[0].text = '分级' hdr_cells[1].text = '数量' hdr_cells[2].text = '占比' for level_name, count in sorted(level_stats.items(), key=lambda x: -x[1]): row_cells = level_table.add_row().cells row_cells[0].text = level_name row_cells[1].text = str(count) row_cells[2].text = f'{count / total * 100:.1f}%' if total > 0 else '0%' # High risk data doc.add_heading('四、高敏感数据清单(L4/L5)', level=1) high_risk = [r for r in results if r.level and r.level.code in ('L4', 'L5')] if high_risk: risk_table = doc.add_table(rows=1, cols=5) risk_table.style = 'Light Grid Accent 1' hdr = risk_table.rows[0].cells hdr[0].text = '字段名' hdr[1].text = '所属表' hdr[2].text = '分类' hdr[3].text = '分级' hdr[4].text = '来源' for r in high_risk[:100]: # limit to 100 rows row = risk_table.add_row().cells row[0].text = r.column.name if r.column else 'N/A' row[1].text = r.column.table.name if r.column and r.column.table else 'N/A' row[2].text = r.category.name if r.category else 'N/A' row[3].text = r.level.name if r.level else 'N/A' row[4].text = '自动' if r.source == 'auto' else '人工' else: doc.add_paragraph('暂无L4/L5级高敏感数据。') # Save to bytes buffer = BytesIO() doc.save(buffer) buffer.seek(0) return buffer.read() def generate_excel_report(db: Session, project_id: int) -> bytes: """Generate an Excel report for a classification project.""" from openpyxl import Workbook from openpyxl.styles import Font, PatternFill, Alignment, Border, Side from openpyxl.chart import PieChart, Reference from sqlalchemy import func project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first() if not project: raise ValueError("项目不存在") wb = Workbook() ws = wb.active ws.title = "报告概览" # Title ws.merge_cells('A1:D1') ws['A1'] = '数据分类分级项目报告' ws['A1'].font = Font(size=18, bold=True) ws['A1'].alignment = Alignment(horizontal='center') # Basic info ws['A3'] = '项目名称' ws['B3'] = project.name ws['A4'] = '报告生成时间' ws['B4'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S') ws['A5'] = '项目状态' ws['B5'] = project.status ws['A6'] = '模板版本' ws['B6'] = project.template.version if project.template else 'N/A' # Statistics results = db.query(ClassificationResult).filter(ClassificationResult.project_id == project_id).all() total = len(results) auto_count = sum(1 for r in results if r.source == 'auto') manual_count = sum(1 for r in results if r.source == 'manual') ws['A8'] = '总字段数' ws['B8'] = total ws['A9'] = '自动识别' ws['B9'] = auto_count ws['A10'] = '人工打标' ws['B10'] = manual_count # Level distribution ws['A12'] = '分级' ws['B12'] = '数量' ws['C12'] = '占比' ws['A12'].font = Font(bold=True) ws['B12'].font = Font(bold=True) ws['C12'].font = Font(bold=True) level_stats = {} for r in results: if r.level: level_stats[r.level.name] = level_stats.get(r.level.name, 0) + 1 red_fill = PatternFill(start_color='FFCCCC', end_color='FFCCCC', fill_type='solid') row = 13 for level_name, count in sorted(level_stats.items(), key=lambda x: -x[1]): ws.cell(row=row, column=1, value=level_name) ws.cell(row=row, column=2, value=count) pct = f'{count / total * 100:.1f}%' if total > 0 else '0%' ws.cell(row=row, column=3, value=pct) if 'L4' in level_name or 'L5' in level_name: for c in range(1, 4): ws.cell(row=row, column=c).fill = red_fill row += 1 # High risk sheet ws2 = wb.create_sheet("高敏感数据清单") ws2.append(['字段名', '所属表', '分类', '分级', '来源', '置信度']) for cell in ws2[1]: cell.font = Font(bold=True) cell.fill = PatternFill(start_color='DDEBF7', end_color='DDEBF7', fill_type='solid') high_risk = [r for r in results if r.level and r.level.code in ('L4', 'L5')] for r in high_risk[:500]: ws2.append([ r.column.name if r.column else 'N/A', r.column.table.name if r.column and r.column.table else 'N/A', r.category.name if r.category else 'N/A', r.level.name if r.level else 'N/A', '自动' if r.source == 'auto' else '人工', r.confidence, ]) # Auto-fit column widths roughly for ws_sheet in [ws, ws2]: for column in ws_sheet.columns: max_length = 0 column_letter = column[0].column_letter for cell in column: try: if len(str(cell.value)) > max_length: max_length = len(str(cell.value)) except: pass adjusted_width = min(max_length + 2, 50) ws_sheet.column_dimensions[column_letter].width = adjusted_width buffer = BytesIO() wb.save(buffer) buffer.seek(0) return buffer.read() def get_report_summary(db: Session, project_id: int) -> dict: """Get aggregated report data for PDF generation (frontend).""" from sqlalchemy import func project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first() if not project: raise ValueError("项目不存在") results = db.query(ClassificationResult).filter(ClassificationResult.project_id == project_id).all() total = len(results) auto_count = sum(1 for r in results if r.source == 'auto') manual_count = sum(1 for r in results if r.source == 'manual') level_stats = {} for r in results: if r.level: level_stats[r.level.name] = level_stats.get(r.level.name, 0) + 1 high_risk = [] for r in results: if r.level and r.level.code in ('L4', 'L5'): high_risk.append({ "column_name": r.column.name if r.column else 'N/A', "table_name": r.column.table.name if r.column and r.column.table else 'N/A', "category_name": r.category.name if r.category else 'N/A', "level_name": r.level.name if r.level else 'N/A', "source": '自动' if r.source == 'auto' else '人工', "confidence": r.confidence, }) return { "project_name": project.name, "status": project.status, "template_version": project.template.version if project.template else 'N/A', "generated_at": datetime.now().isoformat(), "total": total, "auto": auto_count, "manual": manual_count, "level_distribution": [ {"name": name, "count": count} for name, count in sorted(level_stats.items(), key=lambda x: -x[1]) ], "high_risk": high_risk[:100], }