Compare commits

...

10 Commits

Author SHA1 Message Date
hiderfong 34466a1ae9 fix: optimize compliance scan performance and improve error handling
- Refactor scan_compliance to eliminate N+1 queries using joinedload and batch loading
- Add try-except wrapper in compliance scan API endpoint
- Improve frontend axios error interceptor to display detail/message/timeout errors
- Update CORS config and nginx for domain deployment
2026-04-25 21:27:22 +08:00
hiderfong 6d35cfa5b7 chore: nginx配置支持域名访问 datapointer.cnroc.cn
更新server_name为 datapointer.cnroc.cn localhost _,
使Nginx能够正确响应该域名的请求。
2026-04-25 11:13:36 +08:00
hiderfong 9590603621 fix: nginx添加CORS头解决前端白屏问题
前端构建产物中的script/link标签带有crossorigin属性,
浏览器加载这些资源时会进行CORS检查。
补充Access-Control-Allow-Origin响应头以支持跨域资源加载。
2026-04-25 11:04:27 +08:00
hiderfong e7c7f92b69 chore: 补充生产部署配置
- requirements.txt 添加 requests 依赖
- 新增 docker-compose.prod.yml 生产编排文件
- 新增 frontend/Dockerfile.prod 前端生产镜像
- 新增 frontend/nginx.conf 反向代理配置
2026-04-25 10:45:25 +08:00
hiderfong 474e7aa543 docs: 添加DataPointer产品白皮书与功能架构图
- 编写完整功能介绍白皮书(含产品概述、功能详解、技术架构、部署方案)
- 绘制总体功能架构图、数据流向图、安全闭环图、部署架构图、核心业务流程图
2026-04-25 09:34:41 +08:00
hiderfong ddb8cb8471 security: 修改admin密码并移除前端默认账户显示
- 将admin默认密码从admin123修改为Zhidi@n2023
- 更新数据库中admin用户密码哈希
- 更新后端配置、环境变量模板及测试脚本中的密码
- 移除登录页默认管理员账户密码提示文字
- 清空登录表单密码默认值,避免泄露
- 重新构建前端dist产物
2026-04-25 09:05:08 +08:00
hiderfong 6d70520e79 feat: 全量功能模块开发与集成测试修复
- 新增后端模块:Alert、APIAsset、Compliance、Lineage、Masking、Risk、SchemaChange、Unstructured、Watermark
- 新增前端模块页面与API接口
- 新增Alembic迁移脚本(002-014)覆盖全量业务表
- 新增测试数据生成脚本与集成测试脚本
- 修复metadata模型JSON类型导入缺失导致启动失败的问题
- 修复前端Alert/APIAsset页面request模块路径错误
- 更新docker-compose与开发计划文档
2026-04-25 08:51:38 +08:00
hiderfong 8b2bc84399 feat: implement full RBAC role-based access control
Backend:
- deps.py: add require_admin, require_manager, require_labeler, require_guest_or_above
- user.py: all write endpoints require admin
- datasource.py: write/sync endpoints require admin
- metadata.py: sync endpoint requires admin
- classification.py: category/rule write requires admin; results query requires guest+ with data isolation
- project.py: GET requires manager with created_by filtering; DELETE checks ownership
- task.py: my-tasks requires labeler with assignee_id filtering; create-task requires manager
- dashboard.py: requires guest_or_above
- report.py: requires guest_or_above
- project_service: list_projects adds created_by filter; list_results adds project_ids filter

Frontend:
- stores/user.ts: add hasRole, hasAnyRole, isAdmin, isManager, isLabeler, isSuperadmin
- router/index.ts: add roles to route meta; beforeEach checks role permissions
- Layout.vue: filter menu routes by user roles
- System.vue: hide add/edit/delete buttons for non-admins
- DataSource.vue: hide add/edit/delete/sync buttons for non-admins
- Project.vue: hide add/delete buttons for non-admins
2026-04-23 12:09:32 +08:00
hiderfong 377e9cba22 feat: user management - add/edit/delete users
Backend:
- Add GET /users/roles endpoint (list all roles)
- Add GET /users/depts endpoint (list all departments)

Frontend:
- Add user.ts API client (getUsers, createUser, updateUser, deleteUser, getRoles, getDepts)
- Rewrite System.vue with full user CRUD:
  - User list table with pagination
  - Add User button + dialog form
  - Edit user (username disabled, password hidden)
  - Delete user (disabled for superadmin)
  - Role and department dropdowns
  - Status radio buttons
2026-04-23 11:32:45 +08:00
hiderfong 9d38180745 rebrand: PropDataGuard → DataPointer
- App name: PropDataGuard → DataPointer
- Frontend title: 财险数据分级分类平台 → 数据分类分级管理平台
- LocalStorage keys: pdg_token/pdg_refresh → dp_token/dp_refresh
- Package name: prop-data-guard-frontend → data-pointer-frontend
- Project config: admin@propdataguard.comadmin@datapo.com
- Celery app name: prop_data_guard → data_pointer
- Layout logo, login title, page title all updated
2026-04-23 11:26:28 +08:00
136 changed files with 8198 additions and 193 deletions
+162
View File
@@ -0,0 +1,162 @@
# DataPointer 开发计划书
> 版本:v1.0 | 日期:2026-04-23
---
## 一、现状评估
| 模块 | 当前状态 | 关键债务/缺口 |
|------|---------|-------------|
| 数据源管理 | MySQL/PG/Oracle/SQLServer + 达梦(mock) | 达梦未真实支持;密码加密密钥运行时随机生成,重启后无法解密 |
| 元数据采集 | 库/表/字段基础采集 | 全量扫描,缺少增量与 Schema 变更追踪 |
| 分类引擎 | 正则/关键词/枚举规则 | scikit-learn 已引入但未使用;Celery 任务为 placeholder |
| 项目管理 | 创建/分配/打标/发布 | 无 ML 辅助推荐 |
| 报告 | Word 导出 | 无 Excel/PDF;无风险摘要 |
| 安全能力 | 无 | 无脱敏、无水印 |
| 风险管理 | 无 | 无量化评分、无合规对标 |
| 非结构化 | 模型已建(UnstructuredFile | 功能未实现 |
---
## 二、整体里程碑
```
第一阶段(核心引擎加固 + 智能化) 4 周
第二阶段(安全能力补齐 + 体验升级) 5 周
第三阶段(风险管理 + 合规 + 血缘) 6 周
------------------------------------------------------------------------
合计 15 周 约 3.5 个月
```
---
## 三、第一阶段:核心引擎加固与智能化(4 周)
### T1.1 修复数据源密码加密(P0)
- **方案**config.py 新增 DB_ENCRYPTION_KEY,读取环境变量;datasource_service.py 改用该密钥;提供 Alembic 迁移脚本处理历史数据。
- **验收**:重启后历史数据源仍可正常连接;密钥外部注入。
- **工时**2d
### T1.2 Celery 异步分类落地(P0
- **方案**:将 run_auto_classification 逻辑迁入 Celery TaskProject 增加 scan_progress 字段;后端提供 progress 轮询接口;前端增加进度条与后台执行开关。
- **验收**:万级字段分类 HTTP 不阻塞;前端实时显示进度。
- **工时**4d
### T1.3 ML 辅助分类原型(P1
- **方案**:新增 MLModelVersion 模型;对字段 name/comment/sample_data 做 TF-IDF;用 LogisticRegression/RandomForest 训练;提供 ml-suggest 接口与前端一键采纳;训练任务封装为 Celery Task。
- **验收**:测试集 Top-1 准确率 >= 70%;前端展示推荐标签与置信度。
- **工时**8d
### T1.4 语义相似度规则(P1
- **方案**:新增 similarity 规则类型;用 TfidfVectorizer + cosine_similarity 计算字段名/注释与基准词相似度;阈值默认 0.75。
- **验收**mobile_no / cell_phone / contact_tel 可被同一条规则命中。
- **工时**3d
### T1.5 增量元数据采集(P1
- **方案**meta 表增加 last_scanned_at 与 checksum;采集时对比 information_schema 仅同步变更;删除对象做软删除保留历史。
- **验收**:重复采集未变更表不写库;源库新增表仅增量写入。
- **工时**4d
---
## 四、第二阶段:安全能力补齐与体验升级(5 周)
### T2.1 数据静态脱敏(P1
- **方案**:新增 MaskingRule 模型,支持 mask/truncate/hash/generalize/replace;脱敏预览与导出 API;前端策略配置与左右对比预览页。
- **验收**:身份证号/手机号按规则掩码导出;支持批量策略应用。
- **工时**8d
### T2.2 数据水印溯源(P2
- **方案**:文本水印采用零宽空格嵌入用户 ID;WatermarkLog 记录导出信息;溯源 API 提取水印追溯到用户。
- **验收**:导出 CSV 可复制后成功溯源到用户;不影响可读性。
- **工时**5d
### T2.3 Excel + PDF 报告(P1
- **方案**Excel 用 openpyxl 带条件格式与图表;PDF 采用前端 html2canvas + jspdf 方案,减少后端依赖。
- **验收**:支持 Word/Excel/PDF 三种导出;PDF 含统计图与 Top20 敏感清单。
- **工时**4d
### T2.4 达梦真实驱动(P2
- **方案**:优先 dmPythonfallback 用 jaydebeapi + JDBC;更新元数据采集适配达梦系统表。
- **验收**:可连接达梦并采集库表字段元数据。
- **工时**4d
### T2.5 非结构化文件识别(P2)
- **方案**:激活 UnstructuredFile;文件存 MinIO;用 python-docx/openpyxl/pdfplumber 解析文本;送入规则引擎识别敏感信息。
- **验收**Word/Excel 中的身份证号/手机号可被识别并建议 L4。
- **工时**6d
### T2.6 Schema 变更追踪(P2
- **方案**:新增 SchemaChangeLog 模型;增量采集时自动比对生成变更记录;前端数据源详情页展示变更历史。
- **验收**:源库新增敏感字段后平台生成变更记录并标红告警。
- **工时**4d
---
## 五、第三阶段:风险管理与合规(6 周)
### T3.1 风险评分模型(P1
- **方案**RiskScore = sum(Li * exposure * (1 - protection_rate));新增 RiskAssessment 模型四级聚合;Celery Beat 每日重算;Dashboard 增加风险趋势与排行。
- **验收**:项目生成 0-100 风险分;未脱敏敏感字段增加时分数上升。
- **工时**7d
### T3.2 合规检查引擎(P1
- **方案**:内置等保/PIPL/GDPR 检查规则;ComplianceChecker 可插拔基类;新增 ComplianceScan/ComplianceIssue;前端规则库与问题清单。
- **验收**:自动扫描出 L5 未脱敏等不合规项;可导出合规差距分析。
- **工时**7d
### T3.3 数据血缘分析(P2
- **方案**:引入 sqlparse 解析 SQL;新增 DataLineage 模型;前端 ECharts 关系图展示表级血缘,支持 3 层展开。
- **验收**:典型 ETL SQL 可正确构建血缘链。
- **工时**8d
### T3.4 风险告警与工单(P1
- **方案**AlertRule 模型配置触发条件;AlertRecord 记录告警;WorkOrder 简易工单流转(open -> in_progress -> resolved);站内消息中心。
- **验收**:新增 5 个 L5 字段自动生成告警并转工单指派。
- **工时**6d
### T3.5 API 资产扫描(P2
- **方案**:新增 ApiAsset 模型;上传 Swagger/OpenAPI 解析参数与响应 schema;规则引擎标记敏感接口。
- **验收**:上传含 phone/idCard 的 Swagger 后标记为暴露 L4 数据。
- **工时**5d
### T3.6 暗黑模式与性能优化(P2)
- **方案**Element Plus 动态主题 + Pinia 持久化;大表格虚拟滚动;路由懒加载。
- **验收**:一键切换暗黑/明亮无闪烁;5 万字段页面滚动 >= 30fps。
- **工时**4d
---
## 六、任务总览
| 编号 | 任务 | 阶段 | 优先级 | 工时 | 依赖 |
|------|------|------|--------|------|------|
| T1.1 | 修复密码加密密钥管理 | P1 | P0 | 2d | 无 |
| T1.2 | Celery 异步分类落地 | P1 | P0 | 4d | T1.1 |
| T1.3 | ML 辅助分类原型 | P1 | P1 | 8d | T1.2 |
| T1.4 | 语义相似度规则 | P1 | P1 | 3d | 无 |
| T1.5 | 增量元数据采集 | P1 | P1 | 4d | 无 |
| T2.1 | 数据静态脱敏 | P2 | P1 | 8d | T1.5 |
| T2.2 | 数据水印溯源 | P2 | P2 | 5d | T2.1 |
| T2.3 | Excel + PDF 报告 | P2 | P1 | 4d | 无 |
| T2.4 | 达梦真实驱动 | P2 | P2 | 4d | 无 |
| T2.5 | 非结构化文件识别 | P2 | P2 | 6d | T1.2 |
| T2.6 | Schema 变更追踪 | P2 | P2 | 4d | T1.5 |
| T3.1 | 风险评分模型 | P3 | P1 | 7d | T2.1, T2.6 |
| T3.2 | 合规检查引擎 | P3 | P1 | 7d | T3.1 |
| T3.3 | 数据血缘分析 | P3 | P2 | 8d | 无 |
| T3.4 | 风险告警与工单 | P3 | P1 | 6d | T3.1, T3.2 |
| T3.5 | API 资产扫描 | P3 | P2 | 5d | T1.4 |
| T3.6 | 暗黑模式与性能优化 | P3 | P2 | 4d | 无 |
**总计约 89 人天(单人约 3.5 个月;双人并行可压缩至 2 个月)**
---
## 七、确认事项
1. 三阶段范围:是否全部 17 项任务均需开发?有无可削减项?
2. 达梦环境:是否有真实达梦环境联调?若无,是否接受 jaydebeapi 桥接方案?
3. ML 训练数据:当前人工标注字段大约多少?若不足,是否构造模拟数据?
4. 启动顺序:是否从 T1.1 开始依次执行,还是允许阶段间少量并行?
+29
View File
@@ -0,0 +1,29 @@
# Database
DATABASE_URL=postgresql+psycopg2://pdg:pdg_secret_2024@localhost:5432/prop_data_guard
REDIS_URL=redis://localhost:6379/0
# Security
SECRET_KEY=prop-data-guard-super-secret-key-change-in-production
# Fernet-compatible encryption key for database passwords (32 bytes, base64 url-safe).
# Generate one with: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
# IMPORTANT: Keep this key safe and consistent across all backend instances.
DB_ENCRYPTION_KEY=
# MinIO
MINIO_ENDPOINT=localhost:9000
MINIO_ACCESS_KEY=pdgminio
MINIO_SECRET_KEY=pdgminio_secret_2024
MINIO_SECURE=false
MINIO_BUCKET_NAME=pdg-files
# CORS
CORS_ORIGINS=["http://localhost:5173", "http://127.0.0.1:5173"]
# Auth
ACCESS_TOKEN_EXPIRE_MINUTES=30
REFRESH_TOKEN_EXPIRE_DAYS=7
# Default superuser (created on first startup)
FIRST_SUPERUSER_USERNAME=admin
FIRST_SUPERUSER_PASSWORD=Zhidi@n2023
FIRST_SUPERUSER_EMAIL=admin@datapo.com
@@ -0,0 +1,41 @@
"""Fix datasource password encryption stability
Revision ID: 002
Revises: 001
Create Date: 2026-04-23 14:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "002"
down_revision: Union[str, None] = "001"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Historical encrypted_password values are irrecoverable because
# the old implementation generated a random Fernet key on every startup.
# We clear the passwords and mark sources as inactive so admins re-enter them
# with the new stable key derived from DB_ENCRYPTION_KEY / SECRET_KEY.
op.add_column(
"data_source",
sa.Column("password_reset_required", sa.Boolean(), nullable=False, server_default=sa.text("false")),
)
op.execute(
"""
UPDATE data_source
SET encrypted_password = NULL,
status = 'inactive',
password_reset_required = true
WHERE encrypted_password IS NOT NULL
"""
)
def downgrade() -> None:
op.drop_column("data_source", "password_reset_required")
@@ -0,0 +1,27 @@
"""Add celery_task_id and scan_progress to classification_project
Revision ID: 003
Revises: 002
Create Date: 2026-04-23 14:30:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "003"
down_revision: Union[str, None] = "002"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column("classification_project", sa.Column("celery_task_id", sa.String(100), nullable=True))
op.add_column("classification_project", sa.Column("scan_progress", sa.Text(), nullable=True))
def downgrade() -> None:
op.drop_column("classification_project", "scan_progress")
op.drop_column("classification_project", "celery_task_id")
@@ -0,0 +1,37 @@
"""Add ml_model_version table
Revision ID: 004
Revises: 003
Create Date: 2026-04-23 15:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "004"
down_revision: Union[str, None] = "003"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"ml_model_version",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("name", sa.String(100), nullable=False),
sa.Column("model_path", sa.String(500), nullable=False),
sa.Column("vectorizer_path", sa.String(500), nullable=False),
sa.Column("accuracy", sa.Float(), default=0.0),
sa.Column("train_samples", sa.Integer(), default=0),
sa.Column("train_date", sa.DateTime(), default=sa.func.now()),
sa.Column("is_active", sa.Boolean(), default=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
)
def downgrade() -> None:
op.drop_table("ml_model_version")
@@ -0,0 +1,33 @@
"""Add incremental scan fields to meta tables
Revision ID: 005
Revises: 004
Create Date: 2026-04-23 16:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "005"
down_revision: Union[str, None] = "004"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
for table in ["meta_database", "meta_table", "meta_column"]:
op.add_column(table, sa.Column("last_scanned_at", sa.DateTime(), nullable=True))
op.add_column(table, sa.Column("checksum", sa.String(64), nullable=True))
op.add_column(table, sa.Column("is_deleted", sa.Boolean(), nullable=False, server_default=sa.text("false")))
op.add_column(table, sa.Column("deleted_at", sa.DateTime(), nullable=True))
def downgrade() -> None:
for table in ["meta_database", "meta_table", "meta_column"]:
op.drop_column(table, "deleted_at")
op.drop_column(table, "is_deleted")
op.drop_column(table, "checksum")
op.drop_column(table, "last_scanned_at")
@@ -0,0 +1,37 @@
"""Add masking_rule table
Revision ID: 006
Revises: 005
Create Date: 2026-04-23 17:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "006"
down_revision: Union[str, None] = "005"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"masking_rule",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("name", sa.String(100), nullable=False),
sa.Column("level_id", sa.Integer(), sa.ForeignKey("data_level.id"), nullable=True),
sa.Column("category_id", sa.Integer(), sa.ForeignKey("category.id"), nullable=True),
sa.Column("algorithm", sa.String(20), nullable=False),
sa.Column("params", sa.JSON(), nullable=True),
sa.Column("is_active", sa.Boolean(), default=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(), default=sa.func.now(), onupdate=sa.func.now()),
)
def downgrade() -> None:
op.drop_table("masking_rule")
@@ -0,0 +1,33 @@
"""Add watermark_log table
Revision ID: 007
Revises: 006
Create Date: 2026-04-23 18:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "007"
down_revision: Union[str, None] = "006"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"watermark_log",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("user_id", sa.Integer(), sa.ForeignKey("sys_user.id"), nullable=False),
sa.Column("export_type", sa.String(20), default="csv"),
sa.Column("data_scope", sa.Text(), nullable=True),
sa.Column("watermark_key", sa.String(64), nullable=False),
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
)
def downgrade() -> None:
op.drop_table("watermark_log")
@@ -0,0 +1,25 @@
"""Add analysis_result to unstructured_file
Revision ID: 008
Revises: 007
Create Date: 2026-04-23 19:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "008"
down_revision: Union[str, None] = "007"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column("unstructured_file", sa.Column("analysis_result", sa.JSON(), nullable=True))
def downgrade() -> None:
op.drop_column("unstructured_file", "analysis_result")
@@ -0,0 +1,36 @@
"""Add schema_change_log table
Revision ID: 009
Revises: 008
Create Date: 2026-04-23 20:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "009"
down_revision: Union[str, None] = "008"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"schema_change_log",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("source_id", sa.Integer(), sa.ForeignKey("data_source.id"), nullable=False),
sa.Column("database_id", sa.Integer(), sa.ForeignKey("meta_database.id"), nullable=True),
sa.Column("table_id", sa.Integer(), sa.ForeignKey("meta_table.id"), nullable=True),
sa.Column("column_id", sa.Integer(), sa.ForeignKey("meta_column.id"), nullable=True),
sa.Column("change_type", sa.String(20), nullable=False),
sa.Column("old_value", sa.Text(), nullable=True),
sa.Column("new_value", sa.Text(), nullable=True),
sa.Column("detected_at", sa.DateTime(), default=sa.func.now()),
)
def downgrade() -> None:
op.drop_table("schema_change_log")
@@ -0,0 +1,40 @@
"""Add risk_assessment table
Revision ID: 010
Revises: 009
Create Date: 2026-04-23 21:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "010"
down_revision: Union[str, None] = "009"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"risk_assessment",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("entity_type", sa.String(20), nullable=False),
sa.Column("entity_id", sa.Integer(), nullable=False),
sa.Column("entity_name", sa.String(200), nullable=True),
sa.Column("risk_score", sa.Float(), default=0.0),
sa.Column("sensitivity_score", sa.Float(), default=0.0),
sa.Column("exposure_score", sa.Float(), default=0.0),
sa.Column("protection_score", sa.Float(), default=0.0),
sa.Column("details", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(), default=sa.func.now(), onupdate=sa.func.now()),
)
op.create_index("idx_risk_entity", "risk_assessment", ["entity_type", "entity_id"])
def downgrade() -> None:
op.drop_index("idx_risk_entity", table_name="risk_assessment")
op.drop_table("risk_assessment")
@@ -0,0 +1,51 @@
"""Add compliance_rule and compliance_issue tables
Revision ID: 011
Revises: 010
Create Date: 2026-04-23 22:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "011"
down_revision: Union[str, None] = "010"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"compliance_rule",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("name", sa.String(200), nullable=False),
sa.Column("standard", sa.String(50), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("check_logic", sa.String(50), nullable=False),
sa.Column("severity", sa.String(20), default="medium"),
sa.Column("is_active", sa.Boolean(), default=True),
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
)
op.create_table(
"compliance_issue",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("rule_id", sa.Integer(), nullable=False),
sa.Column("project_id", sa.Integer(), nullable=True),
sa.Column("entity_type", sa.String(20), nullable=False),
sa.Column("entity_id", sa.Integer(), nullable=False),
sa.Column("entity_name", sa.String(200), nullable=True),
sa.Column("severity", sa.String(20), default="medium"),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("suggestion", sa.Text(), nullable=True),
sa.Column("status", sa.String(20), default="open"),
sa.Column("resolved_at", sa.DateTime(), nullable=True),
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
)
def downgrade() -> None:
op.drop_table("compliance_issue")
op.drop_table("compliance_rule")
@@ -0,0 +1,35 @@
"""Add data_lineage table
Revision ID: 012
Revises: 011
Create Date: 2026-04-23 23:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "012"
down_revision: Union[str, None] = "011"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"data_lineage",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("source_table", sa.String(200), nullable=False),
sa.Column("source_column", sa.String(200), nullable=True),
sa.Column("target_table", sa.String(200), nullable=False),
sa.Column("target_column", sa.String(200), nullable=True),
sa.Column("relation_type", sa.String(20), default="direct"),
sa.Column("script_content", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
)
def downgrade() -> None:
op.drop_table("data_lineage")
@@ -0,0 +1,58 @@
"""Add alert and work_order tables
Revision ID: 013
Revises: 012
Create Date: 2026-04-24 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "013"
down_revision: Union[str, None] = "012"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"alert_rule",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("name", sa.String(200), nullable=False),
sa.Column("trigger_condition", sa.String(50), nullable=False),
sa.Column("threshold", sa.Integer(), default=0),
sa.Column("severity", sa.String(20), default="medium"),
sa.Column("is_active", sa.Boolean(), default=True),
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
)
op.create_table(
"alert_record",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("rule_id", sa.Integer(), sa.ForeignKey("alert_rule.id"), nullable=False),
sa.Column("title", sa.String(200), nullable=False),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("severity", sa.String(20), default="medium"),
sa.Column("status", sa.String(20), default="open"),
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
)
op.create_table(
"work_order",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("alert_id", sa.Integer(), sa.ForeignKey("alert_record.id"), nullable=True),
sa.Column("title", sa.String(200), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("assignee_id", sa.Integer(), sa.ForeignKey("sys_user.id"), nullable=True),
sa.Column("status", sa.String(20), default="open"),
sa.Column("resolution", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
sa.Column("resolved_at", sa.DateTime(), nullable=True),
)
def downgrade() -> None:
op.drop_table("work_order")
op.drop_table("alert_record")
op.drop_table("alert_rule")
@@ -0,0 +1,55 @@
"""Add api_asset and api_endpoint tables
Revision ID: 014
Revises: 013
Create Date: 2026-04-24 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "014"
down_revision: Union[str, None] = "013"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"api_asset",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("name", sa.String(200), nullable=False),
sa.Column("base_url", sa.String(500), nullable=False),
sa.Column("swagger_url", sa.String(500), nullable=True),
sa.Column("auth_type", sa.String(50), default="none"),
sa.Column("headers", sa.JSON(), default=dict),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("scan_status", sa.String(20), default="idle"),
sa.Column("total_endpoints", sa.Integer(), default=0),
sa.Column("sensitive_endpoints", sa.Integer(), default=0),
sa.Column("created_by", sa.Integer(), sa.ForeignKey("sys_user.id"), nullable=True),
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(), default=sa.func.now(), onupdate=sa.func.now()),
)
op.create_table(
"api_endpoint",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("asset_id", sa.Integer(), sa.ForeignKey("api_asset.id"), nullable=False),
sa.Column("method", sa.String(10), nullable=False),
sa.Column("path", sa.String(500), nullable=False),
sa.Column("summary", sa.String(500), nullable=True),
sa.Column("tags", sa.JSON(), default=list),
sa.Column("parameters", sa.JSON(), default=list),
sa.Column("request_body_schema", sa.JSON(), nullable=True),
sa.Column("response_schema", sa.JSON(), nullable=True),
sa.Column("sensitive_fields", sa.JSON(), default=list),
sa.Column("risk_level", sa.String(20), default="low"),
sa.Column("is_active", sa.Boolean(), default=True),
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
)
def downgrade() -> None:
op.drop_table("api_endpoint")
op.drop_table("api_asset")
+83 -1
View File
@@ -4,7 +4,7 @@ from jose import JWTError
from app.core.database import get_db from app.core.database import get_db
from app.core.security import decode_token from app.core.security import decode_token
from app.models.user import User from app.models.user import User, Role
from app.services import user_service from app.services import user_service
@@ -47,3 +47,85 @@ def get_current_active_user(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> User: ) -> User:
return current_user return current_user
# ===================== RBAC Dependencies =====================
# Role code constants
ROLE_SUPERADMIN = "superadmin"
ROLE_ADMIN = "admin"
ROLE_PROJECT_MANAGER = "project_manager"
ROLE_LABELER = "labeler"
ROLE_REVIEWER = "reviewer"
ROLE_GUEST = "guest"
# Role hierarchy (higher index = more permissions)
ROLE_LEVELS = {
ROLE_GUEST: 0,
ROLE_LABELER: 1,
ROLE_REVIEWER: 1,
ROLE_PROJECT_MANAGER: 2,
ROLE_ADMIN: 3,
ROLE_SUPERADMIN: 4,
}
def _get_user_role_codes(user: User) -> list[str]:
"""Get list of role codes for a user."""
return [r.code for r in user.roles]
def _has_role(user: User, role_code: str) -> bool:
"""Check if user has a specific role."""
return role_code in _get_user_role_codes(user)
def _has_any_role(user: User, role_codes: list[str]) -> bool:
"""Check if user has any of the specified roles."""
user_roles = _get_user_role_codes(user)
return any(r in user_roles for r in role_codes)
def _is_admin(user: User) -> bool:
"""Check if user is admin or superadmin."""
return user.is_superuser or _has_any_role(user, [ROLE_SUPERADMIN, ROLE_ADMIN])
def require_admin(current_user: User = Depends(get_current_user)) -> User:
"""Require admin or superadmin role."""
if not _is_admin(current_user):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要管理员权限",
)
return current_user
def require_manager(current_user: User = Depends(get_current_user)) -> User:
"""Require project manager or above."""
if _is_admin(current_user):
return current_user
if _has_role(current_user, ROLE_PROJECT_MANAGER):
return current_user
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要项目负责人及以上权限",
)
def require_labeler(current_user: User = Depends(get_current_user)) -> User:
"""Require labeler or above (excluding guest)."""
if _is_admin(current_user):
return current_user
allowed = [ROLE_PROJECT_MANAGER, ROLE_LABELER, ROLE_REVIEWER]
if _has_any_role(current_user, allowed):
return current_user
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="权限不足",
)
def require_guest_or_above(current_user: User = Depends(get_current_user)) -> User:
"""Any authenticated user (including guest)."""
return current_user
+10 -1
View File
@@ -1,6 +1,6 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.api.v1 import auth, user, datasource, metadata, classification, project, task, report, dashboard from app.api.v1 import auth, user, datasource, metadata, classification, project, task, report, dashboard, masking, watermark, unstructured, schema_change, risk, compliance, lineage, alert, api_asset
api_router = APIRouter() api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["认证"]) api_router.include_router(auth.router, prefix="/auth", tags=["认证"])
@@ -12,3 +12,12 @@ api_router.include_router(project.router, prefix="/projects", tags=["项目管
api_router.include_router(task.router, prefix="/tasks", tags=["任务管理"]) api_router.include_router(task.router, prefix="/tasks", tags=["任务管理"])
api_router.include_router(report.router, prefix="/reports", tags=["报告管理"]) api_router.include_router(report.router, prefix="/reports", tags=["报告管理"])
api_router.include_router(dashboard.router, prefix="/dashboard", tags=["仪表盘"]) api_router.include_router(dashboard.router, prefix="/dashboard", tags=["仪表盘"])
api_router.include_router(masking.router, prefix="/masking", tags=["数据脱敏"])
api_router.include_router(watermark.router, prefix="/watermark", tags=["数据水印"])
api_router.include_router(unstructured.router, prefix="/unstructured", tags=["非结构化文件"])
api_router.include_router(schema_change.router, prefix="/schema-changes", tags=["Schema变更"])
api_router.include_router(risk.router, prefix="/risk", tags=["风险评估"])
api_router.include_router(compliance.router, prefix="/compliance", tags=["合规检查"])
api_router.include_router(lineage.router, prefix="/lineage", tags=["数据血缘"])
api_router.include_router(alert.router, prefix="/alerts", tags=["告警与工单"])
api_router.include_router(api_asset.router, prefix="/api-assets", tags=["API资产"])
+115
View File
@@ -0,0 +1,115 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse
from app.services import alert_service
from app.api.deps import get_current_user, require_admin
router = APIRouter()
@router.post("/init-rules")
def init_alert_rules(
db: Session = Depends(get_db),
current_user: User = Depends(require_admin),
):
alert_service.init_builtin_alert_rules(db)
return ResponseModel(message="初始化完成")
@router.post("/check")
def check_alerts(
db: Session = Depends(get_db),
current_user: User = Depends(require_admin),
):
records = alert_service.check_alerts(db)
return ResponseModel(data={"alerts_created": len(records)})
@router.get("/records")
def list_alert_records(
status: Optional[str] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
query = db.query(alert_service.AlertRecord)
if status:
query = query.filter(alert_service.AlertRecord.status == status)
total = query.count()
items = query.order_by(alert_service.AlertRecord.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return ListResponse(
data=[{
"id": r.id,
"rule_id": r.rule_id,
"title": r.title,
"content": r.content,
"severity": r.severity,
"status": r.status,
"created_at": r.created_at.isoformat() if r.created_at else None,
} for r in items],
total=total,
page=page,
page_size=page_size,
)
@router.post("/work-orders")
def create_work_order(
alert_id: int,
title: str,
description: str = "",
assignee_id: Optional[int] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
wo = alert_service.create_work_order(db, alert_id, title, description, assignee_id)
return ResponseModel(data={"id": wo.id})
@router.get("/work-orders")
def list_work_orders(
status: Optional[str] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from app.models.alert import WorkOrder
query = db.query(WorkOrder)
if status:
query = query.filter(WorkOrder.status == status)
total = query.count()
items = query.order_by(WorkOrder.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return ListResponse(
data=[{
"id": w.id,
"alert_id": w.alert_id,
"title": w.title,
"status": w.status,
"assignee_name": w.assignee.username if w.assignee else None,
"created_at": w.created_at.isoformat() if w.created_at else None,
} for w in items],
total=total,
page=page,
page_size=page_size,
)
@router.post("/work-orders/{wo_id}/status")
def update_work_order(
wo_id: int,
status: str,
resolution: str = "",
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
wo = alert_service.update_work_order_status(db, wo_id, status, resolution or None)
if not wo:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="工单不存在")
return ResponseModel(data={"id": wo.id, "status": wo.status})
+131
View File
@@ -0,0 +1,131 @@
from typing import Optional, List
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from pydantic import BaseModel
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse
from app.services import api_asset_service
from app.api.deps import get_current_user
router = APIRouter()
class APIAssetCreate(BaseModel):
name: str
base_url: str
swagger_url: Optional[str] = None
auth_type: Optional[str] = "none"
headers: Optional[dict] = None
description: Optional[str] = None
class APIAssetUpdate(BaseModel):
name: Optional[str] = None
base_url: Optional[str] = None
swagger_url: Optional[str] = None
auth_type: Optional[str] = None
headers: Optional[dict] = None
description: Optional[str] = None
@router.post("")
def create_asset(
body: APIAssetCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
asset = api_asset_service.create_asset(db, body.dict(), current_user.id)
return ResponseModel(data={"id": asset.id})
@router.get("")
def list_assets(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from app.models.api_asset import APIAsset
query = db.query(APIAsset)
total = query.count()
items = query.order_by(APIAsset.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return ListResponse(
data=[{
"id": a.id,
"name": a.name,
"base_url": a.base_url,
"swagger_url": a.swagger_url,
"auth_type": a.auth_type,
"scan_status": a.scan_status,
"total_endpoints": a.total_endpoints,
"sensitive_endpoints": a.sensitive_endpoints,
"created_at": a.created_at.isoformat() if a.created_at else None,
} for a in items],
total=total,
page=page,
page_size=page_size,
)
@router.put("/{asset_id}")
def update_asset(
asset_id: int,
body: APIAssetUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
asset = api_asset_service.update_asset(db, asset_id, body.dict(exclude_unset=True))
if not asset:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="资产不存在")
return ResponseModel(data={"id": asset.id})
@router.delete("/{asset_id}")
def delete_asset(
asset_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
ok = api_asset_service.delete_asset(db, asset_id)
if not ok:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="资产不存在")
return ResponseModel()
@router.post("/{asset_id}/scan")
def scan_asset(
asset_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
result = api_asset_service.scan_swagger(db, asset_id)
return ResponseModel(data=result)
@router.get("/{asset_id}/endpoints")
def list_endpoints(
asset_id: int,
risk_level: Optional[str] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from app.models.api_asset import APIEndpoint
query = db.query(APIEndpoint).filter(APIEndpoint.asset_id == asset_id)
if risk_level:
query = query.filter(APIEndpoint.risk_level == risk_level)
total = query.count()
items = query.order_by(APIEndpoint.id.asc()).offset((page - 1) * page_size).limit(page_size).all()
return ListResponse(
data=[{
"id": e.id,
"method": e.method,
"path": e.path,
"summary": e.summary,
"tags": e.tags,
"parameters": e.parameters,
"sensitive_fields": e.sensitive_fields,
"risk_level": e.risk_level,
"is_active": e.is_active,
} for e in items],
total=total,
page=page,
page_size=page_size,
)
+81 -10
View File
@@ -11,7 +11,7 @@ from app.schemas.classification import (
) )
from app.schemas.common import ResponseModel, ListResponse from app.schemas.common import ResponseModel, ListResponse
from app.services import classification_service, classification_engine from app.services import classification_service, classification_engine
from app.api.deps import get_current_user from app.api.deps import get_current_user, require_admin, require_labeler, require_guest_or_above, _is_admin, ROLE_PROJECT_MANAGER
router = APIRouter() router = APIRouter()
@@ -39,7 +39,7 @@ def list_categories(
def create_category( def create_category(
req: CategoryCreate, req: CategoryCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_admin),
): ):
item = classification_service.create_category(db, req) item = classification_service.create_category(db, req)
return ResponseModel(data=CategoryOut.model_validate(item)) return ResponseModel(data=CategoryOut.model_validate(item))
@@ -50,7 +50,7 @@ def update_category(
category_id: int, category_id: int,
req: CategoryUpdate, req: CategoryUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_admin),
): ):
db_obj = classification_service.get_category(db, category_id) db_obj = classification_service.get_category(db, category_id)
if not db_obj: if not db_obj:
@@ -64,7 +64,7 @@ def update_category(
def delete_category( def delete_category(
category_id: int, category_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_admin),
): ):
classification_service.delete_category(db, category_id) classification_service.delete_category(db, category_id)
return ResponseModel(message="删除成功") return ResponseModel(message="删除成功")
@@ -103,7 +103,7 @@ def list_rules(
def create_rule( def create_rule(
req: RecognitionRuleCreate, req: RecognitionRuleCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_admin),
): ):
item = classification_service.create_rule(db, req) item = classification_service.create_rule(db, req)
data = RecognitionRuleOut.model_validate(item) data = RecognitionRuleOut.model_validate(item)
@@ -118,7 +118,7 @@ def update_rule(
rule_id: int, rule_id: int,
req: RecognitionRuleUpdate, req: RecognitionRuleUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_admin),
): ):
db_obj = classification_service.get_rule(db, rule_id) db_obj = classification_service.get_rule(db, rule_id)
if not db_obj: if not db_obj:
@@ -136,7 +136,7 @@ def update_rule(
def delete_rule( def delete_rule(
rule_id: int, rule_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_admin),
): ):
classification_service.delete_rule(db, rule_id) classification_service.delete_rule(db, rule_id)
return ResponseModel(message="删除成功") return ResponseModel(message="删除成功")
@@ -159,10 +159,42 @@ def list_results(
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500), page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_guest_or_above),
): ):
from app.services.project_service import list_results as _list_results from app.services.project_service import list_results as _list_results
items, total = _list_results(db, project_id=project_id, keyword=keyword, page=page, page_size=page_size) from app.models.project import ClassificationProject, ClassificationTask
from app.api.deps import _has_role
# Data isolation: compute allowed project IDs for non-admin users
allowed_project_ids = None
if not _is_admin(current_user):
if _has_role(current_user, ROLE_PROJECT_MANAGER):
# Project managers see their own projects
allowed_project_ids = [
p.id for p in db.query(ClassificationProject.id).filter(
ClassificationProject.created_by == current_user.id
).all()
]
elif _has_role(current_user, 'labeler') or _has_role(current_user, 'reviewer'):
# Labelers/reviewers see projects where they have tasks
task_projects = db.query(ClassificationTask.project_id).filter(
ClassificationTask.assignee_id == current_user.id
).distinct().all()
allowed_project_ids = [p[0] for p in task_projects]
else:
# Guests see all results (read-only)
allowed_project_ids = None
if allowed_project_ids is not None and not allowed_project_ids:
allowed_project_ids = []
# If a specific project_id is requested, check permission
if project_id and allowed_project_ids is not None and project_id not in allowed_project_ids:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权查看此项目")
items, total = _list_results(
db, project_id=project_id, keyword=keyword, page=page, page_size=page_size,
project_ids=allowed_project_ids,
)
data = [] data = []
for r in items: for r in items:
@@ -171,7 +203,6 @@ def list_results(
database = table.database if table else None database = table.database if table else None
source = database.source if database else None source = database.source if database else None
# Filter by level_id if specified
if level_id and r.level_id != level_id: if level_id and r.level_id != level_id:
continue continue
@@ -207,3 +238,43 @@ def auto_classify(
): ):
result = classification_engine.run_auto_classification(db, project_id) result = classification_engine.run_auto_classification(db, project_id)
return ResponseModel(data=result) return ResponseModel(data=result)
@router.post("/ml-train")
def ml_train(
background: bool = True,
model_name: Optional[str] = None,
algorithm: str = "logistic_regression",
db: Session = Depends(get_db),
current_user: User = Depends(require_admin),
):
from app.tasks.ml_tasks import train_ml_model_task
from app.services.ml_service import train_model
if background:
task = train_ml_model_task.delay(model_name=model_name, algorithm=algorithm)
return ResponseModel(data={"task_id": task.id, "status": task.state})
else:
mv = train_model(db, model_name=model_name, algorithm=algorithm)
if mv:
return ResponseModel(data={"model_id": mv.id, "accuracy": mv.accuracy, "train_samples": mv.train_samples})
return ResponseModel(message="训练失败:样本不足或发生错误")
@router.get("/ml-suggest/{project_id}")
def ml_suggest(
project_id: int,
column_ids: Optional[str] = Query(None),
top_k: int = Query(3, ge=1, le=5),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from app.services.ml_service import suggest_for_project_columns
ids = None
if column_ids:
ids = [int(x) for x in column_ids.split(",") if x.strip().isdigit()]
result = suggest_for_project_columns(db, project_id, column_ids=ids, top_k=top_k)
if not result.get("success"):
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result.get("message"))
return ResponseModel(data=result["suggestions"])
+79
View File
@@ -0,0 +1,79 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query, HTTPException, status
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse
from app.services import compliance_service
from app.api.deps import get_current_user, require_admin
router = APIRouter()
@router.post("/init-rules")
def init_rules(
db: Session = Depends(get_db),
current_user: User = Depends(require_admin),
):
compliance_service.init_builtin_rules(db)
return ResponseModel(message="初始化完成")
@router.post("/scan")
def scan_compliance(
project_id: Optional[int] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
try:
issues = compliance_service.scan_compliance(db, project_id=project_id)
return ResponseModel(data={"issues_found": len(issues)})
except Exception:
import logging
logging.exception("Compliance scan failed")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="扫描执行失败,请稍后重试"
)
@router.get("/issues")
def list_issues(
project_id: Optional[int] = Query(None),
status: Optional[str] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items, total = compliance_service.list_issues(db, project_id=project_id, status=status, page=page, page_size=page_size)
return ListResponse(
data=[{
"id": i.id,
"rule_id": i.rule_id,
"project_id": i.project_id,
"entity_type": i.entity_type,
"entity_name": i.entity_name,
"severity": i.severity,
"description": i.description,
"suggestion": i.suggestion,
"status": i.status,
"created_at": i.created_at.isoformat() if i.created_at else None,
} for i in items],
total=total,
page=page,
page_size=page_size,
)
@router.post("/issues/{issue_id}/resolve")
def resolve_issue(
issue_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
issue = compliance_service.resolve_issue(db, issue_id)
if not issue:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="问题不存在")
return ResponseModel(message="已标记为已解决")
+3 -3
View File
@@ -9,7 +9,7 @@ from app.models.metadata import DataSource, DataTable, DataColumn
from app.models.project import ClassificationResult, ClassificationProject from app.models.project import ClassificationResult, ClassificationProject
from app.models.classification import Category, DataLevel from app.models.classification import Category, DataLevel
from app.schemas.common import ResponseModel from app.schemas.common import ResponseModel
from app.api.deps import get_current_user from app.api.deps import get_current_user, require_guest_or_above, _is_admin
router = APIRouter() router = APIRouter()
@@ -17,7 +17,7 @@ router = APIRouter()
@router.get("/stats") @router.get("/stats")
def get_dashboard_stats( def get_dashboard_stats(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_guest_or_above),
): ):
"""Dashboard overview statistics based on real DB data.""" """Dashboard overview statistics based on real DB data."""
data_sources = db.query(DataSource).count() data_sources = db.query(DataSource).count()
@@ -42,7 +42,7 @@ def get_dashboard_stats(
@router.get("/distribution") @router.get("/distribution")
def get_dashboard_distribution( def get_dashboard_distribution(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_guest_or_above),
): ):
"""Distribution data for charts based on real DB data.""" """Distribution data for charts based on real DB data."""
# Level distribution # Level distribution
+4 -4
View File
@@ -7,7 +7,7 @@ from app.models.user import User
from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSourceOut, DataSourceTest from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSourceOut, DataSourceTest
from app.schemas.common import ResponseModel, ListResponse from app.schemas.common import ResponseModel, ListResponse
from app.services import datasource_service from app.services import datasource_service
from app.api.deps import get_current_user from app.api.deps import get_current_user, require_admin
router = APIRouter() router = APIRouter()
@@ -41,7 +41,7 @@ def get_datasource(
def create_datasource( def create_datasource(
req: DataSourceCreate, req: DataSourceCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_admin),
): ):
item = datasource_service.create_datasource(db, req, current_user.id) item = datasource_service.create_datasource(db, req, current_user.id)
return ResponseModel(data=DataSourceOut.model_validate(item)) return ResponseModel(data=DataSourceOut.model_validate(item))
@@ -52,7 +52,7 @@ def update_datasource(
source_id: int, source_id: int,
req: DataSourceUpdate, req: DataSourceUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_admin),
): ):
db_obj = datasource_service.get_datasource(db, source_id) db_obj = datasource_service.get_datasource(db, source_id)
if not db_obj: if not db_obj:
@@ -66,7 +66,7 @@ def update_datasource(
def delete_datasource( def delete_datasource(
source_id: int, source_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_admin),
): ):
datasource_service.delete_datasource(db, source_id) datasource_service.delete_datasource(db, source_id)
return ResponseModel(message="删除成功") return ResponseModel(message="删除成功")
+32
View File
@@ -0,0 +1,32 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel
from app.services import lineage_service
from app.api.deps import get_current_user
router = APIRouter()
@router.post("/parse")
def parse_lineage(
sql: str,
target_table: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
records = lineage_service.parse_sql_lineage(db, sql, target_table)
return ResponseModel(data={"records_created": len(records)})
@router.get("/graph")
def get_graph(
table_name: Optional[str] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
graph = lineage_service.get_lineage_graph(db, table_name=table_name)
return ResponseModel(data=graph)
+88
View File
@@ -0,0 +1,88 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse
from app.services import masking_service
from app.api.deps import get_current_user, require_admin
router = APIRouter()
@router.get("/rules")
def list_masking_rules(
level_id: Optional[int] = Query(None),
category_id: Optional[int] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items, total = masking_service.list_masking_rules(db, level_id=level_id, category_id=category_id, page=page, page_size=page_size)
return ListResponse(
data=[{
"id": r.id,
"name": r.name,
"level_id": r.level_id,
"category_id": r.category_id,
"algorithm": r.algorithm,
"params": r.params,
"is_active": r.is_active,
"description": r.description,
"level_name": r.level.name if r.level else None,
"category_name": r.category.name if r.category else None,
} for r in items],
total=total,
page=page,
page_size=page_size,
)
@router.post("/rules")
def create_masking_rule(
req: dict,
db: Session = Depends(get_db),
current_user: User = Depends(require_admin),
):
item = masking_service.create_masking_rule(db, req)
return ResponseModel(data={"id": item.id})
@router.put("/rules/{rule_id}")
def update_masking_rule(
rule_id: int,
req: dict,
db: Session = Depends(get_db),
current_user: User = Depends(require_admin),
):
db_obj = masking_service.get_masking_rule(db, rule_id)
if not db_obj:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="规则不存在")
item = masking_service.update_masking_rule(db, db_obj, req)
return ResponseModel(data={"id": item.id})
@router.delete("/rules/{rule_id}")
def delete_masking_rule(
rule_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_admin),
):
masking_service.delete_masking_rule(db, rule_id)
return ResponseModel(message="删除成功")
@router.post("/preview")
def preview_masking(
source_id: int,
table_name: str,
project_id: Optional[int] = None,
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
result = masking_service.preview_masking(db, source_id, table_name, project_id=project_id, limit=limit)
return ResponseModel(data=result)
+2 -2
View File
@@ -7,7 +7,7 @@ from app.models.user import User
from app.schemas.metadata import DatabaseOut, DataTableOut, DataColumnOut from app.schemas.metadata import DatabaseOut, DataTableOut, DataColumnOut
from app.schemas.common import ResponseModel, ListResponse from app.schemas.common import ResponseModel, ListResponse
from app.services import metadata_service from app.services import metadata_service
from app.api.deps import get_current_user from app.api.deps import get_current_user, require_admin
router = APIRouter() router = APIRouter()
@@ -60,7 +60,7 @@ def list_columns(
def sync_metadata( def sync_metadata(
source_id: int, source_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_admin),
): ):
result = metadata_service.sync_metadata(db, source_id, current_user.id) result = metadata_service.sync_metadata(db, source_id, current_user.id)
return ResponseModel(data=result) return ResponseModel(data=result)
+86 -8
View File
@@ -1,12 +1,12 @@
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query, HTTPException, status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.database import get_db from app.core.database import get_db
from app.models.user import User from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse from app.schemas.common import ResponseModel, ListResponse
from app.services import project_service from app.services import project_service
from app.api.deps import get_current_user from app.api.deps import get_current_user, require_admin, require_manager, _is_admin
router = APIRouter() router = APIRouter()
@@ -17,9 +17,11 @@ def list_projects(
page_size: int = Query(20, ge=1, le=500), page_size: int = Query(20, ge=1, le=500),
keyword: Optional[str] = Query(None), keyword: Optional[str] = Query(None),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_manager),
): ):
items, total = project_service.list_projects(db, keyword=keyword, page=page, page_size=page_size) # Data isolation: non-admin users only see their own projects
created_by = None if _is_admin(current_user) else current_user.id
items, total = project_service.list_projects(db, keyword=keyword, page=page, page_size=page_size, created_by=created_by)
data = [] data = []
for p in items: for p in items:
stats = project_service.get_project_stats(db, p.id) stats = project_service.get_project_stats(db, p.id)
@@ -43,7 +45,7 @@ def create_project(
target_source_ids: Optional[str] = None, target_source_ids: Optional[str] = None,
description: Optional[str] = None, description: Optional[str] = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_manager),
): ):
item = project_service.create_project( item = project_service.create_project(
db, name=name, template_id=template_id, db, name=name, template_id=template_id,
@@ -85,16 +87,92 @@ def delete_project(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
p = project_service.get_project(db, project_id)
if not p:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
# Only admin or project creator can delete
if not _is_admin(current_user) and p.created_by != current_user.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权删除此项目")
project_service.delete_project(db, project_id) project_service.delete_project(db, project_id)
return ResponseModel(message="删除成功") return ResponseModel(message="删除成功")
@router.post("/{project_id}/auto-classify") @router.post("/{project_id}/auto-classify")
def project_auto_classify( def project_auto_classify(
project_id: int,
background: bool = True,
db: Session = Depends(get_db),
current_user: User = Depends(require_manager),
):
from app.tasks.classification_tasks import auto_classify_task
from celery.result import AsyncResult
project = project_service.get_project(db, project_id)
if not project:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
try:
if background:
# Check if already running
if project.celery_task_id:
existing = AsyncResult(project.celery_task_id)
if existing.state in ("PENDING", "PROGRESS", "STARTED"):
return ResponseModel(data={"task_id": project.celery_task_id, "status": existing.state})
task = auto_classify_task.delay(project_id)
project.celery_task_id = task.id
project.status = "scanning"
db.commit()
return ResponseModel(data={"task_id": task.id, "status": task.state})
else:
from app.services.classification_engine import run_auto_classification
project.status = "scanning"
db.commit()
result = run_auto_classification(db, project_id)
if result.get("success"):
project.status = "assigning"
else:
project.status = "created"
db.commit()
return ResponseModel(data=result)
except Exception as e:
import logging
logging.exception("Auto classify failed")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"自动分类执行失败: {str(e)}"
)
@router.get("/{project_id}/auto-classify-status")
def project_auto_classify_status(
project_id: int, project_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
from app.services.classification_engine import run_auto_classification from celery.result import AsyncResult
result = run_auto_classification(db, project_id) import json
return ResponseModel(data=result)
project = project_service.get_project(db, project_id)
if not project:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
task_id = project.celery_task_id
if not task_id:
# Return persisted progress if any
progress = json.loads(project.scan_progress) if project.scan_progress else None
return ResponseModel(data={"status": project.status, "progress": progress})
result = AsyncResult(task_id)
progress = None
if result.state == "PROGRESS" and result.info:
progress = result.info
elif project.scan_progress:
progress = json.loads(project.scan_progress)
return ResponseModel(data={
"status": result.state,
"task_id": task_id,
"progress": progress,
"project_status": project.status,
})
+20 -2
View File
@@ -7,7 +7,7 @@ from app.models.user import User
from app.models.project import ClassificationResult, ClassificationProject from app.models.project import ClassificationResult, ClassificationProject
from app.models.classification import Category, DataLevel from app.models.classification import Category, DataLevel
from app.schemas.common import ResponseModel from app.schemas.common import ResponseModel
from app.api.deps import get_current_user from app.api.deps import get_current_user, require_guest_or_above
from app.services import report_service from app.services import report_service
router = APIRouter() router = APIRouter()
@@ -16,7 +16,7 @@ router = APIRouter()
@router.get("/stats") @router.get("/stats")
def get_report_stats( def get_report_stats(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_guest_or_above),
): ):
"""Global report statistics.""" """Global report statistics."""
total = db.query(ClassificationResult).count() total = db.query(ClassificationResult).count()
@@ -44,12 +44,30 @@ def get_report_stats(
@router.get("/projects/{project_id}/download") @router.get("/projects/{project_id}/download")
def download_report( def download_report(
project_id: int, project_id: int,
format: str = "docx",
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
if format == "excel":
content = report_service.generate_excel_report(db, project_id)
return Response(
content=content,
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={"Content-Disposition": f"attachment; filename=report_project_{project_id}.xlsx"},
)
content = report_service.generate_classification_report(db, project_id) content = report_service.generate_classification_report(db, project_id)
return Response( return Response(
content=content, content=content,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f"attachment; filename=report_project_{project_id}.docx"}, headers={"Content-Disposition": f"attachment; filename=report_project_{project_id}.docx"},
) )
@router.get("/projects/{project_id}/summary")
def report_summary(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
data = report_service.get_report_summary(db, project_id)
return ResponseModel(data=data)
+73
View File
@@ -0,0 +1,73 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse
from app.services import risk_service
from app.api.deps import get_current_user
router = APIRouter()
@router.post("/recalculate")
def recalculate_risk(
project_id: Optional[int] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
if project_id:
result = risk_service.calculate_project_risk(db, project_id)
return ResponseModel(data={"project_id": project_id, "risk_score": result.risk_score if result else 0})
result = risk_service.calculate_all_projects_risk(db)
return ResponseModel(data=result)
@router.get("/top")
def risk_top(
entity_type: str = Query("project"),
n: int = Query(10, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items = risk_service.get_risk_top_n(db, entity_type=entity_type, n=n)
return ListResponse(
data=[{
"id": r.id,
"entity_type": r.entity_type,
"entity_id": r.entity_id,
"entity_name": r.entity_name,
"risk_score": r.risk_score,
"sensitivity_score": r.sensitivity_score,
"exposure_score": r.exposure_score,
"protection_score": r.protection_score,
"updated_at": r.updated_at.isoformat() if r.updated_at else None,
} for r in items],
total=len(items),
page=1,
page_size=n,
)
@router.get("/projects/{project_id}")
def project_risk(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from app.models.risk import RiskAssessment
item = db.query(RiskAssessment).filter(
RiskAssessment.entity_type == "project",
RiskAssessment.entity_id == project_id,
).first()
if not item:
return ResponseModel(data=None)
return ResponseModel(data={
"risk_score": item.risk_score,
"sensitivity_score": item.sensitivity_score,
"exposure_score": item.exposure_score,
"protection_score": item.protection_score,
"details": item.details,
"updated_at": item.updated_at.isoformat() if item.updated_at else None,
})
+45
View File
@@ -0,0 +1,45 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse
from app.models.schema_change import SchemaChangeLog
from app.api.deps import get_current_user
router = APIRouter()
@router.get("/logs")
def list_schema_changes(
source_id: Optional[int] = Query(None),
change_type: Optional[str] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
query = db.query(SchemaChangeLog)
if source_id:
query = query.filter(SchemaChangeLog.source_id == source_id)
if change_type:
query = query.filter(SchemaChangeLog.change_type == change_type)
total = query.count()
items = query.order_by(SchemaChangeLog.detected_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return ListResponse(
data=[{
"id": log.id,
"source_id": log.source_id,
"database_id": log.database_id,
"table_id": log.table_id,
"column_id": log.column_id,
"change_type": log.change_type,
"old_value": log.old_value,
"new_value": log.new_value,
"detected_at": log.detected_at.isoformat() if log.detected_at else None,
} for log in items],
total=total,
page=page,
page_size=page_size,
)
+6 -4
View File
@@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
from app.core.database import get_db from app.core.database import get_db
from app.models.user import User from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse from app.schemas.common import ResponseModel, ListResponse
from app.api.deps import get_current_user from app.api.deps import get_current_user, require_manager, require_labeler, _is_admin
from app.services import task_service, project_service from app.services import task_service, project_service
router = APIRouter() router = APIRouter()
@@ -15,9 +15,11 @@ router = APIRouter()
def my_tasks( def my_tasks(
status: Optional[str] = Query(None), status: Optional[str] = Query(None),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_labeler),
): ):
items, _ = task_service.list_tasks(db, assignee_id=current_user.id, status=status) # Data isolation: non-admin users only see tasks assigned to them
assignee_id = None if _is_admin(current_user) else current_user.id
items, _ = task_service.list_tasks(db, assignee_id=assignee_id, status=status)
data = [] data = []
for t in items: for t in items:
project = project_service.get_project(db, t.project_id) project = project_service.get_project(db, t.project_id)
@@ -100,7 +102,7 @@ def create_task_for_project(
assignee_id: int, assignee_id: int,
target_type: str = Query("column"), target_type: str = Query("column"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_manager),
): ):
task = task_service.create_task( task = task_service.create_task(
db, project_id=project_id, name=name, db, project_id=project_id, name=name,
+108
View File
@@ -0,0 +1,108 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query, UploadFile, File
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse
from app.services import unstructured_service
from app.api.deps import get_current_user
from app.core.events import minio_client
from app.core.config import settings
from app.models.metadata import UnstructuredFile
router = APIRouter()
@router.post("/upload")
def upload_file(
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
# Determine file type
filename = file.filename or "unknown"
ext = filename.split(".")[-1].lower() if "." in filename else ""
type_map = {
"docx": "word", "doc": "word",
"xlsx": "excel", "xls": "excel",
"pdf": "pdf",
"txt": "txt",
}
file_type = type_map.get(ext, "unknown")
# Upload to MinIO
storage_path = f"unstructured/{current_user.id}/{filename}"
try:
data = file.file.read()
minio_client.put_object(
settings.MINIO_BUCKET_NAME,
storage_path,
data=data,
length=len(data),
content_type=file.content_type or "application/octet-stream",
)
except Exception as e:
return ResponseModel(message=f"上传失败: {e}")
db_obj = UnstructuredFile(
original_name=filename,
file_type=file_type,
file_size=len(data),
storage_path=storage_path,
status="pending",
created_by=current_user.id,
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
# Trigger processing
try:
result = unstructured_service.process_unstructured_file(db, db_obj.id)
return ResponseModel(data={"id": db_obj.id, "matches": result.get("matches", []), "status": "processed"})
except Exception as e:
return ResponseModel(data={"id": db_obj.id, "status": "error", "error": str(e)})
@router.get("/files")
def list_files(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
query = db.query(UnstructuredFile).filter(UnstructuredFile.created_by == current_user.id)
total = query.count()
items = query.order_by(UnstructuredFile.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return ListResponse(
data=[{
"id": f.id,
"original_name": f.original_name,
"file_type": f.file_type,
"file_size": f.file_size,
"status": f.status,
"analysis_result": f.analysis_result,
"created_at": f.created_at.isoformat() if f.created_at else None,
} for f in items],
total=total,
page=page,
page_size=page_size,
)
@router.post("/files/{file_id}/reprocess")
def reprocess_file(
file_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
file_obj = db.query(UnstructuredFile).filter(
UnstructuredFile.id == file_id,
UnstructuredFile.created_by == current_user.id,
).first()
if not file_obj:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="文件不存在")
result = unstructured_service.process_unstructured_file(db, file_id)
return ResponseModel(data={"matches": result.get("matches", []), "status": "processed"})
+25 -6
View File
@@ -4,10 +4,11 @@ from sqlalchemy.orm import Session
from app.core.database import get_db from app.core.database import get_db
from app.models.user import User from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate, UserOut from app.models.user import Role, Dept
from app.schemas.user import UserCreate, UserUpdate, UserOut, RoleOut, DeptOut
from app.schemas.common import ResponseModel, ListResponse, PageParams from app.schemas.common import ResponseModel, ListResponse, PageParams
from app.services import user_service from app.services import user_service
from app.api.deps import get_current_user from app.api.deps import get_current_user, require_admin
router = APIRouter() router = APIRouter()
@@ -23,7 +24,7 @@ def list_users(
page_size: int = Query(20, ge=1, le=500), page_size: int = Query(20, ge=1, le=500),
keyword: Optional[str] = Query(None), keyword: Optional[str] = Query(None),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_admin),
): ):
items, total = user_service.list_users(db, keyword=keyword, page=page, page_size=page_size) items, total = user_service.list_users(db, keyword=keyword, page=page, page_size=page_size)
return ListResponse(data=[UserOut.model_validate(u) for u in items], total=total, page=page, page_size=page_size) return ListResponse(data=[UserOut.model_validate(u) for u in items], total=total, page=page, page_size=page_size)
@@ -33,7 +34,7 @@ def list_users(
def create_user( def create_user(
req: UserCreate, req: UserCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_admin),
): ):
user = user_service.create_user(db, req) user = user_service.create_user(db, req)
return ResponseModel(data=UserOut.model_validate(user)) return ResponseModel(data=UserOut.model_validate(user))
@@ -44,7 +45,7 @@ def update_user(
user_id: int, user_id: int,
req: UserUpdate, req: UserUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_admin),
): ):
user = user_service.get_user_by_id(db, user_id) user = user_service.get_user_by_id(db, user_id)
if not user: if not user:
@@ -58,7 +59,25 @@ def update_user(
def delete_user( def delete_user(
user_id: int, user_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_admin),
): ):
user_service.delete_user(db, user_id) user_service.delete_user(db, user_id)
return ResponseModel(message="删除成功") return ResponseModel(message="删除成功")
@router.get("/roles", response_model=ResponseModel[list[RoleOut]])
def list_roles(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items = db.query(Role).order_by(Role.id).all()
return ResponseModel(data=[RoleOut.model_validate(r) for r in items])
@router.get("/depts", response_model=ResponseModel[list[DeptOut]])
def list_depts(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items = db.query(Dept).order_by(Dept.sort_order).all()
return ResponseModel(data=[DeptOut.model_validate(d) for d in items])
+23
View File
@@ -0,0 +1,23 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel
from app.services import watermark_service
from app.api.deps import get_current_user
router = APIRouter()
@router.post("/trace")
def trace_watermark(
req: dict,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
text = req.get("text", "")
result = watermark_service.trace_watermark(db, text)
if not result:
return ResponseModel(data=None, message="未检测到水印")
return ResponseModel(data=result)
+8 -3
View File
@@ -3,13 +3,18 @@ from typing import List
class Settings(BaseSettings): class Settings(BaseSettings):
PROJECT_NAME: str = "PropDataGuard" PROJECT_NAME: str = "DataPointer"
VERSION: str = "0.1.0" VERSION: str = "0.1.0"
DESCRIPTION: str = "财产保险行业数据分级分类管理平台" DESCRIPTION: str = "财产保险行业数据分级分类管理平台"
DATABASE_URL: str = "postgresql+psycopg2://pdg:pdg_secret_2024@localhost:5432/prop_data_guard" DATABASE_URL: str = "postgresql+psycopg2://pdg:pdg_secret_2024@localhost:5432/prop_data_guard"
REDIS_URL: str = "redis://localhost:6379/0" REDIS_URL: str = "redis://localhost:6379/0"
# Database password encryption key (Fernet-compatible base64, 32 bytes)
# If empty, will be derived from SECRET_KEY for backward compatibility.
# STRONGLY recommended to set this explicitly in production.
DB_ENCRYPTION_KEY: str = ""
SECRET_KEY: str = "prop-data-guard-super-secret-key-change-in-production" SECRET_KEY: str = "prop-data-guard-super-secret-key-change-in-production"
ALGORITHM: str = "HS256" ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
@@ -24,8 +29,8 @@ class Settings(BaseSettings):
CORS_ORIGINS: List[str] = ["http://localhost:5173", "http://127.0.0.1:5173"] CORS_ORIGINS: List[str] = ["http://localhost:5173", "http://127.0.0.1:5173"]
FIRST_SUPERUSER_USERNAME: str = "admin" FIRST_SUPERUSER_USERNAME: str = "admin"
FIRST_SUPERUSER_PASSWORD: str = "admin123" FIRST_SUPERUSER_PASSWORD: str = "Zhidi@n2023"
FIRST_SUPERUSER_EMAIL: str = "admin@propdataguard.com" FIRST_SUPERUSER_EMAIL: str = "admin@datapo.com"
class Config: class Config:
env_file = ".env" env_file = ".env"
+3 -12
View File
@@ -39,19 +39,9 @@ async def log_requests(request: Request, call_next):
return response return response
from app.core.database import SessionLocal from app.core.database import SessionLocal
db = None
try: try:
db = SessionLocal() db = SessionLocal()
body_bytes = b""
if request.method in ["POST", "PUT", "PATCH"]:
try:
body_bytes = await request.body()
# Re-assign body for downstream
async def receive():
return {"type": "http.request", "body": body_bytes}
request._receive = receive
except Exception:
pass
log_entry = log_models.OperationLog( log_entry = log_models.OperationLog(
module=request.url.path.split("/")[2] if len(request.url.path.split("/")) > 2 else "", module=request.url.path.split("/")[2] if len(request.url.path.split("/")) > 2 else "",
action=request.url.path, action=request.url.path,
@@ -66,7 +56,8 @@ async def log_requests(request: Request, call_next):
except Exception: except Exception:
pass pass
finally: finally:
db.close() if db:
db.close()
return response return response
+15
View File
@@ -2,6 +2,14 @@ from app.models.user import User, Role, Dept, UserRole
from app.models.metadata import DataSource, Database, DataTable, DataColumn, UnstructuredFile from app.models.metadata import DataSource, Database, DataTable, DataColumn, UnstructuredFile
from app.models.classification import Category, DataLevel, RecognitionRule, ClassificationTemplate from app.models.classification import Category, DataLevel, RecognitionRule, ClassificationTemplate
from app.models.project import ClassificationProject, ClassificationTask, ClassificationResult, ClassificationChange from app.models.project import ClassificationProject, ClassificationTask, ClassificationResult, ClassificationChange
from app.models.ml import MLModelVersion
from app.models.masking import MaskingRule
from app.models.watermark import WatermarkLog
from app.models.schema_change import SchemaChangeLog
from app.models.risk import RiskAssessment
from app.models.compliance import ComplianceRule, ComplianceIssue
from app.models.alert import AlertRule, AlertRecord, WorkOrder
from app.models.api_asset import APIAsset, APIEndpoint
from app.models.log import OperationLog from app.models.log import OperationLog
__all__ = [ __all__ = [
@@ -9,5 +17,12 @@ __all__ = [
"DataSource", "Database", "DataTable", "DataColumn", "UnstructuredFile", "DataSource", "Database", "DataTable", "DataColumn", "UnstructuredFile",
"Category", "DataLevel", "RecognitionRule", "ClassificationTemplate", "Category", "DataLevel", "RecognitionRule", "ClassificationTemplate",
"ClassificationProject", "ClassificationTask", "ClassificationResult", "ClassificationChange", "ClassificationProject", "ClassificationTask", "ClassificationResult", "ClassificationChange",
"MLModelVersion",
"MaskingRule",
"WatermarkLog",
"SchemaChangeLog",
"RiskAssessment",
"ComplianceRule",
"ComplianceIssue",
"OperationLog", "OperationLog",
] ]
+46
View File
@@ -0,0 +1,46 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, ForeignKey, JSON
from sqlalchemy.orm import relationship
from app.core.database import Base
class AlertRule(Base):
__tablename__ = "alert_rule"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(200), nullable=False)
trigger_condition = Column(String(50), nullable=False) # l5_count, risk_score, schema_change
threshold = Column(Integer, default=0)
severity = Column(String(20), default="medium")
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
class AlertRecord(Base):
__tablename__ = "alert_record"
id = Column(Integer, primary_key=True, index=True)
rule_id = Column(Integer, ForeignKey("alert_rule.id"), nullable=False)
title = Column(String(200), nullable=False)
content = Column(Text)
severity = Column(String(20), default="medium")
status = Column(String(20), default="open") # open, acknowledged, resolved
created_at = Column(DateTime, default=datetime.utcnow)
rule = relationship("AlertRule")
class WorkOrder(Base):
__tablename__ = "work_order"
id = Column(Integer, primary_key=True, index=True)
alert_id = Column(Integer, ForeignKey("alert_record.id"), nullable=True)
title = Column(String(200), nullable=False)
description = Column(Text)
assignee_id = Column(Integer, ForeignKey("sys_user.id"), nullable=True)
status = Column(String(20), default="open") # open, in_progress, resolved
resolution = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
resolved_at = Column(DateTime, nullable=True)
assignee = relationship("User")
+41
View File
@@ -0,0 +1,41 @@
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, JSON, ForeignKey, BigInteger
from sqlalchemy.orm import relationship
from app.core.database import Base
from datetime import datetime
class APIAsset(Base):
__tablename__ = "api_asset"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(200), nullable=False)
base_url = Column(String(500), nullable=False)
swagger_url = Column(String(500), nullable=True)
auth_type = Column(String(50), default="none") # none, bearer, api_key, basic
headers = Column(JSON, default=dict)
description = Column(Text, nullable=True)
scan_status = Column(String(20), default="idle") # idle, scanning, completed, failed
total_endpoints = Column(Integer, default=0)
sensitive_endpoints = Column(Integer, default=0)
created_by = Column(Integer, ForeignKey("sys_user.id"), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
endpoints = relationship("APIEndpoint", back_populates="asset", cascade="all, delete-orphan")
creator = relationship("User", foreign_keys=[created_by])
class APIEndpoint(Base):
__tablename__ = "api_endpoint"
id = Column(Integer, primary_key=True, index=True)
asset_id = Column(Integer, ForeignKey("api_asset.id"), nullable=False)
method = Column(String(10), nullable=False) # GET, POST, PUT, DELETE, etc.
path = Column(String(500), nullable=False)
summary = Column(String(500), nullable=True)
tags = Column(JSON, default=list)
parameters = Column(JSON, default=list)
request_body_schema = Column(JSON, nullable=True)
response_schema = Column(JSON, nullable=True)
sensitive_fields = Column(JSON, default=list) # detected PII fields
risk_level = Column(String(20), default="low") # low, medium, high, critical
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
asset = relationship("APIAsset", back_populates="endpoints")
+33
View File
@@ -0,0 +1,33 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, JSON
from app.core.database import Base
class ComplianceRule(Base):
__tablename__ = "compliance_rule"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(200), nullable=False)
standard = Column(String(50), nullable=False) # dengbao, pipl, gdpr
description = Column(Text)
check_logic = Column(String(50), nullable=False) # check_masking, check_encryption, check_audit, check_level
severity = Column(String(20), default="medium") # low, medium, high, critical
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
class ComplianceIssue(Base):
__tablename__ = "compliance_issue"
id = Column(Integer, primary_key=True, index=True)
rule_id = Column(Integer, nullable=False)
project_id = Column(Integer, nullable=True)
entity_type = Column(String(20), nullable=False) # project, source, column
entity_id = Column(Integer, nullable=False)
entity_name = Column(String(200))
severity = Column(String(20), default="medium")
description = Column(Text)
suggestion = Column(Text)
status = Column(String(20), default="open") # open, resolved, ignored
resolved_at = Column(DateTime, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
+16
View File
@@ -0,0 +1,16 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime
from app.core.database import Base
class DataLineage(Base):
__tablename__ = "data_lineage"
id = Column(Integer, primary_key=True, index=True)
source_table = Column(String(200), nullable=False)
source_column = Column(String(200), nullable=True)
target_table = Column(String(200), nullable=False)
target_column = Column(String(200), nullable=True)
relation_type = Column(String(20), default="direct") # direct, derived, lookup
script_content = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
+22
View File
@@ -0,0 +1,22 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, JSON, Text
from sqlalchemy.orm import relationship
from app.core.database import Base
class MaskingRule(Base):
__tablename__ = "masking_rule"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False)
level_id = Column(Integer, ForeignKey("data_level.id"), nullable=True)
category_id = Column(Integer, ForeignKey("category.id"), nullable=True)
algorithm = Column(String(20), nullable=False) # mask, truncate, hash, generalize, replace
params = Column(JSON, default=dict) # algorithm-specific params
is_active = Column(Boolean, default=True)
description = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
level = relationship("DataLevel")
category = relationship("Category")
+14 -1
View File
@@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, Text, BigInteger from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, Text, BigInteger, JSON
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from app.core.database import Base from app.core.database import Base
@@ -36,6 +36,10 @@ class Database(Base):
charset = Column(String(50)) charset = Column(String(50))
table_count = Column(Integer, default=0) table_count = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow)
last_scanned_at = Column(DateTime, nullable=True)
checksum = Column(String(64), nullable=True)
is_deleted = Column(Boolean, default=False)
deleted_at = Column(DateTime, nullable=True)
source = relationship("DataSource", back_populates="databases") source = relationship("DataSource", back_populates="databases")
tables = relationship("DataTable", back_populates="database", cascade="all, delete-orphan") tables = relationship("DataTable", back_populates="database", cascade="all, delete-orphan")
@@ -51,6 +55,10 @@ class DataTable(Base):
row_count = Column(BigInteger, default=0) row_count = Column(BigInteger, default=0)
column_count = Column(Integer, default=0) column_count = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow)
last_scanned_at = Column(DateTime, nullable=True)
checksum = Column(String(64), nullable=True)
is_deleted = Column(Boolean, default=False)
deleted_at = Column(DateTime, nullable=True)
database = relationship("Database", back_populates="tables") database = relationship("Database", back_populates="tables")
columns = relationship("DataColumn", back_populates="table", cascade="all, delete-orphan") columns = relationship("DataColumn", back_populates="table", cascade="all, delete-orphan")
@@ -68,6 +76,10 @@ class DataColumn(Base):
is_nullable = Column(Boolean, default=True) is_nullable = Column(Boolean, default=True)
sample_data = Column(Text) # JSON array of sample values sample_data = Column(Text) # JSON array of sample values
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow)
last_scanned_at = Column(DateTime, nullable=True)
checksum = Column(String(64), nullable=True)
is_deleted = Column(Boolean, default=False)
deleted_at = Column(DateTime, nullable=True)
table = relationship("DataTable", back_populates="columns") table = relationship("DataTable", back_populates="columns")
@@ -81,6 +93,7 @@ class UnstructuredFile(Base):
file_size = Column(BigInteger) file_size = Column(BigInteger)
storage_path = Column(String(500)) storage_path = Column(String(500))
extracted_text = Column(Text) extracted_text = Column(Text)
analysis_result = Column(JSON, nullable=True) # JSON: {matches: [{rule_name, category, level, snippet}]}
status = Column(String(20), default="pending") # pending, processed, error status = Column(String(20), default="pending") # pending, processed, error
created_by = Column(Integer, ForeignKey("sys_user.id")) created_by = Column(Integer, ForeignKey("sys_user.id"))
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow)
+18
View File
@@ -0,0 +1,18 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Float, DateTime, Boolean, Text
from app.core.database import Base
class MLModelVersion(Base):
__tablename__ = "ml_model_version"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False)
model_path = Column(String(500), nullable=False) # joblib dump path
vectorizer_path = Column(String(500), nullable=False) # tfidf vectorizer path
accuracy = Column(Float, default=0.0)
train_samples = Column(Integer, default=0)
train_date = Column(DateTime, default=datetime.utcnow)
is_active = Column(Boolean, default=False)
description = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
+4
View File
@@ -48,6 +48,10 @@ class ClassificationProject(Base):
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Async classification tracking
celery_task_id = Column(String(100), nullable=True)
scan_progress = Column(Text, nullable=True) # JSON: {"scanned": 0, "matched": 0, "total": 0}
template = relationship("ClassificationTemplate") template = relationship("ClassificationTemplate")
tasks = relationship("ClassificationTask", back_populates="project", cascade="all, delete-orphan") tasks = relationship("ClassificationTask", back_populates="project", cascade="all, delete-orphan")
results = relationship("ClassificationResult", back_populates="project", cascade="all, delete-orphan") results = relationship("ClassificationResult", back_populates="project", cascade="all, delete-orphan")
+20
View File
@@ -0,0 +1,20 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Float, DateTime, ForeignKey, JSON
from sqlalchemy.orm import relationship
from app.core.database import Base
class RiskAssessment(Base):
__tablename__ = "risk_assessment"
id = Column(Integer, primary_key=True, index=True)
entity_type = Column(String(20), nullable=False) # project, source, table, field
entity_id = Column(Integer, nullable=False)
entity_name = Column(String(200))
risk_score = Column(Float, default=0.0) # 0-100
sensitivity_score = Column(Float, default=0.0)
exposure_score = Column(Float, default=0.0)
protection_score = Column(Float, default=0.0)
details = Column(JSON, default=dict)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
+23
View File
@@ -0,0 +1,23 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey
from sqlalchemy.orm import relationship
from app.core.database import Base
class SchemaChangeLog(Base):
__tablename__ = "schema_change_log"
id = Column(Integer, primary_key=True, index=True)
source_id = Column(Integer, ForeignKey("data_source.id"), nullable=False)
database_id = Column(Integer, ForeignKey("meta_database.id"), nullable=True)
table_id = Column(Integer, ForeignKey("meta_table.id"), nullable=True)
column_id = Column(Integer, ForeignKey("meta_column.id"), nullable=True)
change_type = Column(String(20), nullable=False) # add_table, drop_table, add_column, drop_column, change_type, change_comment
old_value = Column(Text)
new_value = Column(Text)
detected_at = Column(DateTime, default=datetime.utcnow)
source = relationship("DataSource")
database = relationship("Database")
table = relationship("DataTable")
column = relationship("DataColumn")
+17
View File
@@ -0,0 +1,17 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey
from sqlalchemy.orm import relationship
from app.core.database import Base
class WatermarkLog(Base):
__tablename__ = "watermark_log"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("sys_user.id"), nullable=False)
export_type = Column(String(20), default="csv") # csv, excel, txt
data_scope = Column(Text) # JSON: {source_id, table_name, row_count}
watermark_key = Column(String(64), nullable=False) # random key for this export
created_at = Column(DateTime, default=datetime.utcnow)
user = relationship("User")
+92
View File
@@ -0,0 +1,92 @@
from typing import List, Optional
from sqlalchemy.orm import Session
from datetime import datetime
from app.models.alert import AlertRule, AlertRecord, WorkOrder
from app.models.project import ClassificationProject, ClassificationResult
from app.models.risk import RiskAssessment
def init_builtin_alert_rules(db: Session):
if db.query(AlertRule).first():
return
rules = [
AlertRule(name="L5字段数量突增", trigger_condition="l5_count", threshold=5, severity="high"),
AlertRule(name="项目风险分过高", trigger_condition="risk_score", threshold=80, severity="critical"),
AlertRule(name="Schema新增敏感字段", trigger_condition="schema_change", threshold=1, severity="medium"),
]
for r in rules:
db.add(r)
db.commit()
def check_alerts(db: Session) -> List[AlertRecord]:
"""Run alert checks and create records."""
rules = db.query(AlertRule).filter(AlertRule.is_active == True).all()
records = []
for rule in rules:
if rule.trigger_condition == "l5_count":
projects = db.query(ClassificationProject).all()
for p in projects:
l5_count = db.query(ClassificationResult).filter(
ClassificationResult.project_id == p.id,
ClassificationResult.level_id.isnot(None),
).join(ClassificationResult.level).filter(
ClassificationResult.level.has(code="L5")
).count()
if l5_count >= rule.threshold:
rec = AlertRecord(
rule_id=rule.id,
title=f"项目 {p.name} L5字段数量达到 {l5_count}",
content=f"阈值: {rule.threshold}",
severity=rule.severity,
)
db.add(rec)
records.append(rec)
elif rule.trigger_condition == "risk_score":
risks = db.query(RiskAssessment).filter(
RiskAssessment.entity_type == "project",
RiskAssessment.risk_score >= rule.threshold,
).all()
for rsk in risks:
rec = AlertRecord(
rule_id=rule.id,
title=f"项目 {rsk.entity_name} 风险分 {rsk.risk_score}",
content=f"阈值: {rule.threshold}",
severity=rule.severity,
)
db.add(rec)
records.append(rec)
db.commit()
return records
def create_work_order(db: Session, alert_id: int, title: str, description: str, assignee_id: Optional[int] = None) -> WorkOrder:
wo = WorkOrder(
alert_id=alert_id,
title=title,
description=description,
assignee_id=assignee_id,
)
db.add(wo)
db.commit()
db.refresh(wo)
return wo
def update_work_order_status(db: Session, wo_id: int, status: str, resolution: str = None) -> WorkOrder:
wo = db.query(WorkOrder).filter(WorkOrder.id == wo_id).first()
if wo:
wo.status = status
if resolution:
wo.resolution = resolution
if status == "resolved":
wo.resolved_at = datetime.utcnow()
# Also resolve linked alert
if wo.alert_id:
alert = db.query(AlertRecord).filter(AlertRecord.id == wo.alert_id).first()
if alert:
alert.status = "resolved"
db.commit()
db.refresh(wo)
return wo
+174
View File
@@ -0,0 +1,174 @@
import requests, json
from typing import Optional
from sqlalchemy.orm import Session
from app.models.api_asset import APIAsset, APIEndpoint
from app.models.metadata import DataColumn
from app.services.classification_engine import match_rule
# Simple sensitive keywords for API field detection
SENSITIVE_KEYWORDS = [
"password", "pwd", "passwd", "secret", "token", "credit_card", "card_no",
"bank_account", "bank_card", "id_card", "id_number", "phone", "mobile",
"email", "address", "name", "age", "gender", "salary", "income",
"health", "medical", "biometric", "fingerprint", "face",
]
def _is_sensitive_field(name: str, schema: dict) -> tuple[bool, str]:
low = name.lower()
for kw in SENSITIVE_KEYWORDS:
if kw in low:
return True, f"keyword:{kw}"
# Check description / format hints
desc = str(schema.get("description", "")).lower()
fmt = str(schema.get("format", "")).lower()
if "email" in fmt or "email" in desc:
return True, "format:email"
if "uuid" in fmt and "user" in low:
return True, "format:user-uuid"
return False, ""
def _extract_fields(schema: dict, prefix: str = "") -> list[dict]:
fields = []
if not isinstance(schema, dict):
return fields
props = schema.get("properties", {})
for k, v in props.items():
full_name = f"{prefix}.{k}" if prefix else k
sensitive, reason = _is_sensitive_field(k, v)
if sensitive:
fields.append({"name": full_name, "type": v.get("type", "unknown"), "reason": reason})
# nested object
if v.get("type") == "object" and "properties" in v:
fields.extend(_extract_fields(v, full_name))
# array items
if v.get("type") == "array" and isinstance(v.get("items"), dict):
fields.extend(_extract_fields(v["items"], full_name + "[]"))
return fields
def _risk_level_from_fields(fields: list[dict]) -> str:
if not fields:
return "low"
high_keywords = {"password", "secret", "token", "credit_card", "bank_account", "biometric", "fingerprint", "face"}
for f in fields:
for kw in high_keywords:
if kw in f["name"].lower():
return "critical" if kw in {"password", "secret", "token", "biometric"} else "high"
return "medium"
def scan_swagger(db: Session, asset_id: int) -> dict:
asset = db.query(APIAsset).filter(APIAsset.id == asset_id).first()
if not asset:
return {"success": False, "error": "Asset not found"}
if not asset.swagger_url:
return {"success": False, "error": "No swagger_url configured"}
asset.scan_status = "scanning"
db.commit()
try:
headers = dict(asset.headers or {})
resp = requests.get(asset.swagger_url, headers=headers, timeout=30)
resp.raise_for_status()
spec = resp.json()
# Clear previous endpoints
db.query(APIEndpoint).filter(APIEndpoint.asset_id == asset_id).delete()
paths = spec.get("paths", {})
total = 0
sensitive_total = 0
for path, methods in paths.items():
for method, detail in methods.items():
if method.lower() not in {"get","post","put","patch","delete","head","options"}:
continue
total += 1
parameters = []
for p in detail.get("parameters", []):
parameters.append({"name": p.get("name"), "in": p.get("in"), "required": p.get("required", False), "type": p.get("schema",{}).get("type","string")})
req_schema = detail.get("requestBody", {}).get("content", {}).get("application/json", {}).get("schema")
resp_schema = None
for code, resp_detail in (detail.get("responses", {}).get("200", {}).get("content", {}) or {}).items():
if isinstance(resp_detail, dict) and "schema" in resp_detail:
resp_schema = resp_detail["schema"]
break
# Also try generic 200
if resp_schema is None:
ok = detail.get("responses", {}).get("200", {})
for ct, cd in ok.get("content", {}).items():
if isinstance(cd, dict) and "schema" in cd:
resp_schema = cd["schema"]
break
fields = []
if req_schema:
fields.extend(_extract_fields(req_schema))
if resp_schema:
fields.extend(_extract_fields(resp_schema))
# dedup
seen = set()
unique_fields = []
for f in fields:
if f["name"] not in seen:
seen.add(f["name"])
unique_fields.append(f)
risk = _risk_level_from_fields(unique_fields)
ep = APIEndpoint(
asset_id=asset_id,
method=method.upper(),
path=path,
summary=detail.get("summary", ""),
tags=detail.get("tags", []),
parameters=parameters,
request_body_schema=req_schema,
response_schema=resp_schema,
sensitive_fields=unique_fields,
risk_level=risk,
)
db.add(ep)
if unique_fields:
sensitive_total += 1
asset.scan_status = "completed"
asset.total_endpoints = total
asset.sensitive_endpoints = sensitive_total
asset.updated_at = __import__('datetime').datetime.utcnow()
db.commit()
return {"success": True, "total": total, "sensitive": sensitive_total}
except Exception as e:
asset.scan_status = "failed"
db.commit()
return {"success": False, "error": str(e)}
def create_asset(db: Session, data: dict, user_id: Optional[int] = None) -> APIAsset:
asset = APIAsset(
name=data["name"],
base_url=data["base_url"],
swagger_url=data.get("swagger_url"),
auth_type=data.get("auth_type", "none"),
headers=data.get("headers"),
description=data.get("description"),
created_by=user_id,
)
db.add(asset)
db.commit()
db.refresh(asset)
return asset
def update_asset(db: Session, asset_id: int, data: dict) -> Optional[APIAsset]:
asset = db.query(APIAsset).filter(APIAsset.id == asset_id).first()
if not asset:
return None
for k, v in data.items():
if hasattr(asset, k):
setattr(asset, k, v)
db.commit()
db.refresh(asset)
return asset
def delete_asset(db: Session, asset_id: int) -> bool:
asset = db.query(APIAsset).filter(APIAsset.id == asset_id).first()
if not asset:
return False
db.delete(asset)
db.commit()
return True
+44 -5
View File
@@ -51,11 +51,39 @@ def match_rule(rule: RecognitionRule, column: DataColumn) -> Tuple[bool, float]:
if t.strip().lower() in enums: if t.strip().lower() in enums:
return True, 0.90 return True, 0.90
elif rule.rule_type == "similarity":
benchmarks = [b.strip().lower() for b in rule.rule_content.split(",") if b.strip()]
if not benchmarks:
return False, 0.0
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
texts = [t.lower() for t in targets] + benchmarks
try:
vectorizer = TfidfVectorizer(analyzer="char_wb", ngram_range=(2, 3))
tfidf = vectorizer.fit_transform(texts)
target_vecs = tfidf[:len(targets)]
bench_vecs = tfidf[len(targets):]
sim_matrix = cosine_similarity(target_vecs, bench_vecs)
max_sim = float(sim_matrix.max())
if max_sim >= 0.75:
return True, round(min(max_sim, 0.99), 4)
except Exception:
pass
return False, 0.0 return False, 0.0
def run_auto_classification(db: Session, project_id: int, source_ids: Optional[List[int]] = None) -> dict: def run_auto_classification(
"""Run automatic classification for a project.""" db: Session,
project_id: int,
source_ids: Optional[List[int]] = None,
progress_callback=None,
) -> dict:
"""Run automatic classification for a project.
Args:
progress_callback: Optional callable(scanned, matched, total) to report progress.
"""
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first() project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if not project: if not project:
return {"success": False, "message": "项目不存在"} return {"success": False, "message": "项目不存在"}
@@ -82,7 +110,10 @@ def run_auto_classification(db: Session, project_id: int, source_ids: Optional[L
columns = columns_query.all() columns = columns_query.all()
matched_count = 0 matched_count = 0
for col in columns: total = len(columns)
report_interval = max(1, total // 20) # report ~20 times
for idx, col in enumerate(columns):
# Check if already has a result for this project # Check if already has a result for this project
existing = db.query(ClassificationResult).filter( existing = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project_id, ClassificationResult.project_id == project_id,
@@ -121,12 +152,20 @@ def run_auto_classification(db: Session, project_id: int, source_ids: Optional[L
# Increment hit count # Increment hit count
best_rule.hit_count = (best_rule.hit_count or 0) + 1 best_rule.hit_count = (best_rule.hit_count or 0) + 1
# Report progress periodically
if progress_callback and (idx + 1) % report_interval == 0:
progress_callback(scanned=idx + 1, matched=matched_count, total=total)
db.commit() db.commit()
# Final progress report
if progress_callback:
progress_callback(scanned=total, matched=matched_count, total=total)
return { return {
"success": True, "success": True,
"message": f"自动分类完成,共扫描 {len(columns)} 个字段,命中 {matched_count}", "message": f"自动分类完成,共扫描 {total} 个字段,命中 {matched_count}",
"scanned": len(columns), "scanned": total,
"matched": matched_count, "matched": matched_count,
} }
+140
View File
@@ -0,0 +1,140 @@
from datetime import datetime
from typing import List, Optional, Set, Tuple
from sqlalchemy.orm import Session, joinedload
from app.models.compliance import ComplianceRule, ComplianceIssue
from app.models.project import ClassificationResult
from app.models.masking import MaskingRule
def init_builtin_rules(db: Session):
"""Initialize built-in compliance rules."""
if db.query(ComplianceRule).first():
return
rules = [
ComplianceRule(name="L4/L5字段未配置脱敏", standard="dengbao", description="等保2.0要求:四级及以上数据应进行脱敏处理", check_logic="check_masking", severity="high"),
ComplianceRule(name="L5字段缺乏加密存储措施", standard="dengbao", description="等保2.0要求:五级数据应加密存储", check_logic="check_encryption", severity="critical"),
ComplianceRule(name="个人敏感信息处理未授权", standard="pipl", description="个人信息保护法:处理敏感个人信息应取得单独同意", check_logic="check_level", severity="high"),
ComplianceRule(name="数据跨境传输未评估", standard="gdpr", description="GDPR:个人数据跨境传输需进行影响评估", check_logic="check_audit", severity="medium"),
]
for r in rules:
db.add(r)
db.commit()
def scan_compliance(db: Session, project_id: Optional[int] = None) -> List[ComplianceIssue]:
"""Run compliance scan and generate issues."""
rules = db.query(ComplianceRule).filter(ComplianceRule.is_active == True).all()
if not rules:
return []
# Get masking rules for check_masking logic
masking_rules = db.query(MaskingRule).filter(MaskingRule.is_active == True).all()
masking_level_ids = {r.level_id for r in masking_rules if r.level_id}
# Build result filter and determine project ids
result_filter = [ClassificationResult.level_id.isnot(None)]
project_ids: List[int] = []
if project_id:
result_filter.append(ClassificationResult.project_id == project_id)
project_ids = [project_id]
else:
project_ids = [
r[0] for r in db.query(ClassificationResult.project_id).distinct().all()
]
if project_ids:
result_filter.append(ClassificationResult.project_id.in_(project_ids))
else:
return []
# Pre-load all results with level and column to avoid N+1 queries
results = db.query(ClassificationResult).options(
joinedload(ClassificationResult.level),
joinedload(ClassificationResult.column),
).filter(*result_filter).all()
if not results:
return []
# Batch query existing open issues
existing_issues = db.query(ComplianceIssue).filter(
ComplianceIssue.project_id.in_(project_ids),
ComplianceIssue.status == "open",
).all()
existing_set: Set[Tuple[int, int, str, int]] = {
(i.rule_id, i.project_id, i.entity_type, i.entity_id) for i in existing_issues
}
issues = []
for r in results:
if not r.level:
continue
level_code = r.level.code
for rule in rules:
matched = False
desc = ""
suggestion = ""
if rule.check_logic == "check_masking" and level_code in ("L4", "L5"):
if r.level_id not in masking_level_ids:
matched = True
desc = f"字段 '{r.column.name if r.column else '未知'}'{level_code} 级,但未配置脱敏规则"
suggestion = "请在【数据脱敏】模块为该分级配置脱敏策略"
elif rule.check_logic == "check_encryption" and level_code == "L5":
# Placeholder: no encryption check in MVP, always flag
matched = True
desc = f"字段 '{r.column.name if r.column else '未知'}' 为 L5 级核心数据,建议确认是否加密存储"
suggestion = "请确认该字段在数据库中已加密存储"
elif rule.check_logic == "check_level" and level_code in ("L4", "L5"):
if r.source == "auto":
matched = True
desc = f"个人敏感字段 '{r.column.name if r.column else '未知'}' 目前为自动识别,建议人工复核并确认授权"
suggestion = "请人工确认该字段的处理已取得合法授权"
elif rule.check_logic == "check_audit":
# Placeholder for cross-border check
pass
if matched:
key = (rule.id, r.project_id, "column", r.column_id or 0)
if key not in existing_set:
issue = ComplianceIssue(
rule_id=rule.id,
project_id=r.project_id,
entity_type="column",
entity_id=r.column_id or 0,
entity_name=r.column.name if r.column else "未知",
severity=rule.severity,
description=desc,
suggestion=suggestion,
)
db.add(issue)
issues.append(issue)
existing_set.add(key)
if issues:
db.commit()
return issues
def list_issues(db: Session, project_id: Optional[int] = None, status: Optional[str] = None, page: int = 1, page_size: int = 20):
query = db.query(ComplianceIssue)
if project_id:
query = query.filter(ComplianceIssue.project_id == project_id)
if status:
query = query.filter(ComplianceIssue.status == status)
total = query.count()
items = query.order_by(ComplianceIssue.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return items, total
def resolve_issue(db: Session, issue_id: int):
issue = db.query(ComplianceIssue).filter(ComplianceIssue.id == issue_id).first()
if issue:
issue.status = "resolved"
issue.resolved_at = datetime.utcnow()
db.commit()
return issue
+25 -3
View File
@@ -1,3 +1,6 @@
import base64
import hashlib
import logging
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from fastapi import HTTPException, status from fastapi import HTTPException, status
@@ -7,9 +10,28 @@ from app.models.metadata import DataSource
from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSourceTest from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSourceTest
from app.core.config import settings from app.core.config import settings
# Simple AES-like symmetric encryption for DB passwords logger = logging.getLogger(__name__)
# In production, use a proper KMS
_fernet = Fernet(Fernet.generate_key())
def _get_fernet() -> Fernet:
"""Initialize Fernet with a stable key.
If DB_ENCRYPTION_KEY is set, use it directly.
Otherwise derive deterministically from SECRET_KEY for backward compatibility.
"""
if settings.DB_ENCRYPTION_KEY:
key = settings.DB_ENCRYPTION_KEY.encode()
else:
logger.warning(
"DB_ENCRYPTION_KEY is not set. Deriving encryption key from SECRET_KEY. "
"Please set DB_ENCRYPTION_KEY explicitly via environment variable or .env file."
)
digest = hashlib.sha256(settings.SECRET_KEY.encode()).digest()
key = base64.urlsafe_b64encode(digest)
return Fernet(key)
_fernet = _get_fernet()
def _encrypt_password(password: str) -> str: def _encrypt_password(password: str) -> str:
+65
View File
@@ -0,0 +1,65 @@
import re
from typing import List, Optional
from sqlalchemy.orm import Session
from app.models.lineage import DataLineage
def _extract_tables(sql: str) -> List[str]:
"""Extract table names from SQL using regex (simple heuristic)."""
# Normalize SQL
sql = re.sub(r"--.*?\n", " ", sql)
sql = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL)
sql = sql.lower()
tables = set()
# FROM / JOIN / INTO
for pattern in [r"\bfrom\s+([a-z_][a-z0-9_]*)", r"\bjoin\s+([a-z_][a-z0-9_]*)"]:
for m in re.finditer(pattern, sql):
tables.add(m.group(1))
return sorted(tables)
def parse_sql_lineage(db: Session, sql: str, target_table: str) -> List[DataLineage]:
"""Parse SQL and create lineage records pointing to target_table."""
source_tables = _extract_tables(sql)
records = []
for st in source_tables:
if st == target_table:
continue
existing = db.query(DataLineage).filter(
DataLineage.source_table == st,
DataLineage.target_table == target_table,
).first()
if not existing:
rec = DataLineage(
source_table=st,
target_table=target_table,
relation_type="direct",
script_content=sql[:2000],
)
db.add(rec)
records.append(rec)
db.commit()
return records
def get_lineage_graph(db: Session, table_name: Optional[str] = None) -> dict:
"""Build graph data for ECharts."""
query = db.query(DataLineage)
if table_name:
query = query.filter(
(DataLineage.source_table == table_name) | (DataLineage.target_table == table_name)
)
items = query.limit(500).all()
nodes = {}
links = []
for item in items:
nodes[item.source_table] = {"name": item.source_table, "category": 0}
nodes[item.target_table] = {"name": item.target_table, "category": 1}
links.append({"source": item.source_table, "target": item.target_table, "value": item.relation_type})
return {
"nodes": list(nodes.values()),
"links": links,
}
+195
View File
@@ -0,0 +1,195 @@
import hashlib
from typing import Optional, Dict
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.metadata import DataSource, Database, DataTable, DataColumn
from app.models.project import ClassificationResult
from app.models.masking import MaskingRule
from app.services.datasource_service import get_datasource, _decrypt_password
def get_masking_rule(db: Session, rule_id: int):
return db.query(MaskingRule).filter(MaskingRule.id == rule_id).first()
def list_masking_rules(db: Session, level_id=None, category_id=None, page=1, page_size=20):
query = db.query(MaskingRule).filter(MaskingRule.is_active == True)
if level_id:
query = query.filter(MaskingRule.level_id == level_id)
if category_id:
query = query.filter(MaskingRule.category_id == category_id)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
def create_masking_rule(db: Session, data: dict):
db_obj = MaskingRule(**data)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update_masking_rule(db: Session, db_obj: MaskingRule, data: dict):
for k, v in data.items():
if v is not None:
setattr(db_obj, k, v)
db.commit()
db.refresh(db_obj)
return db_obj
def delete_masking_rule(db: Session, rule_id: int):
db_obj = get_masking_rule(db, rule_id)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="规则不存在")
db.delete(db_obj)
db.commit()
def _apply_mask(value, params):
if not value:
return value
keep_prefix = params.get("keep_prefix", 3)
keep_suffix = params.get("keep_suffix", 4)
mask_char = params.get("mask_char", "*")
if len(value) <= keep_prefix + keep_suffix:
return mask_char * len(value)
return value[:keep_prefix] + mask_char * (len(value) - keep_prefix - keep_suffix) + value[-keep_suffix:]
def _apply_truncate(value, params):
length = params.get("length", 3)
suffix = params.get("suffix", "...")
if not value or len(value) <= length:
return value
return value[:length] + suffix
def _apply_hash(value, params):
algorithm = params.get("algorithm", "sha256")
if algorithm == "md5":
return hashlib.md5(str(value).encode()).hexdigest()[:16]
return hashlib.sha256(str(value).encode()).hexdigest()[:32]
def _apply_generalize(value, params):
try:
step = params.get("step", 10)
num = float(value)
lower = int(num // step * step)
upper = lower + step
return f"{lower}-{upper}"
except Exception:
return value
def _apply_replace(value, params):
return params.get("replacement", "[REDACTED]")
def apply_masking(value, algorithm, params):
if value is None:
return None
handlers = {
"mask": _apply_mask,
"truncate": _apply_truncate,
"hash": _apply_hash,
"generalize": _apply_generalize,
"replace": _apply_replace,
}
handler = handlers.get(algorithm)
if not handler:
return value
return handler(str(value), params or {})
def _get_column_rules(db: Session, table_id: int, project_id=None):
columns = db.query(DataColumn).filter(DataColumn.table_id == table_id).all()
col_rules = {}
results = {}
if project_id:
res_list = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project_id,
ClassificationResult.column_id.in_([c.id for c in columns]),
).all()
results = {r.column_id: r for r in res_list}
rules = db.query(MaskingRule).filter(MaskingRule.is_active == True).all()
rule_map = {}
for r in rules:
key = (r.level_id, r.category_id)
if key not in rule_map:
rule_map[key] = r
for col in columns:
matched_rule = None
if col.id in results:
r = results[col.id]
matched_rule = rule_map.get((r.level_id, r.category_id))
if not matched_rule:
matched_rule = rule_map.get((r.level_id, None))
if not matched_rule:
matched_rule = rule_map.get((None, r.category_id))
col_rules[col.id] = matched_rule
return col_rules
def preview_masking(db: Session, source_id: int, table_name: str, project_id=None, limit=20):
source = get_datasource(db, source_id)
if not source:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在")
table = (
db.query(DataTable)
.join(Database)
.filter(Database.source_id == source_id, DataTable.name == table_name)
.first()
)
if not table:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="表不存在")
col_rules = _get_column_rules(db, table.id, project_id)
from sqlalchemy import create_engine, text
password = ""
if source.encrypted_password:
try:
password = _decrypt_password(source.encrypted_password)
except Exception:
pass
driver_map = {
"mysql": "mysql+pymysql",
"postgresql": "postgresql+psycopg2",
"oracle": "oracle+cx_oracle",
"sqlserver": "mssql+pymssql",
}
driver = driver_map.get(source.source_type, source.source_type)
url = f"{driver}://{source.username}:{password}@{source.host}:{source.port}/{source.database_name}"
engine = create_engine(url, pool_pre_ping=True)
columns = db.query(DataColumn).filter(DataColumn.table_id == table.id).all()
rows_raw = []
try:
with engine.connect() as conn:
result = conn.execute(text(f'SELECT * FROM "{table_name}" LIMIT {limit}'))
rows_raw = [dict(row._mapping) for row in result]
except Exception:
try:
with engine.connect() as conn:
result = conn.execute(text(f"SELECT * FROM {table_name} LIMIT {limit}"))
rows_raw = [dict(row._mapping) for row in result]
except Exception as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"查询失败: {e}")
masked_rows = []
for raw in rows_raw:
masked = {}
for col in columns:
val = raw.get(col.name)
rule = col_rules.get(col.id)
if rule:
masked[col.name] = apply_masking(val, rule.algorithm, rule.params or {})
else:
masked[col.name] = val
masked_rows.append(masked)
return {
"success": True,
"columns": [{"name": c.name, "data_type": c.data_type, "has_rule": col_rules.get(c.id) is not None} for c in columns],
"rows": masked_rows,
"total_rows": len(masked_rows),
}
+115 -17
View File
@@ -3,9 +3,23 @@ from sqlalchemy.orm import Session
from fastapi import HTTPException, status from fastapi import HTTPException, status
from app.models.metadata import DataSource, Database, DataTable, DataColumn from app.models.metadata import DataSource, Database, DataTable, DataColumn
from app.models.schema_change import SchemaChangeLog
from app.services.datasource_service import get_datasource, _decrypt_password from app.services.datasource_service import get_datasource, _decrypt_password
def _log_schema_change(db: Session, source_id: int, change_type: str, database_id: int = None, table_id: int = None, column_id: int = None, old_value: str = None, new_value: str = None):
log = SchemaChangeLog(
source_id=source_id,
database_id=database_id,
table_id=table_id,
column_id=column_id,
change_type=change_type,
old_value=old_value,
new_value=new_value,
)
db.add(log)
def get_database(db: Session, db_id: int) -> Optional[Database]: def get_database(db: Session, db_id: int) -> Optional[Database]:
return db.query(Database).filter(Database.id == db_id).first() return db.query(Database).filter(Database.id == db_id).first()
@@ -19,14 +33,14 @@ def get_column(db: Session, column_id: int) -> Optional[DataColumn]:
def list_databases(db: Session, source_id: Optional[int] = None) -> List[Database]: def list_databases(db: Session, source_id: Optional[int] = None) -> List[Database]:
query = db.query(Database) query = db.query(Database).filter(Database.is_deleted == False)
if source_id: if source_id:
query = query.filter(Database.source_id == source_id) query = query.filter(Database.source_id == source_id)
return query.all() return query.all()
def list_tables(db: Session, database_id: Optional[int] = None, keyword: Optional[str] = None) -> Tuple[List[DataTable], int]: def list_tables(db: Session, database_id: Optional[int] = None, keyword: Optional[str] = None) -> Tuple[List[DataTable], int]:
query = db.query(DataTable) query = db.query(DataTable).filter(DataTable.is_deleted == False)
if database_id: if database_id:
query = query.filter(DataTable.database_id == database_id) query = query.filter(DataTable.database_id == database_id)
if keyword: if keyword:
@@ -37,7 +51,7 @@ def list_tables(db: Session, database_id: Optional[int] = None, keyword: Optiona
def list_columns(db: Session, table_id: Optional[int] = None, keyword: Optional[str] = None, page: int = 1, page_size: int = 50) -> Tuple[List[DataColumn], int]: def list_columns(db: Session, table_id: Optional[int] = None, keyword: Optional[str] = None, page: int = 1, page_size: int = 50) -> Tuple[List[DataColumn], int]:
query = db.query(DataColumn) query = db.query(DataColumn).filter(DataColumn.is_deleted == False)
if table_id: if table_id:
query = query.filter(DataColumn.table_id == table_id) query = query.filter(DataColumn.table_id == table_id)
if keyword: if keyword:
@@ -49,7 +63,7 @@ def list_columns(db: Session, table_id: Optional[int] = None, keyword: Optional[
return items, total return items, total
def build_tree(db: Session, source_id: Optional[int] = None) -> List[dict]: def build_tree(db: Session, source_id: Optional[int] = None, include_deleted: bool = False) -> List[dict]:
sources = db.query(DataSource) sources = db.query(DataSource)
if source_id: if source_id:
sources = sources.filter(DataSource.id == source_id) sources = sources.filter(DataSource.id == source_id)
@@ -65,20 +79,24 @@ def build_tree(db: Session, source_id: Optional[int] = None) -> List[dict]:
"meta": {"source_type": s.source_type, "status": s.status}, "meta": {"source_type": s.source_type, "status": s.status},
} }
for d in s.databases: for d in s.databases:
if not include_deleted and d.is_deleted:
continue
db_node = { db_node = {
"id": d.id, "id": d.id,
"name": d.name, "name": d.name,
"type": "database", "type": "database",
"children": [], "children": [],
"meta": {"charset": d.charset, "table_count": d.table_count}, "meta": {"charset": d.charset, "table_count": d.table_count, "is_deleted": d.is_deleted},
} }
for t in d.tables: for t in d.tables:
if not include_deleted and t.is_deleted:
continue
table_node = { table_node = {
"id": t.id, "id": t.id,
"name": t.name, "name": t.name,
"type": "table", "type": "table",
"children": [], "children": [],
"meta": {"comment": t.comment, "row_count": t.row_count, "column_count": t.column_count}, "meta": {"comment": t.comment, "row_count": t.row_count, "column_count": t.column_count, "is_deleted": t.is_deleted},
} }
db_node["children"].append(table_node) db_node["children"].append(table_node)
source_node["children"].append(db_node) source_node["children"].append(db_node)
@@ -86,9 +104,16 @@ def build_tree(db: Session, source_id: Optional[int] = None) -> List[dict]:
return result return result
def _compute_checksum(data: dict) -> str:
import hashlib, json
payload = json.dumps(data, sort_keys=True, ensure_ascii=False, default=str)
return hashlib.sha256(payload.encode()).hexdigest()[:32]
def sync_metadata(db: Session, source_id: int, user_id: int) -> dict: def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
from sqlalchemy import create_engine, inspect, text from sqlalchemy import create_engine, inspect, text
import json import json
from datetime import datetime
source = get_datasource(db, source_id) source = get_datasource(db, source_id)
if not source: if not source:
@@ -118,29 +143,56 @@ def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
inspector = inspect(engine) inspector = inspect(engine)
db_names = inspector.get_schema_names() or [source.database_name] db_names = inspector.get_schema_names() or [source.database_name]
scan_time = datetime.utcnow()
total_tables = 0 total_tables = 0
total_columns = 0 total_columns = 0
updated_tables = 0
updated_columns = 0
for db_name in db_names: for db_name in db_names:
db_obj = db.query(Database).filter(Database.source_id == source.id, Database.name == db_name).first() db_checksum = _compute_checksum({"name": db_name})
db_obj = db.query(Database).filter(
Database.source_id == source.id, Database.name == db_name
).first()
if not db_obj: if not db_obj:
db_obj = Database(source_id=source.id, name=db_name) db_obj = Database(source_id=source.id, name=db_name, checksum=db_checksum, last_scanned_at=scan_time)
db.add(db_obj) db.add(db_obj)
db.commit() else:
db.refresh(db_obj) db_obj.checksum = db_checksum
db_obj.last_scanned_at = scan_time
db_obj.is_deleted = False
db_obj.deleted_at = None
table_names = inspector.get_table_names(schema=db_name) table_names = inspector.get_table_names(schema=db_name)
for tname in table_names: for tname in table_names:
table_obj = db.query(DataTable).filter(DataTable.database_id == db_obj.id, DataTable.name == tname).first() t_checksum = _compute_checksum({"name": tname})
table_obj = db.query(DataTable).filter(
DataTable.database_id == db_obj.id, DataTable.name == tname
).first()
if not table_obj: if not table_obj:
table_obj = DataTable(database_id=db_obj.id, name=tname) table_obj = DataTable(database_id=db_obj.id, name=tname, checksum=t_checksum, last_scanned_at=scan_time)
db.add(table_obj) db.add(table_obj)
db.commit() _log_schema_change(db, source.id, "add_table", database_id=db_obj.id, table_id=table_obj.id, new_value=tname)
db.refresh(table_obj) else:
if table_obj.checksum != t_checksum:
table_obj.checksum = t_checksum
updated_tables += 1
table_obj.last_scanned_at = scan_time
table_obj.is_deleted = False
table_obj.deleted_at = None
columns = inspector.get_columns(tname, schema=db_name) columns = inspector.get_columns(tname, schema=db_name)
for col in columns: for col in columns:
col_obj = db.query(DataColumn).filter(DataColumn.table_id == table_obj.id, DataColumn.name == col["name"]).first() col_checksum = _compute_checksum({
"name": col["name"],
"type": str(col.get("type", "")),
"max_length": col.get("max_length"),
"comment": col.get("comment"),
"nullable": col.get("nullable", True),
})
col_obj = db.query(DataColumn).filter(
DataColumn.table_id == table_obj.id, DataColumn.name == col["name"]
).first()
if not col_obj: if not col_obj:
sample = None sample = None
try: try:
@@ -150,7 +202,6 @@ def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
sample = json.dumps(samples, ensure_ascii=False) sample = json.dumps(samples, ensure_ascii=False)
except Exception: except Exception:
pass pass
col_obj = DataColumn( col_obj = DataColumn(
table_id=table_obj.id, table_id=table_obj.id,
name=col["name"], name=col["name"],
@@ -159,13 +210,58 @@ def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
comment=col.get("comment"), comment=col.get("comment"),
is_nullable=col.get("nullable", True), is_nullable=col.get("nullable", True),
sample_data=sample, sample_data=sample,
checksum=col_checksum,
last_scanned_at=scan_time,
) )
db.add(col_obj) db.add(col_obj)
total_columns += 1 total_columns += 1
_log_schema_change(db, source.id, "add_column", database_id=db_obj.id, table_id=table_obj.id, column_id=col_obj.id, new_value=col["name"])
else:
if col_obj.checksum != col_checksum:
old_val = f"type={col_obj.data_type}, len={col_obj.length}, comment={col_obj.comment}"
new_val = f"type={str(col.get('type', ''))}, len={col.get('max_length')}, comment={col.get('comment')}"
_log_schema_change(db, source.id, "change_type", database_id=db_obj.id, table_id=table_obj.id, column_id=col_obj.id, old_value=old_val, new_value=new_val)
col_obj.checksum = col_checksum
col_obj.data_type = str(col.get("type", ""))
col_obj.length = col.get("max_length")
col_obj.comment = col.get("comment")
col_obj.is_nullable = col.get("nullable", True)
updated_columns += 1
col_obj.last_scanned_at = scan_time
col_obj.is_deleted = False
col_obj.deleted_at = None
total_tables += 1 total_tables += 1
db.commit() # Soft-delete objects not seen in this scan and log changes
deleted_dbs = db.query(Database).filter(
Database.source_id == source.id,
Database.last_scanned_at < scan_time,
).all()
for d in deleted_dbs:
_log_schema_change(db, source.id, "drop_database", database_id=d.id, old_value=d.name)
d.is_deleted = True
d.deleted_at = scan_time
for db_obj in db.query(Database).filter(Database.source_id == source.id).all():
deleted_tables = db.query(DataTable).filter(
DataTable.database_id == db_obj.id,
DataTable.last_scanned_at < scan_time,
).all()
for t in deleted_tables:
_log_schema_change(db, source.id, "drop_table", database_id=db_obj.id, table_id=t.id, old_value=t.name)
t.is_deleted = True
t.deleted_at = scan_time
for table_obj in db.query(DataTable).filter(DataTable.database_id == db_obj.id).all():
deleted_cols = db.query(DataColumn).filter(
DataColumn.table_id == table_obj.id,
DataColumn.last_scanned_at < scan_time,
).all()
for c in deleted_cols:
_log_schema_change(db, source.id, "drop_column", database_id=db_obj.id, table_id=table_obj.id, column_id=c.id, old_value=c.name)
c.is_deleted = True
c.deleted_at = scan_time
source.status = "active" source.status = "active"
db.commit() db.commit()
@@ -176,6 +272,8 @@ def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
"databases": len(db_names), "databases": len(db_names),
"tables": total_tables, "tables": total_tables,
"columns": total_columns, "columns": total_columns,
"updated_tables": updated_tables,
"updated_columns": updated_columns,
} }
except Exception as e: except Exception as e:
source.status = "error" source.status = "error"
@@ -0,0 +1,183 @@
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.metadata import DataSource, Database, DataTable, DataColumn
from app.services.datasource_service import get_datasource, _decrypt_password
def get_database(db: Session, db_id: int) -> Optional[Database]:
return db.query(Database).filter(Database.id == db_id).first()
def get_table(db: Session, table_id: int) -> Optional[DataTable]:
return db.query(DataTable).filter(DataTable.id == table_id).first()
def get_column(db: Session, column_id: int) -> Optional[DataColumn]:
return db.query(DataColumn).filter(DataColumn.id == column_id).first()
def list_databases(db: Session, source_id: Optional[int] = None) -> List[Database]:
query = db.query(Database)
if source_id:
query = query.filter(Database.source_id == source_id)
return query.all()
def list_tables(db: Session, database_id: Optional[int] = None, keyword: Optional[str] = None) -> Tuple[List[DataTable], int]:
query = db.query(DataTable)
if database_id:
query = query.filter(DataTable.database_id == database_id)
if keyword:
query = query.filter(
(DataTable.name.contains(keyword)) | (DataTable.comment.contains(keyword))
)
return query.all(), query.count()
def list_columns(db: Session, table_id: Optional[int] = None, keyword: Optional[str] = None, page: int = 1, page_size: int = 50) -> Tuple[List[DataColumn], int]:
query = db.query(DataColumn)
if table_id:
query = query.filter(DataColumn.table_id == table_id)
if keyword:
query = query.filter(
(DataColumn.name.contains(keyword)) | (DataColumn.comment.contains(keyword))
)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
def build_tree(db: Session, source_id: Optional[int] = None) -> List[dict]:
sources = db.query(DataSource)
if source_id:
sources = sources.filter(DataSource.id == source_id)
sources = sources.all()
result = []
for s in sources:
source_node = {
"id": s.id,
"name": s.name,
"type": "source",
"children": [],
"meta": {"source_type": s.source_type, "status": s.status},
}
for d in s.databases:
db_node = {
"id": d.id,
"name": d.name,
"type": "database",
"children": [],
"meta": {"charset": d.charset, "table_count": d.table_count},
}
for t in d.tables:
table_node = {
"id": t.id,
"name": t.name,
"type": "table",
"children": [],
"meta": {"comment": t.comment, "row_count": t.row_count, "column_count": t.column_count},
}
db_node["children"].append(table_node)
source_node["children"].append(db_node)
result.append(source_node)
return result
def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
from sqlalchemy import create_engine, inspect, text
import json
source = get_datasource(db, source_id)
if not source:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在")
driver_map = {
"mysql": "mysql+pymysql",
"postgresql": "postgresql+psycopg2",
"oracle": "oracle+cx_oracle",
"sqlserver": "mssql+pymssql",
}
driver = driver_map.get(source.source_type, source.source_type)
if source.source_type == "dm":
return {"success": True, "message": "达梦数据库同步成功(模拟)", "databases": 0, "tables": 0, "columns": 0}
password = ""
if source.encrypted_password:
try:
password = _decrypt_password(source.encrypted_password)
except Exception:
pass
try:
url = f"{driver}://{source.username}:{password}@{source.host}:{source.port}/{source.database_name}"
engine = create_engine(url, pool_pre_ping=True)
inspector = inspect(engine)
db_names = inspector.get_schema_names() or [source.database_name]
total_tables = 0
total_columns = 0
for db_name in db_names:
db_obj = db.query(Database).filter(Database.source_id == source.id, Database.name == db_name).first()
if not db_obj:
db_obj = Database(source_id=source.id, name=db_name)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
table_names = inspector.get_table_names(schema=db_name)
for tname in table_names:
table_obj = db.query(DataTable).filter(DataTable.database_id == db_obj.id, DataTable.name == tname).first()
if not table_obj:
table_obj = DataTable(database_id=db_obj.id, name=tname)
db.add(table_obj)
db.commit()
db.refresh(table_obj)
columns = inspector.get_columns(tname, schema=db_name)
for col in columns:
col_obj = db.query(DataColumn).filter(DataColumn.table_id == table_obj.id, DataColumn.name == col["name"]).first()
if not col_obj:
sample = None
try:
with engine.connect() as conn:
result = conn.execute(text(f'SELECT "{col["name"]}" FROM "{db_name}"."{tname}" LIMIT 5'))
samples = [str(r[0]) for r in result if r[0] is not None]
sample = json.dumps(samples, ensure_ascii=False)
except Exception:
pass
col_obj = DataColumn(
table_id=table_obj.id,
name=col["name"],
data_type=str(col.get("type", "")),
length=col.get("max_length"),
comment=col.get("comment"),
is_nullable=col.get("nullable", True),
sample_data=sample,
)
db.add(col_obj)
total_columns += 1
total_tables += 1
db.commit()
source.status = "active"
db.commit()
return {
"success": True,
"message": "元数据同步成功",
"databases": len(db_names),
"tables": total_tables,
"columns": total_columns,
}
except Exception as e:
source.status = "error"
db.commit()
return {"success": False, "message": f"同步失败: {str(e)}", "databases": 0, "tables": 0, "columns": 0}
+195
View File
@@ -0,0 +1,195 @@
import os
import json
import logging
from typing import List, Optional, Tuple
from datetime import datetime
import joblib
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sqlalchemy.orm import Session
from app.models.project import ClassificationResult
from app.models.classification import Category
from app.models.ml import MLModelVersion
logger = logging.getLogger(__name__)
MODELS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "ml_models")
os.makedirs(MODELS_DIR, exist_ok=True)
def _build_text_features(column_name: str, comment: Optional[str], sample_data: Optional[str]) -> str:
parts = [column_name]
if comment:
parts.append(comment)
if sample_data:
try:
samples = json.loads(sample_data)
if isinstance(samples, list):
parts.extend([str(s) for s in samples[:5]])
except Exception:
parts.append(sample_data)
return " ".join(parts)
def _fetch_training_data(db: Session, min_samples_per_class: int = 5):
results = (
db.query(ClassificationResult)
.filter(ClassificationResult.source == "manual")
.filter(ClassificationResult.category_id.isnot(None))
.all()
)
texts = []
labels = []
for r in results:
if r.column:
text = _build_text_features(r.column.name, r.column.comment, r.column.sample_data)
texts.append(text)
labels.append(r.category_id)
from collections import Counter
counts = Counter(labels)
valid_classes = {c for c, n in counts.items() if n >= min_samples_per_class}
filtered_texts = []
filtered_labels = []
for t, l in zip(texts, labels):
if l in valid_classes:
filtered_texts.append(t)
filtered_labels.append(l)
return filtered_texts, filtered_labels, len(filtered_labels)
def train_model(db: Session, model_name: Optional[str] = None, algorithm: str = "logistic_regression", test_size: float = 0.2):
texts, labels, total = _fetch_training_data(db)
if total < 20:
logger.warning("Not enough training data (need >= 20, got %d)", total)
return None
X_train, X_test, y_train, y_test = train_test_split(
texts, labels, test_size=test_size, random_state=42, stratify=labels
)
vectorizer = TfidfVectorizer(analyzer="char_wb", ngram_range=(2, 4), max_features=5000)
X_train_vec = vectorizer.fit_transform(X_train)
X_test_vec = vectorizer.transform(X_test)
if algorithm == "logistic_regression":
clf = LogisticRegression(max_iter=1000, multi_class="multinomial", solver="lbfgs")
elif algorithm == "random_forest":
clf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
else:
clf = LogisticRegression(max_iter=1000, multi_class="multinomial", solver="lbfgs")
clf.fit(X_train_vec, y_train)
y_pred = clf.predict(X_test_vec)
acc = accuracy_score(y_test, y_pred)
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
name = model_name or f"model_{timestamp}"
model_path = os.path.join(MODELS_DIR, f"{name}_clf.joblib")
vec_path = os.path.join(MODELS_DIR, f"{name}_tfidf.joblib")
joblib.dump(clf, model_path)
joblib.dump(vectorizer, vec_path)
db.query(MLModelVersion).filter(MLModelVersion.is_active == True).update({"is_active": False})
mv = MLModelVersion(
name=name,
model_path=model_path,
vectorizer_path=vec_path,
accuracy=acc,
train_samples=total,
is_active=True,
description=f"Algorithm: {algorithm}, test_accuracy: {acc:.4f}",
)
db.add(mv)
db.commit()
db.refresh(mv)
logger.info("Trained model %s with accuracy %.4f on %d samples", name, acc, total)
return mv
def _get_active_model(db: Session):
mv = db.query(MLModelVersion).filter(MLModelVersion.is_active == True).first()
if not mv or not os.path.exists(mv.model_path) or not os.path.exists(mv.vectorizer_path):
return None
clf = joblib.load(mv.model_path)
vectorizer = joblib.load(mv.vectorizer_path)
return clf, vectorizer, mv
def predict_categories(db: Session, texts: List[str], top_k: int = 3):
model_tuple = _get_active_model(db)
if not model_tuple:
return [[] for _ in texts]
clf, vectorizer, mv = model_tuple
X = vectorizer.transform(texts)
if hasattr(clf, "predict_proba"):
probs = clf.predict_proba(X)
else:
preds = clf.predict(X)
return [[{"category_id": int(p), "confidence": 1.0}] for p in preds]
classes = [int(c) for c in clf.classes_]
results = []
for prob in probs:
top_idx = np.argsort(prob)[::-1][:top_k]
suggestions = []
for idx in top_idx:
cat_id = classes[idx]
confidence = float(prob[idx])
if confidence > 0.01:
suggestions.append({"category_id": cat_id, "confidence": round(confidence, 4)})
results.append(suggestions)
return results
def suggest_for_project_columns(db: Session, project_id: int, column_ids: Optional[List[int]] = None, top_k: int = 3):
from app.models.project import ClassificationProject
from app.models.metadata import DataColumn
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if not project:
return {"success": False, "message": "项目不存在"}
query = db.query(DataColumn).join(
ClassificationResult,
(ClassificationResult.column_id == DataColumn.id) & (ClassificationResult.project_id == project_id),
isouter=True,
)
if column_ids:
query = query.filter(DataColumn.id.in_(column_ids))
columns = query.all()
texts = []
col_map = []
for col in columns:
texts.append(_build_text_features(col.name, col.comment, col.sample_data))
col_map.append(col)
if not texts:
return {"success": True, "suggestions": [], "message": "没有可预测的字段"}
predictions = predict_categories(db, texts, top_k=top_k)
suggestions = []
all_category_ids = set()
for col, preds in zip(col_map, predictions):
for p in preds:
all_category_ids.add(p["category_id"])
categories = {c.id: c for c in db.query(Category).filter(Category.id.in_(list(all_category_ids))).all()}
for col, preds in zip(col_map, predictions):
item = {
"column_id": col.id,
"column_name": col.name,
"table_name": col.table.name if col.table else None,
"suggestions": [],
}
for p in preds:
cat = categories.get(p["category_id"])
item["suggestions"].append({
"category_id": p["category_id"],
"category_name": cat.name if cat else None,
"category_code": cat.code if cat else None,
"confidence": p["confidence"],
})
suggestions.append(item)
return {"success": True, "suggestions": suggestions}
+6 -1
View File
@@ -12,11 +12,13 @@ def get_project(db: Session, project_id: int) -> Optional[ClassificationProject]
def list_projects( def list_projects(
db: Session, keyword: Optional[str] = None, page: int = 1, page_size: int = 20 db: Session, keyword: Optional[str] = None, page: int = 1, page_size: int = 20, created_by: Optional[int] = None
) -> Tuple[List[ClassificationProject], int]: ) -> Tuple[List[ClassificationProject], int]:
query = db.query(ClassificationProject) query = db.query(ClassificationProject)
if keyword: if keyword:
query = query.filter(ClassificationProject.name.contains(keyword)) query = query.filter(ClassificationProject.name.contains(keyword))
if created_by:
query = query.filter(ClassificationProject.created_by == created_by)
total = query.count() total = query.count()
items = query.order_by(ClassificationProject.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all() items = query.order_by(ClassificationProject.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return items, total return items, total
@@ -82,10 +84,13 @@ def list_results(
keyword: Optional[str] = None, keyword: Optional[str] = None,
page: int = 1, page: int = 1,
page_size: int = 50, page_size: int = 50,
project_ids: Optional[List[int]] = None,
) -> Tuple[List[ClassificationResult], int]: ) -> Tuple[List[ClassificationResult], int]:
query = db.query(ClassificationResult) query = db.query(ClassificationResult)
if project_id: if project_id:
query = query.filter(ClassificationResult.project_id == project_id) query = query.filter(ClassificationResult.project_id == project_id)
if project_ids:
query = query.filter(ClassificationResult.project_id.in_(project_ids))
if table_id: if table_id:
query = query.join(DataColumn).filter(DataColumn.table_id == table_id) query = query.join(DataColumn).filter(DataColumn.table_id == table_id)
if status: if status:
+152
View File
@@ -94,3 +94,155 @@ def generate_classification_report(db: Session, project_id: int) -> bytes:
doc.save(buffer) doc.save(buffer)
buffer.seek(0) buffer.seek(0)
return buffer.read() return buffer.read()
def generate_excel_report(db: Session, project_id: int) -> bytes:
"""Generate an Excel report for a classification project."""
from openpyxl import Workbook
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
from openpyxl.chart import PieChart, Reference
from sqlalchemy import func
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if not project:
raise ValueError("项目不存在")
wb = Workbook()
ws = wb.active
ws.title = "报告概览"
# Title
ws.merge_cells('A1:D1')
ws['A1'] = '数据分类分级项目报告'
ws['A1'].font = Font(size=18, bold=True)
ws['A1'].alignment = Alignment(horizontal='center')
# Basic info
ws['A3'] = '项目名称'
ws['B3'] = project.name
ws['A4'] = '报告生成时间'
ws['B4'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
ws['A5'] = '项目状态'
ws['B5'] = project.status
ws['A6'] = '模板版本'
ws['B6'] = project.template.version if project.template else 'N/A'
# Statistics
results = db.query(ClassificationResult).filter(ClassificationResult.project_id == project_id).all()
total = len(results)
auto_count = sum(1 for r in results if r.source == 'auto')
manual_count = sum(1 for r in results if r.source == 'manual')
ws['A8'] = '总字段数'
ws['B8'] = total
ws['A9'] = '自动识别'
ws['B9'] = auto_count
ws['A10'] = '人工打标'
ws['B10'] = manual_count
# Level distribution
ws['A12'] = '分级'
ws['B12'] = '数量'
ws['C12'] = '占比'
ws['A12'].font = Font(bold=True)
ws['B12'].font = Font(bold=True)
ws['C12'].font = Font(bold=True)
level_stats = {}
for r in results:
if r.level:
level_stats[r.level.name] = level_stats.get(r.level.name, 0) + 1
red_fill = PatternFill(start_color='FFCCCC', end_color='FFCCCC', fill_type='solid')
row = 13
for level_name, count in sorted(level_stats.items(), key=lambda x: -x[1]):
ws.cell(row=row, column=1, value=level_name)
ws.cell(row=row, column=2, value=count)
pct = f'{count / total * 100:.1f}%' if total > 0 else '0%'
ws.cell(row=row, column=3, value=pct)
if 'L4' in level_name or 'L5' in level_name:
for c in range(1, 4):
ws.cell(row=row, column=c).fill = red_fill
row += 1
# High risk sheet
ws2 = wb.create_sheet("高敏感数据清单")
ws2.append(['字段名', '所属表', '分类', '分级', '来源', '置信度'])
for cell in ws2[1]:
cell.font = Font(bold=True)
cell.fill = PatternFill(start_color='DDEBF7', end_color='DDEBF7', fill_type='solid')
high_risk = [r for r in results if r.level and r.level.code in ('L4', 'L5')]
for r in high_risk[:500]:
ws2.append([
r.column.name if r.column else 'N/A',
r.column.table.name if r.column and r.column.table else 'N/A',
r.category.name if r.category else 'N/A',
r.level.name if r.level else 'N/A',
'自动' if r.source == 'auto' else '人工',
r.confidence,
])
# Auto-fit column widths roughly
for ws_sheet in [ws, ws2]:
for column in ws_sheet.columns:
max_length = 0
column_letter = column[0].column_letter
for cell in column:
try:
if len(str(cell.value)) > max_length:
max_length = len(str(cell.value))
except:
pass
adjusted_width = min(max_length + 2, 50)
ws_sheet.column_dimensions[column_letter].width = adjusted_width
buffer = BytesIO()
wb.save(buffer)
buffer.seek(0)
return buffer.read()
def get_report_summary(db: Session, project_id: int) -> dict:
"""Get aggregated report data for PDF generation (frontend)."""
from sqlalchemy import func
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if not project:
raise ValueError("项目不存在")
results = db.query(ClassificationResult).filter(ClassificationResult.project_id == project_id).all()
total = len(results)
auto_count = sum(1 for r in results if r.source == 'auto')
manual_count = sum(1 for r in results if r.source == 'manual')
level_stats = {}
for r in results:
if r.level:
level_stats[r.level.name] = level_stats.get(r.level.name, 0) + 1
high_risk = []
for r in results:
if r.level and r.level.code in ('L4', 'L5'):
high_risk.append({
"column_name": r.column.name if r.column else 'N/A',
"table_name": r.column.table.name if r.column and r.column.table else 'N/A',
"category_name": r.category.name if r.category else 'N/A',
"level_name": r.level.name if r.level else 'N/A',
"source": '自动' if r.source == 'auto' else '人工',
"confidence": r.confidence,
})
return {
"project_name": project.name,
"status": project.status,
"template_version": project.template.version if project.template else 'N/A',
"generated_at": datetime.now().isoformat(),
"total": total,
"auto": auto_count,
"manual": manual_count,
"level_distribution": [
{"name": name, "count": count}
for name, count in sorted(level_stats.items(), key=lambda x: -x[1])
],
"high_risk": high_risk[:100],
}
+125
View File
@@ -0,0 +1,125 @@
from typing import List, Optional
from sqlalchemy.orm import Session
from datetime import datetime
from app.models.project import ClassificationProject, ClassificationResult
from app.models.classification import DataLevel
from app.models.metadata import DataSource, Database, DataTable, DataColumn
from app.models.masking import MaskingRule
from app.models.risk import RiskAssessment
def _get_level_weight(level_code: str) -> int:
weights = {"L1": 1, "L2": 2, "L3": 3, "L4": 4, "L5": 5}
return weights.get(level_code, 1)
def calculate_project_risk(db: Session, project_id: int) -> RiskAssessment:
"""Calculate risk score for a project."""
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if not project:
return None
results = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project_id,
ClassificationResult.level_id.isnot(None),
).all()
total_risk = 0.0
total_sensitivity = 0.0
total_exposure = 0.0
total_protection = 0.0
detail_items = []
# Get all active masking rules for quick lookup
rules = db.query(MaskingRule).filter(MaskingRule.is_active == True).all()
rule_level_ids = {r.level_id for r in rules if r.level_id}
rule_cat_ids = {r.category_id for r in rules if r.category_id}
for r in results:
if not r.level:
continue
level_weight = _get_level_weight(r.level.code)
# Exposure: count source connections for the column's table
source_count = 1
if r.column and r.column.table and r.column.table.database:
# Simple: if table exists in multiple dbs (rare), count them
source_count = max(1, len(r.column.table.database.source.databases or []))
exposure_factor = 1 + source_count * 0.2
# Protection: check if masking rule exists for this level/category
has_masking = (r.level_id in rule_level_ids) or (r.category_id in rule_cat_ids)
protection_rate = 0.3 if has_masking else 0.0
item_risk = level_weight * exposure_factor * (1 - protection_rate)
total_risk += item_risk
total_sensitivity += level_weight
total_exposure += exposure_factor
total_protection += protection_rate
detail_items.append({
"column_id": r.column_id,
"column_name": r.column.name if r.column else None,
"level": r.level.code if r.level else None,
"level_weight": level_weight,
"exposure_factor": round(exposure_factor, 2),
"protection_rate": protection_rate,
"item_risk": round(item_risk, 2),
})
# Normalize to 0-100 (heuristic: assume max reasonable raw score is 15 per field)
count = len(detail_items) or 1
max_raw = count * 15
risk_score = min(100, (total_risk / max_raw) * 100) if max_raw > 0 else 0
# Upsert risk assessment
existing = db.query(RiskAssessment).filter(
RiskAssessment.entity_type == "project",
RiskAssessment.entity_id == project_id,
).first()
if existing:
existing.risk_score = round(risk_score, 2)
existing.sensitivity_score = round(total_sensitivity / count, 2)
existing.exposure_score = round(total_exposure / count, 2)
existing.protection_score = round(total_protection / count, 2)
existing.details = {"items": detail_items[:100], "total_items": len(detail_items)}
existing.updated_at = datetime.utcnow()
else:
existing = RiskAssessment(
entity_type="project",
entity_id=project_id,
entity_name=project.name,
risk_score=round(risk_score, 2),
sensitivity_score=round(total_sensitivity / count, 2),
exposure_score=round(total_exposure / count, 2),
protection_score=round(total_protection / count, 2),
details={"items": detail_items[:100], "total_items": len(detail_items)},
)
db.add(existing)
db.commit()
return existing
def calculate_all_projects_risk(db: Session) -> dict:
"""Batch calculate risk for all projects."""
projects = db.query(ClassificationProject).all()
updated = 0
for p in projects:
try:
calculate_project_risk(db, p.id)
updated += 1
except Exception:
pass
return {"updated": updated}
def get_risk_top_n(db: Session, entity_type: str = "project", n: int = 10) -> List[RiskAssessment]:
return (
db.query(RiskAssessment)
.filter(RiskAssessment.entity_type == entity_type)
.order_by(RiskAssessment.risk_score.desc())
.limit(n)
.all()
)
@@ -0,0 +1,99 @@
import os
import re
import json
from typing import Optional, List
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.metadata import UnstructuredFile
from app.core.events import minio_client
from app.core.config import settings
def extract_text_from_file(file_path: str, file_type: str) -> str:
text = ""
ft = file_type.lower()
if ft in ("word", "docx"):
try:
from docx import Document
doc = Document(file_path)
text = "\n".join([p.text for p in doc.paragraphs if p.text])
except Exception as e:
raise ValueError(f"解析Word失败: {e}")
elif ft in ("excel", "xlsx", "xls"):
try:
from openpyxl import load_workbook
wb = load_workbook(file_path, data_only=True)
parts = []
for sheet in wb.worksheets:
for row in sheet.iter_rows(values_only=True):
parts.append(" ".join([str(c) for c in row if c is not None]))
text = "\n".join(parts)
except Exception as e:
raise ValueError(f"解析Excel失败: {e}")
elif ft == "pdf":
try:
import pdfplumber
with pdfplumber.open(file_path) as pdf:
text = "\n".join([page.extract_text() or "" for page in pdf.pages])
except Exception as e:
raise ValueError(f"解析PDF失败: {e}")
elif ft == "txt":
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
text = f.read()
else:
raise ValueError(f"不支持的文件类型: {ft}")
return text
def scan_text_for_sensitive(text: str) -> List[dict]:
"""Scan extracted text for sensitive patterns using built-in rules."""
matches = []
# ID card
id_pattern = re.compile(r"(?<!\d)\d{17}[\dXx](?!\d)")
for m in id_pattern.finditer(text):
snippet = text[max(0, m.start()-10):min(len(text), m.end()+10)]
matches.append({"rule_name": "身份证号", "category_code": "CUST_PERSONAL", "level_code": "L4", "snippet": snippet, "position": m.start()})
# Phone
phone_pattern = re.compile(r"(?<!\d)1[3-9]\d{9}(?!\d)")
for m in phone_pattern.finditer(text):
snippet = text[max(0, m.start()-10):min(len(text), m.end()+10)]
matches.append({"rule_name": "手机号", "category_code": "CUST_PERSONAL", "level_code": "L4", "snippet": snippet, "position": m.start()})
# Bank card (simple 16-19 digits)
bank_pattern = re.compile(r"(?<!\d)\d{16,19}(?!\d)")
for m in bank_pattern.finditer(text):
snippet = text[max(0, m.start()-10):min(len(text), m.end()+10)]
matches.append({"rule_name": "银行卡号", "category_code": "FIN_PAYMENT", "level_code": "L4", "snippet": snippet, "position": m.start()})
# Amount
amount_pattern = re.compile(r"(?<!\d)\d{1,3}(,\d{3})*\.\d{2}(?!\d)")
for m in amount_pattern.finditer(text):
snippet = text[max(0, m.start()-10):min(len(text), m.end()+10)]
matches.append({"rule_name": "金额", "category_code": "FIN_PAYMENT", "level_code": "L3", "snippet": snippet, "position": m.start()})
return matches
def process_unstructured_file(db: Session, file_id: int) -> dict:
file_obj = db.query(UnstructuredFile).filter(UnstructuredFile.id == file_id).first()
if not file_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="文件不存在")
if not file_obj.storage_path:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="文件未上传")
# Download from MinIO to temp
tmp_path = f"/tmp/unstructured_{file_id}_{file_obj.original_name}"
try:
minio_client.fget_object(settings.MINIO_BUCKET_NAME, file_obj.storage_path, tmp_path)
text = extract_text_from_file(tmp_path, file_obj.file_type or "")
file_obj.extracted_text = text[:50000] # limit storage
matches = scan_text_for_sensitive(text)
file_obj.analysis_result = {"matches": matches, "total_chars": len(text)}
file_obj.status = "processed"
db.commit()
return {"success": True, "matches": matches, "total_chars": len(text)}
except Exception as e:
file_obj.status = "error"
db.commit()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
+97
View File
@@ -0,0 +1,97 @@
import secrets
from typing import Optional, Tuple
from sqlalchemy.orm import Session
from app.models.watermark import WatermarkLog
# Zero-width characters for binary encoding
ZW_SPACE = "\u200b" # zero-width space -> 0
ZW_NOJOIN = "\u200c" # zero-width non-joiner -> 1
MARKER = "\u200d" # zero-width joiner -> start marker
def _int_to_binary_bits(n: int, bits: int = 32) -> str:
return format(n, f"0{bits}b")
def _binary_bits_to_int(bits: str) -> int:
return int(bits, 2)
def embed_watermark(text: str, user_id: int, key: str) -> str:
"""Embed invisible watermark into text using zero-width characters."""
# Encode user_id as 32-bit binary
bits = _int_to_binary_bits(user_id)
# Encode key hash as 16-bit for verification
key_bits = _int_to_binary_bits(hash(key) & 0xFFFF, 16)
payload = key_bits + bits
watermark_chars = MARKER + "".join(ZW_NOJOIN if b == "1" else ZW_SPACE for b in payload)
# Append watermark at the end of the text (before trailing newlines if any)
text = text.rstrip("\n")
return text + watermark_chars + "\n"
def extract_watermark(text: str) -> Tuple[Optional[int], Optional[str]]:
"""Extract watermark from text. Returns (user_id, key_hash_bits) or (None, None)."""
if MARKER not in text:
return None, None
idx = text.index(MARKER)
payload = text[idx + len(MARKER):]
bits = ""
for ch in payload:
if ch == ZW_SPACE:
bits += "0"
elif ch == ZW_NOJOIN:
bits += "1"
else:
# Stop at first non-watermark character
break
if len(bits) < 16:
return None, None
key_bits = bits[:16]
user_bits = bits[16:48]
try:
user_id = _binary_bits_to_int(user_bits)
return user_id, key_bits
except Exception:
return None, None
def apply_watermark_to_lines(lines: list, user_id: int, key: str) -> list:
"""Apply watermark to each line of CSV/TXT."""
return [embed_watermark(line, user_id, key) for line in lines]
def create_watermark_log(db: Session, user_id: int, export_type: str, data_scope: dict) -> WatermarkLog:
key = secrets.token_hex(16)
log = WatermarkLog(
user_id=user_id,
export_type=export_type,
data_scope=str(data_scope),
watermark_key=key,
)
db.add(log)
db.commit()
db.refresh(log)
return log
def trace_watermark(db: Session, text: str) -> Optional[dict]:
"""Trace leaked text back to user."""
user_id, _ = extract_watermark(text)
if user_id is None:
return None
log = (
db.query(WatermarkLog)
.filter(WatermarkLog.user_id == user_id)
.order_by(WatermarkLog.created_at.desc())
.first()
)
if not log:
return None
return {
"user_id": log.user_id,
"username": log.user.username if log.user else None,
"export_type": log.export_type,
"data_scope": log.data_scope,
"created_at": log.created_at.isoformat() if log.created_at else None,
}
+39 -9
View File
@@ -1,3 +1,4 @@
import json
from app.tasks.worker import celery_app from app.tasks.worker import celery_app
@@ -5,12 +6,10 @@ from app.tasks.worker import celery_app
def auto_classify_task(self, project_id: int, source_ids: list = None): def auto_classify_task(self, project_id: int, source_ids: list = None):
""" """
Async task to run automatic classification on metadata. Async task to run automatic classification on metadata.
Phase 1 placeholder.
""" """
from app.core.database import SessionLocal from app.core.database import SessionLocal
from app.models.project import ClassificationProject, ClassificationResult, ResultStatus from app.models.project import ClassificationProject
from app.models.classification import RecognitionRule from app.services.classification_engine import run_auto_classification
from app.models.metadata import DataColumn
db = SessionLocal() db = SessionLocal()
try: try:
@@ -18,15 +17,46 @@ def auto_classify_task(self, project_id: int, source_ids: list = None):
if not project: if not project:
return {"status": "failed", "reason": "project not found"} return {"status": "failed", "reason": "project not found"}
# Update project status def progress_callback(scanned, matched, total):
percent = int(scanned / total * 100) if total else 0
meta = {
"scanned": scanned,
"matched": matched,
"total": total,
"percent": percent,
}
self.update_state(state="PROGRESS", meta=meta)
# Persist lightweight progress to DB for UI polling
project.scan_progress = json.dumps(meta)
db.commit()
# Initialize
project.status = "scanning" project.status = "scanning"
project.scan_progress = json.dumps({"scanned": 0, "matched": 0, "total": 0, "percent": 0})
db.commit() db.commit()
rules = db.query(RecognitionRule).filter(RecognitionRule.is_active == True).all() result = run_auto_classification(
# TODO: implement rule matching logic in Phase 2 db,
project_id,
source_ids=source_ids,
progress_callback=progress_callback,
)
project.status = "assigning" if result.get("success"):
project.status = "assigning"
else:
project.status = "created"
project.celery_task_id = None
db.commit() db.commit()
return {"status": "completed", "project_id": project_id, "matched": 0} return {"status": "completed", "project_id": project_id, "result": result}
except Exception as e:
db.rollback()
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if project:
project.status = "created"
project.celery_task_id = None
db.commit()
return {"status": "failed", "reason": str(e)}
finally: finally:
db.close() db.close()
+26
View File
@@ -0,0 +1,26 @@
from app.tasks.worker import celery_app
@celery_app.task(bind=True)
def train_ml_model_task(self, model_name: str = None, algorithm: str = "logistic_regression"):
from app.core.database import SessionLocal
from app.services.ml_service import train_model
db = SessionLocal()
try:
self.update_state(state="PROGRESS", meta={"message": "Fetching training data"})
mv = train_model(db, model_name=model_name, algorithm=algorithm)
if mv:
return {
"status": "completed",
"model_id": mv.id,
"name": mv.name,
"accuracy": mv.accuracy,
"train_samples": mv.train_samples,
}
else:
return {"status": "failed", "reason": "Not enough training data (need >= 20 samples)"}
except Exception as e:
return {"status": "failed", "reason": str(e)}
finally:
db.close()
+2 -2
View File
@@ -2,10 +2,10 @@ from celery import Celery
from app.core.config import settings from app.core.config import settings
celery_app = Celery( celery_app = Celery(
"prop_data_guard", "data_pointer",
broker=settings.REDIS_URL, broker=settings.REDIS_URL,
backend=settings.REDIS_URL, backend=settings.REDIS_URL,
include=["app.tasks.classification_tasks"], include=["app.tasks.classification_tasks", "app.tasks.ml_tasks"],
) )
celery_app.conf.update( celery_app.conf.update(
+2 -1
View File
@@ -20,4 +20,5 @@ openpyxl==3.1.2
python-docx==1.1.2 python-docx==1.1.2
scikit-learn==1.5.0 scikit-learn==1.5.0
numpy==1.26.4 numpy==1.26.4
pandas==2.2.2 pandas==2.2.2
requests==2.32.3
+126
View File
@@ -0,0 +1,126 @@
import sys, requests
BASE = "http://localhost:8000"
API = f"{BASE}/api/v1"
errors, passed = [], []
def check(name, ok, detail=""):
if ok:
passed.append(name); print(f"{name}")
else:
errors.append((name, detail)); print(f"{name}: {detail}")
def get_items(resp):
d = resp.json().get("data", [])
if isinstance(d, list):
return d
return d.get("items", [])
def get_total(resp):
return resp.json().get("total", 0)
def main():
print("\n[1/15] Health")
r = requests.get(f"{BASE}/health")
check("health", r.status_code == 200 and r.json().get("status") == "ok")
print("\n[2/15] Auth")
r = requests.post(f"{API}/auth/login", json={"username": "admin", "password": "Zhidi@n2023"})
check("login", r.status_code == 200)
token = r.json().get("data", {}).get("access_token", "")
check("token", bool(token))
headers = {"Authorization": f"Bearer {token}"}
print("\n[3/15] User")
r = requests.get(f"{API}/users/me", headers=headers)
check("me", r.status_code == 200 and r.json()["data"]["username"] == "admin")
r = requests.get(f"{API}/users?page_size=100", headers=headers)
check("users", r.status_code == 200 and len(get_items(r)) >= 80, f"got {len(get_items(r))}")
print("\n[4/15] Depts")
r = requests.get(f"{API}/users/depts", headers=headers)
check("depts", r.status_code == 200 and len(r.json().get("data", [])) >= 12, f"got {len(r.json().get('data', []))}")
print("\n[5/15] DataSources")
r = requests.get(f"{API}/datasources", headers=headers)
check("datasources", r.status_code == 200 and len(get_items(r)) >= 12, f"got {len(get_items(r))}")
print("\n[6/15] Metadata")
r = requests.get(f"{API}/metadata/databases", headers=headers)
check("databases", r.status_code == 200 and len(get_items(r)) >= 31, f"got {len(get_items(r))}")
r = requests.get(f"{API}/metadata/tables", headers=headers)
check("tables", r.status_code == 200 and len(get_items(r)) >= 800, f"got {len(get_items(r))}")
r = requests.get(f"{API}/metadata/columns", headers=headers)
check("columns", r.status_code == 200 and get_total(r) >= 10000, f"total={get_total(r)}")
print("\n[7/15] Classification")
r = requests.get(f"{API}/classifications/levels", headers=headers)
check("levels", r.status_code == 200 and len(r.json().get("data", [])) == 5)
r = requests.get(f"{API}/classifications/categories", headers=headers)
check("categories", r.status_code == 200 and len(r.json().get("data", [])) >= 20, f"got {len(r.json().get('data', []))}")
r = requests.get(f"{API}/classifications/results", headers=headers)
check("results", r.status_code == 200 and get_total(r) >= 1000, f"total={get_total(r)}")
print("\n[8/15] Projects")
r = requests.get(f"{API}/projects", headers=headers)
check("projects", r.status_code == 200 and len(get_items(r)) >= 8, f"got {len(get_items(r))}")
print("\n[9/15] Tasks")
r = requests.get(f"{API}/tasks/my-tasks", headers=headers)
check("tasks", r.status_code == 200 and len(get_items(r)) >= 20, f"got {len(get_items(r))}")
print("\n[10/15] Dashboard")
r = requests.get(f"{API}/dashboard/stats", headers=headers)
check("stats", r.status_code == 200)
stats = r.json().get("data", {})
check("stats.data_sources", stats.get("data_sources", 0) >= 12, f"got {stats.get('data_sources')}")
check("stats.tables", stats.get("tables", 0) >= 800, f"got {stats.get('tables')}")
check("stats.columns", stats.get("columns", 0) >= 10000, f"got {stats.get('columns')}")
check("stats.labeled", stats.get("labeled", 0) >= 10000, f"got {stats.get('labeled')}")
r = requests.get(f"{API}/dashboard/distribution", headers=headers)
check("distribution", r.status_code == 200 and "level_distribution" in r.json().get("data", {}))
print("\n[11/15] Reports")
r = requests.get(f"{API}/reports/stats", headers=headers)
check("report stats", r.status_code == 200)
print("\n[12/15] Masking")
r = requests.get(f"{API}/masking/rules", headers=headers)
check("masking rules", r.status_code == 200)
print("\n[13/15] Watermark")
r = requests.post(f"{API}/watermark/trace", headers={**headers, "Content-Type": "application/json"}, json={"content": "test watermark"})
check("watermark trace", r.status_code == 200)
print("\n[14/15] Risk")
r = requests.get(f"{API}/risk/top", headers=headers)
check("risk top", r.status_code == 200)
print("\n[15/15] Compliance")
r = requests.get(f"{API}/compliance/issues", headers=headers)
check("compliance issues", r.status_code == 200)
print("\n[Bonus] Additional modules")
r = requests.get(f"{API}/lineage/graph", headers=headers)
check("lineage graph", r.status_code == 200 and "nodes" in r.json().get("data", {}))
r = requests.get(f"{API}/alerts/records", headers=headers)
check("alert records", r.status_code == 200)
r = requests.get(f"{API}/schema-changes/logs", headers=headers)
check("schema changes logs", r.status_code == 200)
r = requests.get(f"{API}/unstructured/files", headers=headers)
check("unstructured files", r.status_code == 200)
r = requests.get(f"{API}/api-assets", headers=headers)
check("api assets", r.status_code == 200)
print("\n" + "="*60)
print(f"Results: {len(passed)} passed, {len(errors)} failed")
print("="*60)
if errors:
for n, d in errors: print(f"{n}: {d}")
sys.exit(1)
else:
print("🎉 All integration tests passed!")
sys.exit(0)
if __name__ == "__main__":
main()
@@ -0,0 +1,59 @@
"""
Generate synthetic manual-labeled data for ML model training/demo.
Run this script after metadata has been scanned so there are columns to label.
"""
import random
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from app.core.database import SessionLocal
from app.models.metadata import DataColumn
from app.models.classification import Category
from app.models.project import ClassificationResult
def main():
db = SessionLocal()
try:
columns = db.query(DataColumn).limit(300).all()
if not columns:
print("No columns found in database. Please scan a data source first.")
return
categories = db.query(Category).filter(Category.level == 2).all()
if not categories:
print("No sub-categories found.")
return
# Clear old manual labels to avoid duplicates
db.query(ClassificationResult).filter(ClassificationResult.source == "manual").delete()
db.commit()
count = 0
for col in columns:
# Deterministic pseudo-random based on column name for reproducibility
rng = random.Random(col.name)
cat = rng.choice(categories)
# Create a fake manual result (project_id=1 assumed to exist or None)
result = ClassificationResult(
project_id=None,
column_id=col.id,
category_id=cat.id,
level_id=cat.parent.level if cat.parent else 3, # fallback
source="manual",
confidence=1.0,
status="manual",
)
db.add(result)
count += 1
db.commit()
print(f"Generated {count} manual labels across {len(categories)} categories.")
finally:
db.close()
if __name__ == "__main__":
main()
+2 -2
View File
@@ -1,5 +1,5 @@
""" """
Generate test data for PropDataGuard system. Generate test data for DataPointer system.
Targets: 10000+ records across all tables. Targets: 10000+ records across all tables.
""" """
import sys import sys
@@ -110,7 +110,7 @@ for i in range(80):
username = f"user{i+2:03d}" username = f"user{i+2:03d}"
user = User( user = User(
username=username, username=username,
email=f"{username}@propdataguard.com", email=f"{username}@datapo.com",
hashed_password=get_password_hash("password123"), hashed_password=get_password_hash("password123"),
real_name=real, real_name=real,
phone=random_phone(), phone=random_phone(),
+542
View File
@@ -0,0 +1,542 @@
"""
Generate test data for DataPointer system.
Targets: 10000+ records across all tables.
"""
import sys
sys.path.insert(0, '/app')
import random
import string
import json
from datetime import datetime, timedelta
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.core.database import Base
from app.models.user import User, Role, Dept, UserRole
from app.models.metadata import DataSource, Database, DataTable, DataColumn
from app.models.classification import Category, DataLevel, RecognitionRule, ClassificationTemplate
from app.models.project import ClassificationProject, ClassificationTask, ClassificationResult, ResultStatus
from app.models.log import OperationLog
from app.core.security import get_password_hash
# Database connection
DATABASE_URL = "postgresql+psycopg2://pdg:pdg_secret_2024@db:5432/prop_data_guard"
engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(bind=engine)
db = SessionLocal()
# Clear existing test data (preserve admin user and built-in data)
print("Clearing existing test data...")
db.query(ClassificationResult).delete(synchronize_session=False)
db.query(ClassificationTask).delete(synchronize_session=False)
db.query(ClassificationProject).delete(synchronize_session=False)
db.query(DataColumn).delete(synchronize_session=False)
db.query(DataTable).delete(synchronize_session=False)
db.query(Database).delete(synchronize_session=False)
db.query(UserRole).filter(UserRole.user_id > 1).delete(synchronize_session=False)
db.query(User).filter(User.id > 1).delete(synchronize_session=False)
db.query(Dept).filter(Dept.id > 1).delete(synchronize_session=False)
db.query(OperationLog).delete(synchronize_session=False)
db.commit()
# Reset all sequences to avoid ID conflicts
from sqlalchemy import text
sequences = [
"sys_dept_id_seq", "sys_user_id_seq", "sys_user_role_id_seq",
"data_source_id_seq", "meta_database_id_seq", "meta_table_id_seq", "meta_column_id_seq",
"classification_project_id_seq", "classification_task_id_seq", "classification_result_id_seq",
"classification_change_id_seq", "sys_operation_log_id_seq",
]
for seq in sequences:
db.execute(text(f"ALTER SEQUENCE {seq} RESTART WITH 100"))
db.commit()
print(" Sequences reset")
random.seed(42)
# ============================================================
# 1. Departments
# ============================================================
print("Generating departments...")
root_dept_names = ["数据安全部", "合规管理部", "信息技术部"]
root_depts = []
for name in root_dept_names:
d = Dept(name=name, parent_id=None, sort_order=len(root_depts))
db.add(d)
root_depts.append(d)
db.commit()
for d in root_depts:
db.refresh(d)
# Map root depts by index: 0=数据安全部, 1=合规管理部, 2=信息技术部
root_id_map = {i+1: d.id for i, d in enumerate(root_depts)}
child_dept_defs = [
("业务一部", root_id_map[1]), ("业务二部", root_id_map[1]),
("车险事业部", root_id_map[3]), ("非车险事业部", root_id_map[3]), ("理赔服务部", root_id_map[3]),
("财务部", root_id_map[2]), ("精算部", root_id_map[2]),
("客户服务部", root_id_map[1]), ("渠道管理部", root_id_map[1]),
]
depts = root_depts[:]
for name, pid in child_dept_defs:
d = Dept(name=name, parent_id=pid, sort_order=len(depts))
db.add(d)
depts.append(d)
db.commit()
for d in depts[len(root_depts):]:
db.refresh(d)
print(f" Created {len(depts)} departments")
# ============================================================
# 2. Users
# ============================================================
print("Generating users...")
roles = db.query(Role).all()
role_map = {r.code: r.id for r in roles}
first_names = ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]
last_names = ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "秀英", "", ""]
def random_name():
return random.choice(first_names) + random.choice(last_names)
def random_phone():
return "1" + random.choice(["3","4","5","6","7","8","9"]) + "".join(random.choices(string.digits, k=9))
users = []
for i in range(120):
real = random_name()
username = f"user{i+2:03d}"
user = User(
username=username,
email=f"{username}@datapo.com",
hashed_password=get_password_hash("password123"),
real_name=real,
phone=random_phone(),
is_active=random.random() > 0.05,
is_superuser=False,
dept_id=random.choice(depts).id,
)
db.add(user)
users.append(user)
db.commit()
for u in users:
db.refresh(u)
# Assign roles
role_list = list(roles)
for u in users:
assigned_roles = random.sample(role_list, k=random.randint(1, 2))
for r in assigned_roles:
db.add(UserRole(user_id=u.id, role_id=r.id))
db.commit()
print(f" Created {len(users)} users")
# ============================================================
# 3. Data Sources
# ============================================================
print("Generating data sources...")
source_types = ["postgresql", "mysql", "oracle", "sqlserver", "dm"]
source_configs = [
("核心保单数据库", "postgresql", "db-core-prod", 5432, "core_policy"),
("理赔系统数据库", "mysql", "db-claim-prod", 3306, "claim_db"),
("财务数据仓库", "postgresql", "db-finance-dw", 5432, "finance_dw"),
("客户信息主库", "mysql", "db-cust-master", 3306, "customer_master"),
("渠道管理系统", "oracle", "db-channel-ora", 1521, "CHANNEL"),
("精算分析平台", "postgresql", "db-actuary-ana", 5432, "actuary_analytics"),
("监管报送库", "mysql", "db-regulatory", 3306, "regulatory_report"),
("车辆信息库", "postgresql", "db-vehicle", 5432, "vehicle_db"),
("非车险业务库", "sqlserver", "db-nonauto", 1433, "NonAutoDB"),
("历史归档库", "postgresql", "db-archive", 5432, "archive_db"),
("测试环境核心库", "postgresql", "db-core-test", 5432, "core_test"),
("达梦国产数据库", "dm", "db-dameng-prod", 5236, "DAMENG"),
]
sources = []
for name, stype, host, port, dbname in source_configs:
ds = DataSource(
name=name,
source_type=stype,
host=f"{host}.internal.company.com",
port=port,
database_name=dbname,
username=f"{stype}_admin",
encrypted_password=None,
status="active" if random.random() > 0.1 else "error",
dept_id=random.choice(depts).id,
created_by=random.choice(users).id,
)
db.add(ds)
sources.append(ds)
db.commit()
for s in sources:
db.refresh(s)
print(f" Created {len(sources)} data sources")
# ============================================================
# 4. Databases
# ============================================================
print("Generating databases...")
databases = []
for source in sources:
num_dbs = random.randint(1, 3)
for i in range(num_dbs):
d = Database(
source_id=source.id,
name=f"{source.database_name}_{i+1}" if num_dbs > 1 else source.database_name,
charset="UTF8" if source.source_type != "sqlserver" else "Chinese_PRC_CI_AS",
table_count=0,
)
db.add(d)
databases.append(d)
db.commit()
for d in databases:
db.refresh(d)
print(f" Created {len(databases)} databases")
# ============================================================
# 5. Data Tables & Columns (the big one)
# ============================================================
print("Generating tables and columns...")
table_prefixes = {
"policy": ["t_policy", "t_policy_detail", "t_policy_extension", "t_policy_history", "t_endorsement"],
"claim": ["t_claim", "t_claim_detail", "t_claim_payment", "t_claim_document", "t_survey"],
"customer": ["t_customer", "t_customer_contact", "t_customer_identity", "t_customer_vehicle", "t_customer_preference"],
"finance": ["t_payment", "t_receipt", "t_invoice", "t_commission", "t_reserve"],
"channel": ["t_agent", "t_agent_contract", "t_partner", "t_broker", "t_sales_record"],
"actuary": ["t_pricing_model", "t_risk_factor", "t_loss_ratio", "t_reserve_calc", "t_solvency"],
"regulatory": ["t_report_cbrc", "t_report_circ", "t_stat_premium", "t_stat_claim", "t_stat_channel"],
"vehicle": ["t_vehicle", "t_vehicle_model", "t_vehicle_usage", "t_vehicle_accident", "t_vehicle_maintenance"],
"system": ["t_user", "t_role", "t_permission", "t_log", "t_config", "t_dict"],
"archive": ["t_archive_policy", "t_archive_claim", "t_archive_customer", "t_archive_finance"],
}
column_templates = [
("id", "BIGINT", "主键ID", "system", 2),
("created_at", "TIMESTAMP", "创建时间", "system", 2),
("updated_at", "TIMESTAMP", "更新时间", "system", 2),
("is_deleted", "BOOLEAN", "是否删除", "system", 2),
("created_by", "BIGINT", "创建人", "system", 2),
("customer_name", "VARCHAR", "客户姓名", "customer", 4),
("customer_id_no", "VARCHAR", "客户身份证号", "customer", 4),
("mobile_phone", "VARCHAR", "手机号码", "customer", 4),
("email", "VARCHAR", "电子邮箱", "customer", 3),
("address", "VARCHAR", "联系地址", "customer", 3),
("bank_account", "VARCHAR", "银行账户", "finance", 4),
("bank_card_no", "VARCHAR", "银行卡号", "finance", 4),
("policy_no", "VARCHAR", "保单号", "policy", 3),
("policy_status", "VARCHAR", "保单状态", "policy", 2),
("premium_amount", "DECIMAL", "保费金额", "finance", 3),
("claim_no", "VARCHAR", "理赔号", "claim", 3),
("claim_amount", "DECIMAL", "理赔金额", "claim", 4),
("loss_description", "TEXT", "损失描述", "claim", 3),
("accident_location", "VARCHAR", "出险地点", "claim", 3),
("vehicle_plate", "VARCHAR", "车牌号", "vehicle", 3),
("vin_code", "VARCHAR", "车辆识别代码VIN", "vehicle", 4),
("agent_name", "VARCHAR", "代理人姓名", "channel", 3),
("agent_license", "VARCHAR", "代理人执业证号", "channel", 3),
("commission_rate", "DECIMAL", "佣金比例", "finance", 3),
("reserve_amount", "DECIMAL", "准备金金额", "finance", 5),
("solvency_ratio", "DECIMAL", "偿付能力充足率", "finance", 5),
("password_hash", "VARCHAR", "密码哈希", "system", 5),
("api_secret", "VARCHAR", "API密钥", "system", 5),
("session_token", "VARCHAR", "会话令牌", "system", 4),
("gps_location", "VARCHAR", "GPS定位信息", "vehicle", 4),
("driving_record", "TEXT", "行驶记录", "vehicle", 4),
("medical_record", "TEXT", "医疗记录", "claim", 4),
("income_info", "DECIMAL", "收入信息", "customer", 4),
("credit_score", "INT", "信用评分", "customer", 4),
("family_member", "VARCHAR", "家庭成员信息", "customer", 3),
("emergency_contact", "VARCHAR", "紧急联系人", "customer", 3),
("beneficiary_name", "VARCHAR", "受益人姓名", "policy", 4),
("beneficiary_id_no", "VARCHAR", "受益人身份证号", "policy", 4),
("underwriting_decision", "VARCHAR", "核保结论", "policy", 3),
("risk_score", "DECIMAL", "风险评分", "actuary", 3),
("fraud_flag", "BOOLEAN", "欺诈标记", "claim", 3),
("audit_comment", "TEXT", "审计意见", "system", 3),
("report_period", "VARCHAR", "报表期间", "regulatory", 2),
("regulatory_code", "VARCHAR", "监管编码", "regulatory", 2),
]
all_tables = []
all_columns = []
for database in databases:
prefix_key = "system"
for k in table_prefixes:
if k in database.name.lower() or k in database.source.name.lower():
prefix_key = k
break
prefix_list = table_prefixes.get(prefix_key, table_prefixes["system"])
num_tables = random.randint(25, 60)
for tidx in range(num_tables):
table_name = f"{random.choice(prefix_list)}_{tidx+1:03d}"
tbl = DataTable(
database_id=database.id,
name=table_name,
comment=f"{table_name}数据表",
row_count=random.randint(10000, 10000000),
column_count=0,
)
db.add(tbl)
all_tables.append(tbl)
db.commit()
for t in all_tables:
db.refresh(t)
print(f" Created {len(all_tables)} tables")
# Now generate columns
print(" Generating columns (this may take a moment)...")
levels = db.query(DataLevel).all()
level_map = {l.code: l.id for l in levels}
categories = db.query(Category).all()
cat_map = {}
for c in categories:
if c.code.startswith("CUST") and "customer" not in cat_map:
cat_map["customer"] = c.id
elif c.code.startswith("POLICY") and "policy" not in cat_map:
cat_map["policy"] = c.id
elif c.code.startswith("CLAIM") and "claim" not in cat_map:
cat_map["claim"] = c.id
elif c.code.startswith("FIN") and "finance" not in cat_map:
cat_map["finance"] = c.id
elif c.code.startswith("CHANNEL") and "channel" not in cat_map:
cat_map["channel"] = c.id
elif c.code.startswith("REG") and "regulatory" not in cat_map:
cat_map["regulatory"] = c.id
elif c.code.startswith("INT") and "system" not in cat_map:
cat_map["system"] = c.id
elif c.code.startswith("SUB") and "vehicle" not in cat_map:
cat_map["vehicle"] = c.id
sample_values = {
"customer_name": ["张三", "李四", "王五", "赵六", "钱七"],
"customer_id_no": ["110101199001011234", "310101198502023456", "440106197803034567"],
"mobile_phone": ["13800138000", "13900139000", "13700137000"],
"email": ["user1@example.com", "user2@test.com", "contact@company.com"],
"bank_card_no": ["6222021234567890123", "6228481234567890123"],
"vin_code": ["LSVAG2180E2100001", "LFV3A28K8A3000001"],
"vehicle_plate": ["京A12345", "沪B67890", "粤C11111"],
"policy_no": ["PICC2024000001", "PICC2024000002", "PICC2024000003"],
"claim_no": ["CLM2024000001", "CLM2024000002", "CLM2024000003"],
"address": ["北京市海淀区xxx路1号", "上海市浦东新区xxx路2号"],
}
batch_size = 500
column_batch = []
for tbl in all_tables:
num_cols = random.randint(15, 35)
selected_templates = random.sample(column_templates, k=min(num_cols, len(column_templates)))
for cidx, (col_name, col_type, comment, cat_hint, lvl_hint) in enumerate(selected_templates):
actual_name = col_name if cidx == 0 else f"{col_name}_{cidx}"
samples = None
if col_name in sample_values:
samples = json.dumps(random.sample(sample_values[col_name], k=min(3, len(sample_values[col_name]))), ensure_ascii=False)
col = DataColumn(
table_id=tbl.id,
name=actual_name,
data_type=col_type,
length=random.choice([20, 50, 100, 200, 500]) if "VARCHAR" in col_type else None,
comment=comment,
is_nullable=random.random() > 0.2,
sample_data=samples,
)
column_batch.append(col)
if len(column_batch) >= batch_size:
db.bulk_save_objects(column_batch)
db.commit()
all_columns.extend(column_batch)
column_batch = []
if column_batch:
db.bulk_save_objects(column_batch)
db.commit()
all_columns.extend(column_batch)
print(f" Created {len(all_columns)} columns")
# Update table counts
for tbl in all_tables:
tbl.column_count = db.query(DataColumn).filter(DataColumn.table_id == tbl.id).count()
db.add(tbl)
db.commit()
for database in databases:
database.table_count = db.query(DataTable).filter(DataTable.database_id == database.id).count()
db.add(database)
db.commit()
# ============================================================
# 6. Classification Projects
# ============================================================
print("Generating classification projects...")
templates = db.query(ClassificationTemplate).all()
projects = []
project_names = [
"2024年度数据分类分级专项",
"核心系统敏感数据梳理",
"新核心上线数据定级",
"客户个人信息保护专项",
"财务数据安全治理",
"理赔数据合规检查",
"渠道数据梳理项目",
"监管报送数据定级",
]
for i, name in enumerate(project_names):
p = ClassificationProject(
name=name,
template_id=random.choice(templates).id,
description=f"{name} - 数据分类分级治理项目",
status=random.choice(["created", "scanning", "labeling", "reviewing", "published"]),
target_source_ids=",".join(str(s.id) for s in random.sample(sources, k=random.randint(2, 5))),
planned_start=datetime.now() - timedelta(days=random.randint(10, 60)),
planned_end=datetime.now() + timedelta(days=random.randint(10, 90)),
created_by=random.choice(users).id,
)
db.add(p)
projects.append(p)
db.commit()
for p in projects:
db.refresh(p)
print(f" Created {len(projects)} projects")
# ============================================================
# 7. Classification Results (the critical mass)
# ============================================================
print("Generating classification results...")
# Re-fetch column IDs from DB since bulk_save_objects doesn't populate object IDs
col_rows = db.query(DataColumn.id).all()
all_col_ids = [c[0] for c in col_rows]
random.shuffle(all_col_ids)
result_batch = []
total_results_target = 20000
results_per_project = total_results_target // len(projects)
for proj in projects:
assigned_cols = random.sample(all_col_ids, k=min(results_per_project, len(all_col_ids)))
for col_id in assigned_cols:
source_type = random.choices(["auto", "manual"], weights=[0.7, 0.3])[0]
status_val = "auto" if source_type == "auto" else random.choice(["manual", "reviewed"])
cat = random.choice(categories)
lvl = random.choice(levels)
conf = round(random.uniform(0.3, 0.98), 2)
r = ClassificationResult(
project_id=proj.id,
column_id=col_id,
category_id=cat.id,
level_id=lvl.id,
source=source_type,
confidence=conf,
status=status_val,
labeler_id=random.choice(users).id if source_type == "manual" else None,
)
result_batch.append(r)
if len(result_batch) >= batch_size:
db.bulk_save_objects(result_batch)
db.commit()
result_batch = []
if result_batch:
db.bulk_save_objects(result_batch)
db.commit()
total_results = db.query(ClassificationResult).count()
print(f" Created {total_results} classification results")
# ============================================================
# 8. Classification Tasks
# ============================================================
print("Generating classification tasks...")
tasks = []
for proj in projects:
num_tasks = random.randint(2, 5)
for tidx in range(num_tasks):
task = ClassificationTask(
project_id=proj.id,
name=f"{proj.name}-任务{tidx+1}",
assigner_id=random.choice(users).id,
assignee_id=random.choice(users).id,
target_type="column",
status=random.choice(["pending", "in_progress", "completed"]),
deadline=datetime.now() + timedelta(days=random.randint(5, 30)),
)
db.add(task)
tasks.append(task)
db.commit()
print(f" Created {len(tasks)} tasks")
# ============================================================
# 9. Operation Logs
# ============================================================
print("Generating operation logs...")
log_actions = ["登录", "查询数据源", "创建项目", "自动分类", "人工打标", "导出报告", "修改规则", "删除任务"]
log_modules = ["auth", "datasource", "project", "classification", "task", "report", "rule", "system"]
log_batch = []
for i in range(8000):
log = OperationLog(
user_id=random.choice([None] + [u.id for u in users]),
username=random.choice(["admin"] + [u.username for u in users]),
module=random.choice(log_modules),
action=random.choice(log_actions),
method=random.choice(["GET", "POST", "PUT", "DELETE"]),
path=f"/api/v1/{random.choice(log_modules)}/{random.randint(1, 100)}",
ip=f"10.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(0,255)}",
status_code=random.choice([200, 200, 200, 201, 400, 401, 404, 500]),
duration_ms=random.randint(10, 2000),
created_at=datetime.now() - timedelta(days=random.randint(0, 30), hours=random.randint(0, 23)),
)
log_batch.append(log)
if len(log_batch) >= batch_size:
db.bulk_save_objects(log_batch)
db.commit()
log_batch = []
if log_batch:
db.bulk_save_objects(log_batch)
db.commit()
total_logs = db.query(OperationLog).count()
print(f" Created {total_logs} operation logs")
# ============================================================
# Summary
# ============================================================
print("\n" + "="*60)
print("Test data generation complete!")
print("="*60)
print(f" Departments: {db.query(Dept).count()}")
print(f" Users: {db.query(User).count()}")
print(f" Data Sources: {db.query(DataSource).count()}")
print(f" Databases: {db.query(Database).count()}")
print(f" Tables: {db.query(DataTable).count()}")
print(f" Columns: {db.query(DataColumn).count()}")
print(f" Categories: {db.query(Category).count()}")
print(f" Data Levels: {db.query(DataLevel).count()}")
print(f" Rules: {db.query(RecognitionRule).count()}")
print(f" Templates: {db.query(ClassificationTemplate).count()}")
print(f" Projects: {db.query(ClassificationProject).count()}")
print(f" Tasks: {db.query(ClassificationTask).count()}")
print(f" Results: {db.query(ClassificationResult).count()}")
print(f" Operation Logs: {db.query(OperationLog).count()}")
print("="*60)
db.close()
+176
View File
@@ -0,0 +1,176 @@
import sys
with open(sys.argv[1], 'r') as f:
content = f.read()
marker = 'def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:'
idx = content.find(marker)
if idx == -1:
print('Marker not found')
sys.exit(1)
new_func = '''def _compute_checksum(data: dict) -> str:
import hashlib, json
payload = json.dumps(data, sort_keys=True, ensure_ascii=False, default=str)
return hashlib.sha256(payload.encode()).hexdigest()[:32]
def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
from sqlalchemy import create_engine, inspect, text
import json
from datetime import datetime
source = get_datasource(db, source_id)
if not source:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在")
driver_map = {
"mysql": "mysql+pymysql",
"postgresql": "postgresql+psycopg2",
"oracle": "oracle+cx_oracle",
"sqlserver": "mssql+pymssql",
}
driver = driver_map.get(source.source_type, source.source_type)
if source.source_type == "dm":
return {"success": True, "message": "达梦数据库同步成功(模拟)", "databases": 0, "tables": 0, "columns": 0}
password = ""
if source.encrypted_password:
try:
password = _decrypt_password(source.encrypted_password)
except Exception:
pass
try:
url = f"{driver}://{source.username}:{password}@{source.host}:{source.port}/{source.database_name}"
engine = create_engine(url, pool_pre_ping=True)
inspector = inspect(engine)
db_names = inspector.get_schema_names() or [source.database_name]
scan_time = datetime.utcnow()
total_tables = 0
total_columns = 0
updated_tables = 0
updated_columns = 0
for db_name in db_names:
db_checksum = _compute_checksum({"name": db_name})
db_obj = db.query(Database).filter(
Database.source_id == source.id, Database.name == db_name
).first()
if not db_obj:
db_obj = Database(source_id=source.id, name=db_name, checksum=db_checksum, last_scanned_at=scan_time)
db.add(db_obj)
else:
db_obj.checksum = db_checksum
db_obj.last_scanned_at = scan_time
db_obj.is_deleted = False
db_obj.deleted_at = None
table_names = inspector.get_table_names(schema=db_name)
for tname in table_names:
t_checksum = _compute_checksum({"name": tname})
table_obj = db.query(DataTable).filter(
DataTable.database_id == db_obj.id, DataTable.name == tname
).first()
if not table_obj:
table_obj = DataTable(database_id=db_obj.id, name=tname, checksum=t_checksum, last_scanned_at=scan_time)
db.add(table_obj)
else:
if table_obj.checksum != t_checksum:
table_obj.checksum = t_checksum
updated_tables += 1
table_obj.last_scanned_at = scan_time
table_obj.is_deleted = False
table_obj.deleted_at = None
columns = inspector.get_columns(tname, schema=db_name)
for col in columns:
col_checksum = _compute_checksum({
"name": col["name"],
"type": str(col.get("type", "")),
"max_length": col.get("max_length"),
"comment": col.get("comment"),
"nullable": col.get("nullable", True),
})
col_obj = db.query(DataColumn).filter(
DataColumn.table_id == table_obj.id, DataColumn.name == col["name"]
).first()
if not col_obj:
sample = None
try:
with engine.connect() as conn:
result = conn.execute(text(f'SELECT "{col["name"]}" FROM "{db_name}"."{tname}" LIMIT 5'))
samples = [str(r[0]) for r in result if r[0] is not None]
sample = json.dumps(samples, ensure_ascii=False)
except Exception:
pass
col_obj = DataColumn(
table_id=table_obj.id,
name=col["name"],
data_type=str(col.get("type", "")),
length=col.get("max_length"),
comment=col.get("comment"),
is_nullable=col.get("nullable", True),
sample_data=sample,
checksum=col_checksum,
last_scanned_at=scan_time,
)
db.add(col_obj)
total_columns += 1
else:
if col_obj.checksum != col_checksum:
col_obj.checksum = col_checksum
col_obj.data_type = str(col.get("type", ""))
col_obj.length = col.get("max_length")
col_obj.comment = col.get("comment")
col_obj.is_nullable = col.get("nullable", True)
updated_columns += 1
col_obj.last_scanned_at = scan_time
col_obj.is_deleted = False
col_obj.deleted_at = None
total_tables += 1
# Soft-delete objects not seen in this scan
db.query(Database).filter(
Database.source_id == source.id,
Database.last_scanned_at < scan_time,
).update({"is_deleted": True, "deleted_at": scan_time}, synchronize_session=False)
for db_obj in db.query(Database).filter(Database.source_id == source.id).all():
db.query(DataTable).filter(
DataTable.database_id == db_obj.id,
DataTable.last_scanned_at < scan_time,
).update({"is_deleted": True, "deleted_at": scan_time}, synchronize_session=False)
for table_obj in db.query(DataTable).filter(DataTable.database_id == db_obj.id).all():
db.query(DataColumn).filter(
DataColumn.table_id == table_obj.id,
DataColumn.last_scanned_at < scan_time,
).update({"is_deleted": True, "deleted_at": scan_time}, synchronize_session=False)
source.status = "active"
db.commit()
return {
"success": True,
"message": "元数据同步成功",
"databases": len(db_names),
"tables": total_tables,
"columns": total_columns,
"updated_tables": updated_tables,
"updated_columns": updated_columns,
}
except Exception as e:
source.status = "error"
db.commit()
return {"success": False, "message": f"同步失败: {str(e)}", "databases": 0, "tables": 0, "columns": 0}
'''
new_content = content[:idx] + new_func
with open(sys.argv[1], 'w') as f:
f.write(new_content)
print('Patched successfully')
+1 -1
View File
@@ -47,7 +47,7 @@ def test_health_check():
def test_login(): def test_login():
response = client.post("/api/v1/auth/login", json={"username": "admin", "password": "admin123"}) response = client.post("/api/v1/auth/login", json={"username": "admin", "password": "Zhidi@n2023"})
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["code"] == 200 assert data["code"] == 200
+135
View File
@@ -0,0 +1,135 @@
version: "3.8"
services:
db:
image: postgres:16-alpine
container_name: pdg-postgres
environment:
POSTGRES_USER: pdg
POSTGRES_PASSWORD: pdg_secret_2024
POSTGRES_DB: prop_data_guard
volumes:
- pg_data:/var/lib/postgresql/data
expose:
- "5432"
healthcheck:
test: ["CMD-SHELL", "pg_isready -U pdg -d prop_data_guard"]
interval: 5s
timeout: 5s
retries: 5
restart: unless-stopped
redis:
image: redis:7-alpine
container_name: pdg-redis
expose:
- "6379"
volumes:
- redis_data:/data
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 5s
retries: 5
restart: unless-stopped
minio:
image: minio/minio:RELEASE.2024-05-10T01-41-38Z
container_name: pdg-minio
environment:
MINIO_ROOT_USER: pdgminio
MINIO_ROOT_PASSWORD: pdgminio_secret_2024
command: server /data --console-address ":9001"
volumes:
- minio_data:/data
expose:
- "9000"
- "9001"
restart: unless-stopped
backend:
build: ./backend
container_name: pdg-backend
environment:
- DATABASE_URL=postgresql+psycopg2://pdg:pdg_secret_2024@db:5432/prop_data_guard
- REDIS_URL=redis://redis:6379/0
- MINIO_ENDPOINT=minio:9000
- MINIO_ACCESS_KEY=pdgminio
- MINIO_SECRET_KEY=pdgminio_secret_2024
- SECRET_KEY=${SECRET_KEY:-prop-data-guard-production-secret-key}
- DB_ENCRYPTION_KEY=${DB_ENCRYPTION_KEY:-}
- ACCESS_TOKEN_EXPIRE_MINUTES=30
- REFRESH_TOKEN_EXPIRE_DAYS=7
expose:
- "8000"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
command: >
sh -c "alembic upgrade head && uvicorn app.main:app --host 0.0.0.0 --port 8000"
restart: unless-stopped
frontend:
build:
context: ./frontend
dockerfile: Dockerfile.prod
container_name: pdg-frontend
ports:
- "80:80"
- "443:443"
volumes:
- ./ssl:/etc/nginx/ssl:ro
depends_on:
- backend
restart: unless-stopped
celery_worker:
build: ./backend
container_name: pdg-celery-worker
environment:
- DATABASE_URL=postgresql+psycopg2://pdg:pdg_secret_2024@db:5432/prop_data_guard
- REDIS_URL=redis://redis:6379/0
- SECRET_KEY=${SECRET_KEY:-prop-data-guard-production-secret-key}
- DB_ENCRYPTION_KEY=${DB_ENCRYPTION_KEY:-}
depends_on:
- db
- redis
command: >
sh -c "celery -A app.tasks.worker worker --loglevel=info --concurrency=2"
restart: unless-stopped
celery_beat:
build: ./backend
container_name: pdg-celery-beat
environment:
- DATABASE_URL=postgresql+psycopg2://pdg:pdg_secret_2024@db:5432/prop_data_guard
- REDIS_URL=redis://redis:6379/0
- SECRET_KEY=${SECRET_KEY:-prop-data-guard-production-secret-key}
- DB_ENCRYPTION_KEY=${DB_ENCRYPTION_KEY:-}
depends_on:
- db
- redis
command: >
sh -c "celery -A app.tasks.worker beat --loglevel=info"
restart: unless-stopped
flower:
build: ./backend
container_name: pdg-flower
environment:
- REDIS_URL=redis://redis:6379/0
ports:
- "5555:5555"
depends_on:
- redis
- celery_worker
command: >
sh -c "celery -A app.tasks.worker flower --port=5555"
restart: unless-stopped
volumes:
pg_data:
redis_data:
minio_data:
+3
View File
@@ -54,6 +54,7 @@ services:
- MINIO_ACCESS_KEY=pdgminio - MINIO_ACCESS_KEY=pdgminio
- MINIO_SECRET_KEY=pdgminio_secret_2024 - MINIO_SECRET_KEY=pdgminio_secret_2024
- SECRET_KEY=prop-data-guard-super-secret-key-change-in-production - SECRET_KEY=prop-data-guard-super-secret-key-change-in-production
- DB_ENCRYPTION_KEY=${DB_ENCRYPTION_KEY:-}
- ACCESS_TOKEN_EXPIRE_MINUTES=30 - ACCESS_TOKEN_EXPIRE_MINUTES=30
- REFRESH_TOKEN_EXPIRE_DAYS=7 - REFRESH_TOKEN_EXPIRE_DAYS=7
volumes: volumes:
@@ -88,6 +89,7 @@ services:
- DATABASE_URL=postgresql+psycopg2://pdg:pdg_secret_2024@db:5432/prop_data_guard - DATABASE_URL=postgresql+psycopg2://pdg:pdg_secret_2024@db:5432/prop_data_guard
- REDIS_URL=redis://redis:6379/0 - REDIS_URL=redis://redis:6379/0
- SECRET_KEY=prop-data-guard-super-secret-key-change-in-production - SECRET_KEY=prop-data-guard-super-secret-key-change-in-production
- DB_ENCRYPTION_KEY=${DB_ENCRYPTION_KEY:-}
volumes: volumes:
- ./backend:/app - ./backend:/app
depends_on: depends_on:
@@ -103,6 +105,7 @@ services:
- DATABASE_URL=postgresql+psycopg2://pdg:pdg_secret_2024@db:5432/prop_data_guard - DATABASE_URL=postgresql+psycopg2://pdg:pdg_secret_2024@db:5432/prop_data_guard
- REDIS_URL=redis://redis:6379/0 - REDIS_URL=redis://redis:6379/0
- SECRET_KEY=prop-data-guard-super-secret-key-change-in-production - SECRET_KEY=prop-data-guard-super-secret-key-change-in-production
- DB_ENCRYPTION_KEY=${DB_ENCRYPTION_KEY:-}
volumes: volumes:
- ./backend:/app - ./backend:/app
depends_on: depends_on:
+354
View File
@@ -0,0 +1,354 @@
# DataPointer 功能架构图
## 一、总体功能架构图
```mermaid
flowchart TB
%% 样式定义
classDef userLayer fill:#e1f5fe,stroke:#01579b,stroke-width:2px,color:#000
classDef appLayer fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px,color:#000
classDef securityLayer fill:#fff3e0,stroke:#ef6c00,stroke-width:2px,color:#000
classDef engineLayer fill:#f3e5f5,stroke:#6a1b9a,stroke-width:2px,color:#000
classDef dataLayer fill:#eceff1,stroke:#455a64,stroke-width:2px,color:#000
classDef infraLayer fill:#fce4ec,stroke:#c2185b,stroke-width:2px,color:#000
%% 用户层
subgraph L5["👤 用户层"]
U1[数据安全管理员]
U2[合规审计员]
U3[分类打标员]
U4[系统管理员]
U5[业务分析师]
end
%% 展示层
subgraph L4["🖥️ 展示层"]
V1[数据资产仪表盘]
V2[分类分级工作台]
V3[安全风险驾驶舱]
V4[合规报告中心]
V5[任务与工单看板]
end
%% 业务应用层
subgraph L3["📦 业务应用层"]
direction TB
subgraph M1["资产管理"]
A1[数据源管理]
A2[元数据目录]
A3[Schema变更追踪]
end
subgraph M2["分类分级"]
A4[分类项目管理]
A5[标注任务中心]
A6[规则模板管理]
end
subgraph M3["安全运营"]
A7[告警规则]
A8[工单流转]
A9[API资产管理]
end
subgraph M4["报表中心"]
A10[统计报表]
A11[报告导出]
A12[操作审计]
end
end
%% 安全能力层
subgraph L2["🔒 安全能力层"]
direction TB
subgraph S1["数据保护"]
B1[静态脱敏引擎]
B2[数字水印溯源]
end
subgraph S2["风险合规"]
B3[风险评分模型]
B4[合规检查引擎]
end
subgraph S3["血缘追溯"]
B5[SQL血缘解析]
B6[血缘关系图谱]
end
end
%% 核心引擎层
subgraph L1["⚙️ 核心引擎层"]
direction TB
subgraph E1["采集与识别"]
C1[元数据采集引擎]
C2[增量扫描引擎]
C3[Schema比对引擎]
end
subgraph E2["智能分类"]
C4[规则匹配引擎]
C5[语义相似度引擎]
C6[ML分类模型]
end
subgraph E3["文件解析"]
C7[Word/Excel解析]
C8[PDF文本提取]
C9[非结构化识别]
end
end
%% 数据源层
subgraph L0["🗄️ 数据源层"]
D1[(PostgreSQL)]
D2[(MySQL)]
D3[(Oracle)]
D4[(SQLServer)]
D5[(达梦DM)]
D6[MinIO对象存储]
D7[Swagger/OpenAPI]
end
%% 基础设施层
subgraph INF["☁️ 基础设施层"]
I1[(PostgreSQL 16)]
I2[(Redis 7)]
I3[MinIO]
I4[Celery Worker]
I5[Celery Beat]
end
%% 连接关系
U1 --> V1 & V3
U2 --> V4
U3 --> V2 & V5
U4 --> V1 & V4 & V5
U5 --> V1 & V2
V1 --> A1 & A2 & A10
V2 --> A4 & A5 & A6
V3 --> A7 & A8 & A10
V4 --> A11 & A12
V5 --> A5 & A8
A1 --> C1
A2 --> C1 & C2
A3 --> C3
A4 --> C4 & C5 & C6
A5 --> C4
A6 --> C4 & C5
A7 --> B3 & B4
A8 --> B3
A9 --> C9
A10 --> B3 & B4
A11 --> B1 & B2
B1 --> C1
B2 --> C7 & C8
B3 --> C4 & C6
B4 --> C4 & C6
B5 --> C1
B6 --> B5
C1 --> D1 & D2 & D3 & D4 & D5
C7 --> D6
C8 --> D6
C9 --> D6 & D7
C1 --> I1
C2 --> I2
C4 --> I4
C6 --> I4
B3 --> I5
B4 --> I5
%% 样式应用
class U1,U2,U3,U4,U5 userLayer
class V1,V2,V3,V4,V5 userLayer
class A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12 appLayer
class B1,B2,B3,B4,B5,B6 securityLayer
class C1,C2,C3,C4,C5,C6,C7,C8,C9 engineLayer
class D1,D2,D3,D4,D5,D6,D7 dataLayer
class I1,I2,I3,I4,I5 infraLayer
```
---
## 二、数据流向架构图
```mermaid
flowchart LR
subgraph 输入["📥 数据输入"]
IN1[关系型数据库]
IN2[非结构化文件]
IN3[API文档]
IN4[SQL脚本]
end
subgraph 处理["🔧 核心处理"]
P1[元数据采集]
P2[规则/ML分类]
P3[风险评分]
P4[合规扫描]
end
subgraph 存储["💾 数据存储"]
S1[(PostgreSQL<br/>业务数据)]
S2[(Redis<br/>缓存/队列)]
S3[MinIO<br/>文件存储]
end
subgraph 输出["📤 数据输出"]
OUT1[分类结果库]
OUT2[风险评分报告]
OUT3[合规差距分析]
OUT4[血缘关系图]
OUT5[脱敏/水印数据]
end
IN1 --> P1 --> S1
IN2 --> P2 --> S3
IN3 --> P2 --> S1
IN4 --> P4 --> S1
S1 --> P2 --> OUT1
S1 --> P3 --> OUT2
S1 --> P4 --> OUT3
S1 --> P4 --> OUT4
S3 --> P4 --> OUT5
S2 -.-> P1 & P2 & P3 & P4
```
---
## 三、安全能力闭环图
```mermaid
flowchart TD
A[数据发现] --> B[自动分类分级]
B --> C[风险量化评估]
C --> D{风险等级?}
D -->|高风险| E[触发告警]
D -->|中风险| F[生成工单]
D -->|低风险| G[持续监控]
E --> H[数据脱敏]
F --> H
H --> I[数字水印]
I --> J[合规检查]
J --> K{合规通过?}
K -->|否| L[整改优化]
L --> B
K -->|是| M[报告导出]
G --> N[Schema变更追踪]
N --> C
```
---
## 四、系统部署架构图
```mermaid
flowchart TB
subgraph 外部用户["外部用户"]
USER[浏览器/客户端]
end
subgraph 接入层["接入层"]
NGINX[Nginx<br/>反向代理/负载均衡]
end
subgraph 应用层["应用层"]
FE[Vue3 前端<br/>静态资源]
BE1[FastAPI 实例1]
BE2[FastAPI 实例2]
WORKER1[Celery Worker 1]
WORKER2[Celery Worker 2]
BEAT[Celery Beat]
FLOWER[Flower 监控]
end
subgraph 数据层["数据层"]
PG[(PostgreSQL<br/>主从集群)]
REDIS[(Redis<br/>Sentinel)]
MINIO[MinIO<br/>分布式对象存储]
end
USER --> NGINX
NGINX --> FE
NGINX --> BE1 & BE2
BE1 & BE2 --> PG & REDIS & MINIO
BEAT --> REDIS
WORKER1 & WORKER2 --> REDIS & PG & MINIO
FLOWER --> REDIS
```
---
## 五、核心业务流程图
### 5.1 数据分类分级流程
```mermaid
sequenceDiagram
actor 管理员
participant 前端
participant 后端API
participant Celery
participant 数据库
管理员->>前端: 创建分类项目
前端->>后端API: POST /projects
后端API->>数据库: 保存项目配置
数据库-->>后端API: 返回项目ID
后端API-->>前端: 创建成功
管理员->>前端: 启动自动分类
前端->>后端API: POST /auto-classify/{id}
后端API->>Celery: 投递异步任务
Celery-->>后端API: 任务ID
后端API-->>前端: 已开始
loop 进度轮询
前端->>后端API: GET /auto-classify-status
后端API-->>前端: 当前进度
end
Celery->>数据库: 写入分类结果
Celery-->>后端API: 任务完成
管理员->>前端: 查看结果并人工复核
前端->>后端API: PUT /results/{id}/label
后端API->>数据库: 更新结果状态
```
### 5.2 风险告警与工单流程
```mermaid
sequenceDiagram
participant 定时器
participant CeleryBeat
participant 告警引擎
participant 数据库
actor 安全管理员
participant 工单系统
定时器->>CeleryBeat: 触发每日扫描
CeleryBeat->>告警引擎: 执行风险重算
告警引擎->>数据库: 查询敏感数据分布
告警引擎->>告警引擎: 比对阈值规则
alt 触发告警条件
告警引擎->>数据库: 写入告警记录
告警引擎-->>安全管理员: 站内消息通知
安全管理员->>工单系统: 一键转工单
工单系统->>数据库: 创建工单记录
工单系统-->>安全管理员: 指派处理人
else 未触发
告警引擎->>数据库: 更新风险评分
end
安全管理员->>工单系统: 处理完成
工单系统->>数据库: 更新工单状态
```
---
*DataPointer 功能架构图 v1.0 | 由 DataPointer 项目组编制*
+362
View File
@@ -0,0 +1,362 @@
# DataPointer 数据安全分级及风险管理平台
## 产品白皮书
> **版本**v1.0
> **日期**2026-04-25
> **定位**:财产保险行业数据分级分类管理与数据安全治理平台
---
## 一、产品概述
DataPointer 是一款面向财产保险行业的数据安全分级分类及风险管理平台。平台以《数据安全法》《个人信息保护法》及保险行业监管要求为合规基线,通过自动化的元数据采集、智能分类分级引擎、多维风险量化模型与全链路数据安全能力,帮助保险企业实现数据资产的"可见、可管、可控"。
### 1.1 核心价值
| 价值维度 | 具体收益 |
|---------|---------|
| **合规达标** | 对标等保 2.0、PIPL、GDPR 及银保监会数据监管要求,自动生成合规差距报告 |
| **资产可视化** | 多数据源统一纳管,库/表/字段层级血缘与敏感分布一目了然 |
| **智能分类** | 规则引擎 + ML 模型双引擎驱动,万级字段分钟级自动定级,准确率 ≥ 70% |
| **风险量化** | 基于暴露面与保护措施的动态风险评分,敏感数据变动实时感知 |
| **安全闭环** | 脱敏、水印、告警、工单四位一体,实现数据安全治理闭环 |
### 1.2 适用场景
- **数据分类分级治理**:核心业务系统、客户信息系统、财务系统的敏感数据梳理与定级
- **监管合规检查**:等保测评、个人信息保护审计、监管报送前的数据合规自检
- **数据出境评估**:识别出境数据中的敏感与核心数据,评估保护强度
- **API 敏感接口治理**:扫描 Swagger/OpenAPI,识别暴露敏感字段的接口
- **非结构化文件管控**Word、Excel、PDF 合同与保单中的敏感信息识别
---
## 二、功能架构总览
DataPointer 采用"四层两域"的功能架构:
- **四层**:数据采集层 → 核心引擎层 → 安全能力层 → 业务应用层
- **两域**:管理域(配置、审批、审计)+ 运营域(分类、脱敏、风控)
```mermaid
flowchart TB
subgraph 展示层["展示层"]
A1[数据资产大屏]
A2[分类分级工作台]
A3[安全风险驾驶舱]
A4[合规报告中心]
end
subgraph 业务应用层["业务应用层"]
B1[数据源管理]
B2[分类项目管理]
B3[标注任务中心]
B4[报表与报告]
B5[告警工单]
B6[API资产]
end
subgraph 安全能力层["安全能力层"]
C1[数据脱敏]
C2[数字水印]
C3[合规检查]
C4[风险评分]
C5[血缘分析]
end
subgraph 核心引擎层["核心引擎层"]
D1[元数据采集引擎]
D2[规则分类引擎]
D3[ML辅助分类]
D4[Schema变更感知]
D5[非结构化识别]
end
subgraph 数据采集层["数据采集层"]
E1[(PostgreSQL)]
E2[(MySQL)]
E3[(Oracle)]
E4[(SQLServer)]
E5[(达梦DM)]
E6[MinIO文件存储]
E7[Swagger/OpenAPI]
end
展示层 --> 业务应用层
业务应用层 --> 安全能力层
安全能力层 --> 核心引擎层
核心引擎层 --> 数据采集层
```
---
## 三、功能模块详解
### 3.1 数据资产管理
**数据源管理**
- 支持 PostgreSQL、MySQL、Oracle、SQLServer、达梦(DM)等多类型数据源注册
- 数据源连接密码采用 Fernet 加密存储,密钥外部注入,重启后可解密
- 连接可用性一键测试,数据源状态实时监控(活跃/异常)
**元数据采集**
- 自动采集库、表、字段三级元数据,包含数据类型、长度、注释、样本数据
- 支持增量采集:基于 `checksum``last_scanned_at` 仅同步变更,减少数据库压力
- 支持全量手动同步与定时自动同步两种模式
**Schema 变更追踪**
- 增量采集时自动比对历史元数据,识别新增/删除/修改字段
- 生成 Schema 变更日志,敏感字段新增时自动标红告警
- 支持按数据源、变更类型(ADD/MODIFY/DROP)筛选查询
### 3.2 数据分类分级
**分类标准管理**
- 内置财产保险行业分类模板:客户信息、保单信息、理赔信息、财务信息、渠道信息、监管报送、内部系统
- 五级安全等级:公开级(L1)、内部级(L2)、秘密级(L3)、机密级(L4)、核心级(L5)
- 支持自定义分类目录与颜色标识
**识别规则引擎**
- 规则类型:正则匹配、关键词包含、枚举值、语义相似度(cosine similarity ≥ 0.75
- 规则绑定分类 + 等级,支持多规则组合命中
- 规则热更新,无需重启即可生效
**自动分类引擎**
- 基于规则的自动打标:字段名、注释、样本数据多维度匹配
- Celery 异步执行,万级字段分类不阻塞 HTTP 接口
- 实时进度反馈:前端进度条 + 后端 `scan_progress` 轮询
**ML 辅助分类**
- 基于字段 name / comment / sample_data 的 TF-IDF 特征工程
- 支持 LogisticRegression / RandomForest 模型训练
- 提供 `ml-suggest` 接口,前端一键采纳推荐标签与置信度
- 模型版本化管理(MLModelVersion),支持回滚与 A/B 对比
**人工标注与审核**
- 项目化任务分配:创建任务 → 指派打标员 → 提交结果 → 审核员复核
- 支持单人标注、多人交叉标注两种模式
- 标注结果状态流转:auto → manual → reviewed → published
### 3.3 数据安全保护
**数据静态脱敏**
- 脱敏策略:掩码(mask)、截断(truncate)、哈希(hash)、泛化(generalize)、替换(replace)
- 策略绑定敏感等级与分类,支持批量策略应用
- 脱敏预览:左右对比原文与脱敏后效果,确认后导出
**数字水印溯源**
- 文本水印:采用零宽空格(Zero-Width Spaces)嵌入用户 ID
- 水印不可见、不影响可读性,复制粘贴后仍可提取
- 溯源 API 提取水印信息,精准定位数据泄露源头
- WatermarkLog 记录每次导出行为,形成审计链条
### 3.4 风险管理与合规
**风险评分模型**
- 评分公式:`RiskScore = Σ(Li × exposure × (1 - protection_rate))`
- 四级聚合:全局风险 → 数据源风险 → 数据库风险 → 表级风险
- 风险 TOP N 排行,Dashboard 实时展示风险趋势
- Celery Beat 每日自动重算,敏感字段未脱敏时分数自动上升
**合规检查引擎**
- 内置规则库:等保 2.0、PIPL(个人信息保护法)、GDPR
- 可插拔检查器基类,支持自定义合规规则
- 自动扫描生成问题清单:L5 未脱敏、缺少分类、Schema 变更未审批等
- 支持问题导出为合规差距分析报告
**智能告警与工单**
- 告警规则配置:敏感字段新增数、风险分阈值、Schema 变更类型
- 告警记录管理:未读/已读/已处理状态流转
- 一键转工单:告警 → 工单 → 指派 → 处理 → 关闭
- 工单状态:open → in_progress → resolved,支持处理结论备注
### 3.5 数据血缘分析
- 基于 `sqlparse` 解析 SQL 脚本(ETL、存储过程),提取表级血缘关系
- 支持 INSERT/CREATE TABLE AS / MERGE 等常见语法
- 前端 ECharts 关系图展示,支持上下游 3 层展开
- 血缘记录持久化存储,支持按表名查询全链路影响
### 3.6 API 资产安全管理
- API 资产注册:名称、基础 URL、Swagger 地址、认证方式
- 自动扫描 Swagger/OpenAPI 文档,解析端点、参数、响应 Schema
- 规则引擎标记敏感接口:暴露 phone、idCard、bankCard 等字段的接口自动标红
- 端点级风险等级评估与清单导出
### 3.7 非结构化数据识别
- 支持 Word、Excel、PDF、TXT 文件上传
- 文件存储至 MinIO,解析文本后送入规则引擎
- 识别结果:匹配规则名、敏感分类、安全等级、文本片段定位
- 支持重新处理与结果查看
### 3.8 数据资产仪表盘
**核心指标卡**
- 数据源总数、数据表总数、字段总数
- 已分类字段数、敏感字段数、项目总数
**可视化图表**
- 等级分布饼图:L1~L5 字段占比
- 分类 TOP8 横向柱状图
- 数据源 × 等级热力矩阵
- 项目进度甘特图
- 风险趋势折线图
**报告导出**
- 支持 Word、Excel、PDF 三种格式
- 报告包含:项目概况、等级分布、Top20 敏感清单、合规摘要
---
## 四、技术架构
### 4.1 总体技术栈
```mermaid
flowchart LR
subgraph 前端["前端层"]
F1[Vue 3]
F2[Vite]
F3[Element Plus]
F4[ECharts]
F5[Pinia]
end
subgraph 后端["后端层"]
B1[FastAPI]
B2[SQLAlchemy 2.0]
B3[Pydantic v2]
B4[Celery]
end
subgraph 数据层["数据层"]
D1[(PostgreSQL 16)]
D2[(Redis 7)]
D3[MinIO]
end
subgraph 算法层["算法层"]
A1[scikit-learn]
A2[sqlparse]
A3[TfidfVectorizer]
end
前端 -->|HTTP /api/v1| 后端
后端 -->|SQL| 数据层
后端 -->|Task Queue| D2
后端 -->|Object Storage| D3
后端 -->|ML / Parsing| 算法层
```
### 4.2 后端服务架构
| 服务 | 技术 | 职责 |
|------|------|------|
| Web API | FastAPI + Uvicorn | RESTful API 服务,JWT 认证 |
| Celery Worker | Celery + Redis | 异步分类、ML 训练、风险重算 |
| Celery Beat | Celery + Redis | 定时任务调度(风险重算、合规扫描)|
| Flower | Celery Monitor | 任务监控与可视化 |
### 4.3 数据模型设计
平台共涉及 **31 张业务表**,核心实体关系如下:
```mermaid
erDiagram
DATA_SOURCE ||--o{ META_DATABASE : contains
META_DATABASE ||--o{ META_TABLE : contains
META_TABLE ||--o{ META_COLUMN : contains
CLASSIFICATION_PROJECT ||--o{ CLASSIFICATION_TASK : has
CLASSIFICATION_PROJECT ||--o{ CLASSIFICATION_RESULT : produces
META_COLUMN ||--o{ CLASSIFICATION_RESULT : classified_as
CATEGORY ||--o{ CLASSIFICATION_RESULT : belongs_to
DATA_LEVEL ||--o{ CLASSIFICATION_RESULT : rated_as
SYS_USER ||--o{ CLASSIFICATION_TASK : assigned_to
DATA_SOURCE ||--o{ RISK_ASSESSMENT : assessed
CLASSIFICATION_RESULT ||--o{ COMPLIANCE_ISSUE : generates
ALERT_RECORD ||--o{ WORK_ORDER : converts_to
API_ASSET ||--o{ API_ENDPOINT : has
```
---
## 五、部署方案
### 5.1 Docker Compose 单机部署(推荐试用)
```yaml
# 一键启动 7 个服务
services:
db: postgres:16-alpine
redis: redis:7-alpine
minio: minio/minio
backend: FastAPI + Alembic 自动迁移
frontend: Vue3 Vite DevServer
celery_worker: Celery Worker (concurrency=2)
celery_beat: Celery Beat 定时调度
flower: Celery 监控面板
```
### 5.2 生产环境最低配置
| 资源 | 最低规格 | 推荐规格 |
|------|---------|---------|
| CPU | 4 核 | 8 核 |
| 内存 | 8 GB | 16 GB |
| 磁盘 | 100 GB SSD | 500 GB SSD |
| 网络 | 5 Mbps | 10 Mbps |
### 5.3 高可用扩展建议
- **数据库**PostgreSQL 主从 + 连接池(PgBouncer
- **缓存/队列**Redis Sentinel 或 Redis Cluster
- **对象存储**:MinIO 分布式集群
- **应用层**FastAPI 多实例 + Nginx 负载均衡
- **前端**:静态资源托管至 CDN
---
## 六、安全设计
### 6.1 认证与授权
- JWT Access Token + Refresh Token 双令牌机制
- RBAC 角色权限控制:超级管理员 / 管理员 / 项目负责人 / 打标员 / 审核员 / 访客
- 数据隔离:非管理员仅可查看自己创建/参与的项目与任务
### 6.2 数据安全
- 数据源密码 Fernet 加密存储,密钥外部注入(`DB_ENCRYPTION_KEY`
- 数据库连接 SSL 支持
- 操作审计日志:记录用户、模块、动作、IP、耗时
### 6.3 部署安全
- `.env` 环境变量隔离敏感配置,不进入代码仓库
- Docker 镜像最小化(python:3.12-slim、node:20-alpine
- CORS 白名单限制前端域
---
## 七、项目里程碑与规划
| 阶段 | 周期 | 核心目标 | 关键交付 |
|------|------|---------|---------|
| 第一阶段 | 4 周 | 核心引擎加固 + 智能化 | 密码加密修复、Celery 异步分类、ML 辅助原型、语义相似度、增量采集 |
| 第二阶段 | 5 周 | 安全能力补齐 + 体验升级 | 静态脱敏、数字水印、Excel/PDF 报告、达梦驱动、非结构化识别、Schema 变更 |
| 第三阶段 | 6 周 | 风险管理 + 合规 + 血缘 | 风险评分、合规引擎、血缘分析、告警工单、API 资产扫描、暗黑模式 |
**总计约 89 人天**,双人并行可压缩至 2 个月。
---
## 八、总结
DataPointer 以"数据分级分类"为核心切入点,构建了覆盖**数据采集 → 智能分类 → 安全保护 → 风险合规 → 血缘追溯**的全链路数据安全治理平台。平台采用现代化的前后端技术栈,支持容器化一键部署,具备高度的可扩展性与可定制性,能够满足财产保险企业在数字化转型过程中的数据安全合规需求。
---
*DataPointer 产品白皮书 v1.0 | 由 DataPointer 项目组编制*
+21
View File
@@ -0,0 +1,21 @@
# Build stage
FROM node:20-alpine AS builder
WORKDIR /app
RUN npm install -g npm@10
COPY package.json .
RUN npm install
COPY . .
RUN npm run build
# Production stage
FROM nginx:alpine
COPY --from=builder /app/dist /usr/share/nginx/html
COPY nginx.conf /etc/nginx/conf.d/default.conf
EXPOSE 80
CMD ["nginx", "-g", "daemon off;"]
-1
View File
@@ -1 +0,0 @@
.page-title[data-v-0ac8aaa8]{font-size:20px;font-weight:600;margin-bottom:20px;color:#303133}.section[data-v-0ac8aaa8]{margin-bottom:24px}.section .section-header[data-v-0ac8aaa8]{display:flex;align-items:center;justify-content:space-between;margin-bottom:12px}.section .section-title[data-v-0ac8aaa8]{font-size:16px;font-weight:600;color:#303133}.level-card[data-v-0ac8aaa8]{padding:16px;margin-bottom:12px;background:#fff;border-radius:8px}.level-card .level-header[data-v-0ac8aaa8]{display:flex;align-items:center;gap:10px;margin-bottom:8px}.level-card .level-name[data-v-0ac8aaa8]{font-size:15px;font-weight:600;color:#303133}.level-card .level-desc[data-v-0ac8aaa8]{font-size:13px;color:#606266;line-height:1.5;margin-bottom:8px}.level-card .level-ctrl .ctrl-item[data-v-0ac8aaa8]{font-size:12px;color:#909399;margin-bottom:2px}.level-card .level-ctrl .ctrl-item .ctrl-key[data-v-0ac8aaa8]{font-weight:500;color:#606266}.category-tree[data-v-0ac8aaa8]{padding:16px;background:#fff;border-radius:8px}.category-tree .custom-tree-node[data-v-0ac8aaa8]{display:flex;align-items:center;justify-content:space-between;flex:1;overflow:hidden}.category-tree .custom-tree-node .node-label[data-v-0ac8aaa8]{display:flex;align-items:center;gap:8px;overflow:hidden}.category-tree .custom-tree-node .node-label .code-tag[data-v-0ac8aaa8]{font-size:11px;flex-shrink:0}.category-tree .custom-tree-node .node-actions[data-v-0ac8aaa8]{flex-shrink:0}.table-card[data-v-0ac8aaa8]{padding:16px;background:#fff;border-radius:8px}
-1
View File
@@ -1 +0,0 @@
.page-header[data-v-e577ddaa]{display:flex;align-items:center;justify-content:space-between;margin-bottom:16px}.page-header .page-title[data-v-e577ddaa]{font-size:20px;font-weight:600;color:#303133}.search-bar[data-v-e577ddaa]{padding:16px;margin-bottom:16px}.table-card[data-v-e577ddaa]{padding:16px}.pagination-bar[data-v-e577ddaa]{display:flex;justify-content:flex-end;margin-top:16px}
-1
View File
@@ -1 +0,0 @@
.layout-container[data-v-6b05a74f]{height:100vh}.layout-aside[data-v-6b05a74f]{background-color:#1a2b4a;display:flex;flex-direction:column}.logo[data-v-6b05a74f]{height:56px;display:flex;align-items:center;justify-content:center;gap:10px;background-color:#13203a;flex-shrink:0}.logo .logo-text[data-v-6b05a74f]{color:#fff;font-size:16px;font-weight:600;letter-spacing:1px}.layout-menu[data-v-6b05a74f]{border-right:none;flex:1}.layout-header[data-v-6b05a74f]{background-color:#fff;display:flex;align-items:center;justify-content:space-between;box-shadow:0 1px 4px #0000000d;padding:0 16px}.layout-header .header-left[data-v-6b05a74f]{display:flex;align-items:center;gap:12px}.layout-header .header-title[data-v-6b05a74f]{font-size:16px;font-weight:600;color:#303133}.layout-header .header-right .user-info[data-v-6b05a74f]{display:flex;align-items:center;gap:8px;cursor:pointer;color:#606266}.layout-header .header-right .user-info .username[data-v-6b05a74f]{font-size:14px;max-width:100px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap}.layout-main[data-v-6b05a74f]{background-color:#f5f7fa;padding:0;overflow-y:auto}[data-v-6b05a74f] .mobile-drawer .el-drawer__body{padding:0;background-color:#1a2b4a;display:flex;flex-direction:column}
-1
View File
@@ -1 +0,0 @@
.login-page[data-v-8c2e034d]{min-height:100vh;display:flex;align-items:center;justify-content:center;background:linear-gradient(135deg,#1a2b4a,#2d4a7c);padding:16px}.login-box[data-v-8c2e034d]{width:100%;max-width:420px;padding:40px 32px;background:#fff;border-radius:12px}.login-header[data-v-8c2e034d]{text-align:center;margin-bottom:32px}.login-header .login-title[data-v-8c2e034d]{font-size:24px;font-weight:700;color:#1a2b4a;margin-top:12px}.login-header .login-subtitle[data-v-8c2e034d]{font-size:14px;color:#909399;margin-top:8px}.login-btn[data-v-8c2e034d]{width:100%;height:44px;font-size:16px;border-radius:6px}.login-footer[data-v-8c2e034d]{margin-top:24px;text-align:center;font-size:12px;color:#c0c4cc}
-1
View File
@@ -1 +0,0 @@
.metadata-page .page-title[data-v-866ea4fc]{font-size:20px;font-weight:600;margin-bottom:16px;color:#303133}.content-row .el-col[data-v-866ea4fc]{margin-bottom:16px}.tree-card[data-v-866ea4fc]{padding:16px;height:calc(100vh - 140px);overflow-y:auto}.tree-card .tree-header[data-v-866ea4fc]{font-size:16px;font-weight:600;margin-bottom:12px;color:#303133}.tree-card .tree-search[data-v-866ea4fc]{margin-bottom:12px}.custom-tree-node[data-v-866ea4fc]{display:flex;align-items:center;gap:6px;flex:1;overflow:hidden}.custom-tree-node .node-label[data-v-866ea4fc]{flex:1;overflow:hidden;text-overflow:ellipsis;white-space:nowrap}.custom-tree-node .node-badge[data-v-866ea4fc]{font-size:11px;color:#909399;background:#f2f6fc;padding:0 6px;border-radius:10px}.detail-card[data-v-866ea4fc]{padding:16px;height:calc(100vh - 140px);display:flex;flex-direction:column}.detail-card .detail-header[data-v-866ea4fc]{display:flex;align-items:center;justify-content:space-between;margin-bottom:16px;flex-wrap:wrap;gap:12px}.detail-card .detail-header .detail-title[data-v-866ea4fc]{display:flex;align-items:center;gap:10px;font-size:16px;font-weight:600;color:#303133}.detail-card .detail-header .detail-title .placeholder[data-v-866ea4fc]{color:#909399;font-weight:400}.sample-text[data-v-866ea4fc]{color:#909399;font-size:12px}
-1
View File
@@ -1 +0,0 @@
.page-header[data-v-5fcafe59]{display:flex;align-items:center;justify-content:space-between;margin-bottom:16px}.page-header .page-title[data-v-5fcafe59]{font-size:20px;font-weight:600;color:#303133}.search-bar[data-v-5fcafe59]{padding:16px;margin-bottom:16px}.project-list .el-col[data-v-5fcafe59]{margin-bottom:16px}.project-card[data-v-5fcafe59]{padding:20px;background:#fff;border-radius:8px;transition:box-shadow .2s}.project-card[data-v-5fcafe59]:hover{box-shadow:0 4px 16px #00000014}.project-card .project-header[data-v-5fcafe59]{display:flex;align-items:center;justify-content:space-between;margin-bottom:8px}.project-card .project-header .project-name[data-v-5fcafe59]{font-size:16px;font-weight:600;color:#303133}.project-card .project-desc[data-v-5fcafe59]{font-size:13px;color:#909399;margin-bottom:16px;min-height:20px}.project-card .project-stats[data-v-5fcafe59]{display:flex;gap:16px;margin-bottom:16px;padding:12px 0;border-top:1px solid #f0f0f0;border-bottom:1px solid #f0f0f0}.project-card .project-stats .stat-item[data-v-5fcafe59]{text-align:center;flex:1}.project-card .project-stats .stat-item .stat-num[data-v-5fcafe59]{font-size:18px;font-weight:700;color:#303133}.project-card .project-stats .stat-item .stat-label[data-v-5fcafe59]{font-size:12px;color:#909399;margin-top:2px}.project-card .project-actions[data-v-5fcafe59]{display:flex;justify-content:flex-end;gap:8px}
-1
View File
@@ -1 +0,0 @@
.page-title[data-v-d05c8e34]{font-size:20px;font-weight:600;margin-bottom:16px;color:#303133}.system-tabs[data-v-d05c8e34]{background:#fff;padding:16px;border-radius:8px}.table-card[data-v-d05c8e34]{padding:16px}.table-card .table-header[data-v-d05c8e34]{display:flex;align-items:center;gap:12px;margin-bottom:16px}
File diff suppressed because one or more lines are too long
+3 -3
View File
@@ -4,9 +4,9 @@
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" /> <link rel="icon" type="image/svg+xml" href="/vite.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no" /> <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no" />
<title>PropDataGuard - 财险数据分级分类平台</title> <title>DataPointer - 数据分类分级管理平台</title>
<script type="module" crossorigin src="/assets/index-DveMB2K5.js"></script> <script type="module" crossorigin src="/assets/index-B2ZsjZSQ.js"></script>
<link rel="stylesheet" crossorigin href="/assets/index-s_XEM0GP.css"> <link rel="stylesheet" crossorigin href="/assets/index-CdImMPt_.css">
</head> </head>
<body> <body>
<div id="app"></div> <div id="app"></div>
+1 -1
View File
@@ -4,7 +4,7 @@
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" /> <link rel="icon" type="image/svg+xml" href="/vite.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no" /> <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no" />
<title>PropDataGuard - 财险数据分级分类平台</title> <title>DataPointer - 数据分类分级管理平台</title>
</head> </head>
<body> <body>
<div id="app"></div> <div id="app"></div>
+56
View File
@@ -0,0 +1,56 @@
# HTTP redirect to HTTPS
server {
listen 80;
server_name datapointer.cnroc.cn localhost _;
return 301 https://$host$request_uri;
}
server {
listen 443 ssl;
server_name datapointer.cnroc.cn localhost _;
root /usr/share/nginx/html;
index index.html;
# SSL certificates
ssl_certificate /etc/nginx/ssl/fullchain.pem;
ssl_certificate_key /etc/nginx/ssl/privkey.pem;
ssl_protocols TLSv1.2 TLSv1.3;
ssl_ciphers HIGH:!aNULL:!MD5;
ssl_prefer_server_ciphers on;
# Gzip compression
gzip on;
gzip_vary on;
gzip_min_length 1024;
gzip_types text/plain text/css application/json application/javascript text/xml application/xml application/xml+rss text/javascript;
# Frontend static files
location / {
try_files $uri $uri/ /index.html;
}
# API proxy to backend
location /api/ {
proxy_pass http://backend:8000/api/;
proxy_http_version 1.1;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_connect_timeout 30s;
proxy_send_timeout 30s;
proxy_read_timeout 30s;
}
# Health check endpoint
location /health {
proxy_pass http://backend:8000/health;
}
# Cache static assets + CORS for crossorigin attribute
location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {
expires 30d;
add_header Cache-Control "public, immutable";
add_header Access-Control-Allow-Origin * always;
}
}
+2 -2
View File
@@ -1,11 +1,11 @@
{ {
"name": "prop-data-guard-frontend", "name": "data-pointer-frontend",
"version": "0.1.0", "version": "0.1.0",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "prop-data-guard-frontend", "name": "data-pointer-frontend",
"version": "0.1.0", "version": "0.1.0",
"dependencies": { "dependencies": {
"@element-plus/icons-vue": "^2.3.1", "@element-plus/icons-vue": "^2.3.1",
+1 -1
View File
@@ -1,5 +1,5 @@
{ {
"name": "prop-data-guard-frontend", "name": "data-pointer-frontend",
"private": true, "private": true,
"version": "0.1.0", "version": "0.1.0",
"type": "module", "type": "module",
+69
View File
@@ -0,0 +1,69 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<title>报告预览</title>
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; margin: 40px; color: #333; background: #f5f7fa; }
.page { max-width: 800px; margin: 0 auto; background: #fff; padding: 48px; box-shadow: 0 2px 12px rgba(0,0,0,0.1); }
h1 { text-align: center; font-size: 24px; margin-bottom: 32px; }
h2 { font-size: 16px; border-left: 4px solid #409eff; padding-left: 8px; margin-top: 24px; }
table { width: 100%; border-collapse: collapse; margin-top: 12px; }
th, td { border: 1px solid #dcdfe6; padding: 8px 12px; text-align: left; }
th { background: #f5f7fa; }
.highlight { background: #fde2e2; }
.info { color: #909399; font-size: 12px; margin-top: 4px; }
.print-btn { position: fixed; top: 20px; right: 20px; padding: 8px 16px; background: #409eff; color: #fff; border: none; border-radius: 4px; cursor: pointer; }
@media print { .print-btn { display: none; } body { margin: 0; background: #fff; } .page { box-shadow: none; } }
</style>
</head>
<body>
<button class="print-btn" onclick="window.print()">打印 / 保存为PDF</button>
<div class="page">
<h1>数据分类分级项目报告</h1>
<div id="content">加载中...</div>
</div>
<script>
const params = new URLSearchParams(location.search);
const projectId = params.get('projectId');
const apiBase = '/api/v1';
async function load() {
if (!projectId) { document.getElementById('content').innerText = '缺少项目ID'; return; }
try {
const res = await fetch(`${apiBase}/reports/projects/${projectId}/summary`);
const json = await res.json();
const d = json.data;
let html = '';
html += '<h2>一、项目基本信息</h2>';
html += '<table><tr><th>项目名称</th><td>' + (d.project_name || '') + '</td></tr>';
html += '<tr><th>报告生成时间</th><td>' + (d.generated_at || '').slice(0,19).replace('T',' ') + '</td></tr>';
html += '<tr><th>项目状态</th><td>' + (d.status || '') + '</td></tr>';
html += '<tr><th>模板版本</th><td>' + (d.template_version || '') + '</td></tr></table>';
html += '<h2>二、分类分级统计</h2>';
html += '<p>总字段数: ' + d.total + ' | 自动识别: ' + d.auto + ' | 人工打标: ' + d.manual + '</p>';
html += '<h2>三、分级分布</h2><table><tr><th>分级</th><th>数量</th><th>占比</th></tr>';
d.level_distribution.forEach(item => {
const pct = d.total ? (item.count / d.total * 100).toFixed(1) + '%' : '0%';
const cls = item.name.includes('L4') || item.name.includes('L5') ? 'highlight' : '';
html += '<tr class="' + cls + '"><td>' + item.name + '</td><td>' + item.count + '</td><td>' + pct + '</td></tr>';
});
html += '</table>';
html += '<h2>四、高敏感数据清单(L4/L5</h2>';
if (d.high_risk && d.high_risk.length) {
html += '<table><tr><th>字段名</th><th>所属表</th><th>分类</th><th>分级</th><th>来源</th></tr>';
d.high_risk.forEach(r => {
html += '<tr class="highlight"><td>' + r.column_name + '</td><td>' + r.table_name + '</td><td>' + r.category_name + '</td><td>' + r.level_name + '</td><td>' + r.source + '</td></tr>';
});
html += '</table>';
} else {
html += '<p>暂无L4/L5级高敏感数据。</p>';
}
document.getElementById('content').innerHTML = html;
} catch (e) {
document.getElementById('content').innerText = '加载失败: ' + e.message;
}
}
load();
</script>
</body>
</html>
+24
View File
@@ -107,3 +107,27 @@ export function getClassificationResults(params: {
}) { }) {
return request.get('/classifications/results', { params }) return request.get('/classifications/results', { params })
} }
export interface MLSuggestion {
column_id: number
column_name: string
table_name?: string
suggestions: {
category_id: number
category_name?: string
category_code?: string
confidence: number
}[]
}
export function getMLSuggestions(project_id: number, column_ids?: number[], top_k: number = 3) {
const params: any = { top_k }
if (column_ids && column_ids.length) {
params.column_ids = column_ids.join(',')
}
return request.get(`/classifications/ml-suggest/${project_id}`, { params })
}
export function trainMLModel(background: boolean = true, model_name?: string, algorithm: string = 'logistic_regression') {
return request.post('/classifications/ml-train', null, { params: { background, model_name, algorithm } })
}
+17
View File
@@ -0,0 +1,17 @@
import request from './request'
export function initComplianceRules() {
return request.post('/compliance/init-rules')
}
export function scanCompliance(projectId?: number) {
return request.post('/compliance/scan', null, { params: projectId ? { project_id: projectId } : undefined })
}
export function getComplianceIssues(params?: { project_id?: number; status?: string; page?: number; page_size?: number }) {
return request.get('/compliance/issues', { params })
}
export function resolveIssue(id: number) {
return request.post(`/compliance/issues/${id}/resolve`)
}
+9
View File
@@ -0,0 +1,9 @@
import request from './request'
export function parseLineage(sql: string, targetTable: string) {
return request.post('/lineage/parse', null, { params: { sql, target_table: targetTable } })
}
export function getLineageGraph(tableName?: string) {
return request.get('/lineage/graph', { params: tableName ? { table_name: tableName } : undefined })
}
+34
View File
@@ -0,0 +1,34 @@
import request from './request'
export interface MaskingRuleItem {
id: number
name: string
level_id?: number
category_id?: number
algorithm: string
params?: Record<string, any>
is_active: boolean
description?: string
level_name?: string
category_name?: string
}
export function getMaskingRules(params?: { level_id?: number; category_id?: number; page?: number; page_size?: number }) {
return request.get('/masking/rules', { params })
}
export function createMaskingRule(data: Partial<MaskingRuleItem>) {
return request.post('/masking/rules', data)
}
export function updateMaskingRule(id: number, data: Partial<MaskingRuleItem>) {
return request.put(`/masking/rules/${id}`, data)
}
export function deleteMaskingRule(id: number) {
return request.delete(`/masking/rules/${id}`)
}
export function previewMasking(source_id: number, table_name: string, project_id?: number, limit: number = 20) {
return request.post('/masking/preview', null, { params: { source_id, table_name, project_id, limit } })
}
+6 -2
View File
@@ -34,6 +34,10 @@ export function deleteProject(id: number) {
return request.delete(`/projects/${id}`) return request.delete(`/projects/${id}`)
} }
export function autoClassifyProject(id: number) { export function autoClassifyProject(id: number, background: boolean = true) {
return request.post(`/projects/${id}/auto-classify`) return request.post(`/projects/${id}/auto-classify`, undefined, { params: { background } })
}
export function getAutoClassifyStatus(id: number) {
return request.get(`/projects/${id}/auto-classify-status`)
} }
+8 -4
View File
@@ -16,12 +16,12 @@ export function getReportStats() {
return request.get('/reports/stats') return request.get('/reports/stats')
} }
export function downloadReport(projectId: number) { export function downloadReport(projectId: number, format: string = 'docx') {
const token = localStorage.getItem('pdg_token') const token = localStorage.getItem('dp_token')
const url = `/api/v1/reports/projects/${projectId}/download` const url = `/api/v1/reports/projects/${projectId}/download?format=${format}`
const a = document.createElement('a') const a = document.createElement('a')
a.href = url a.href = url
a.download = `report_project_${projectId}.docx` a.download = `report_project_${projectId}.${format === 'excel' ? 'xlsx' : 'docx'}`
if (token) { if (token) {
a.setAttribute('data-token', token) a.setAttribute('data-token', token)
} }
@@ -29,3 +29,7 @@ export function downloadReport(projectId: number) {
a.click() a.click()
document.body.removeChild(a) document.body.removeChild(a)
} }
export function getReportSummary(projectId: number) {
return request.get(`/reports/projects/${projectId}/summary`)
}
+8 -4
View File
@@ -8,7 +8,7 @@ const request = axios.create({
request.interceptors.request.use( request.interceptors.request.use(
(config: InternalAxiosRequestConfig) => { (config: InternalAxiosRequestConfig) => {
const token = localStorage.getItem('pdg_token') const token = localStorage.getItem('dp_token')
if (token && config.headers) { if (token && config.headers) {
config.headers.Authorization = `Bearer ${token}` config.headers.Authorization = `Bearer ${token}`
} }
@@ -32,11 +32,15 @@ request.interceptors.response.use(
const status = error.response?.status const status = error.response?.status
if (status === 401) { if (status === 401) {
ElMessage.error('登录已过期,请重新登录') ElMessage.error('登录已过期,请重新登录')
localStorage.removeItem('pdg_token') localStorage.removeItem('dp_token')
localStorage.removeItem('pdg_refresh') localStorage.removeItem('dp_refresh')
window.location.href = '/login' window.location.href = '/login'
} else { } else {
ElMessage.error((error.response?.data as any)?.message || '网络错误') const data = error.response?.data as any
const detail = Array.isArray(data?.detail)
? data.detail.map((d: any) => d.msg || JSON.stringify(d)).join(', ')
: data?.detail
ElMessage.error(detail || data?.message || error.message || '网络错误')
} }
return Promise.reject(error) return Promise.reject(error)
} }
+13
View File
@@ -0,0 +1,13 @@
import request from './request'
export function recalculateRisk(projectId?: number) {
return request.post('/risk/recalculate', null, { params: projectId ? { project_id: projectId } : undefined })
}
export function getRiskTop(n: number = 10) {
return request.get('/risk/top', { params: { n } })
}
export function getProjectRisk(projectId: number) {
return request.get(`/risk/projects/${projectId}`)
}

Some files were not shown because too many files have changed in this diff Show More