feat: 全量功能模块开发与集成测试修复
- 新增后端模块:Alert、APIAsset、Compliance、Lineage、Masking、Risk、SchemaChange、Unstructured、Watermark - 新增前端模块页面与API接口 - 新增Alembic迁移脚本(002-014)覆盖全量业务表 - 新增测试数据生成脚本与集成测试脚本 - 修复metadata模型JSON类型导入缺失导致启动失败的问题 - 修复前端Alert/APIAsset页面request模块路径错误 - 更新docker-compose与开发计划文档
This commit is contained in:
@@ -0,0 +1,29 @@
|
||||
# Database
|
||||
DATABASE_URL=postgresql+psycopg2://pdg:pdg_secret_2024@localhost:5432/prop_data_guard
|
||||
REDIS_URL=redis://localhost:6379/0
|
||||
|
||||
# Security
|
||||
SECRET_KEY=prop-data-guard-super-secret-key-change-in-production
|
||||
# Fernet-compatible encryption key for database passwords (32 bytes, base64 url-safe).
|
||||
# Generate one with: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
||||
# IMPORTANT: Keep this key safe and consistent across all backend instances.
|
||||
DB_ENCRYPTION_KEY=
|
||||
|
||||
# MinIO
|
||||
MINIO_ENDPOINT=localhost:9000
|
||||
MINIO_ACCESS_KEY=pdgminio
|
||||
MINIO_SECRET_KEY=pdgminio_secret_2024
|
||||
MINIO_SECURE=false
|
||||
MINIO_BUCKET_NAME=pdg-files
|
||||
|
||||
# CORS
|
||||
CORS_ORIGINS=["http://localhost:5173", "http://127.0.0.1:5173"]
|
||||
|
||||
# Auth
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
|
||||
# Default superuser (created on first startup)
|
||||
FIRST_SUPERUSER_USERNAME=admin
|
||||
FIRST_SUPERUSER_PASSWORD=admin123
|
||||
FIRST_SUPERUSER_EMAIL=admin@datapo.com
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Fix datasource password encryption stability
|
||||
|
||||
Revision ID: 002
|
||||
Revises: 001
|
||||
Create Date: 2026-04-23 14:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "002"
|
||||
down_revision: Union[str, None] = "001"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Historical encrypted_password values are irrecoverable because
|
||||
# the old implementation generated a random Fernet key on every startup.
|
||||
# We clear the passwords and mark sources as inactive so admins re-enter them
|
||||
# with the new stable key derived from DB_ENCRYPTION_KEY / SECRET_KEY.
|
||||
op.add_column(
|
||||
"data_source",
|
||||
sa.Column("password_reset_required", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE data_source
|
||||
SET encrypted_password = NULL,
|
||||
status = 'inactive',
|
||||
password_reset_required = true
|
||||
WHERE encrypted_password IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("data_source", "password_reset_required")
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Add celery_task_id and scan_progress to classification_project
|
||||
|
||||
Revision ID: 003
|
||||
Revises: 002
|
||||
Create Date: 2026-04-23 14:30:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "003"
|
||||
down_revision: Union[str, None] = "002"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("classification_project", sa.Column("celery_task_id", sa.String(100), nullable=True))
|
||||
op.add_column("classification_project", sa.Column("scan_progress", sa.Text(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("classification_project", "scan_progress")
|
||||
op.drop_column("classification_project", "celery_task_id")
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Add ml_model_version table
|
||||
|
||||
Revision ID: 004
|
||||
Revises: 003
|
||||
Create Date: 2026-04-23 15:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "004"
|
||||
down_revision: Union[str, None] = "003"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"ml_model_version",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column("name", sa.String(100), nullable=False),
|
||||
sa.Column("model_path", sa.String(500), nullable=False),
|
||||
sa.Column("vectorizer_path", sa.String(500), nullable=False),
|
||||
sa.Column("accuracy", sa.Float(), default=0.0),
|
||||
sa.Column("train_samples", sa.Integer(), default=0),
|
||||
sa.Column("train_date", sa.DateTime(), default=sa.func.now()),
|
||||
sa.Column("is_active", sa.Boolean(), default=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("ml_model_version")
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Add incremental scan fields to meta tables
|
||||
|
||||
Revision ID: 005
|
||||
Revises: 004
|
||||
Create Date: 2026-04-23 16:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "005"
|
||||
down_revision: Union[str, None] = "004"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
for table in ["meta_database", "meta_table", "meta_column"]:
|
||||
op.add_column(table, sa.Column("last_scanned_at", sa.DateTime(), nullable=True))
|
||||
op.add_column(table, sa.Column("checksum", sa.String(64), nullable=True))
|
||||
op.add_column(table, sa.Column("is_deleted", sa.Boolean(), nullable=False, server_default=sa.text("false")))
|
||||
op.add_column(table, sa.Column("deleted_at", sa.DateTime(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in ["meta_database", "meta_table", "meta_column"]:
|
||||
op.drop_column(table, "deleted_at")
|
||||
op.drop_column(table, "is_deleted")
|
||||
op.drop_column(table, "checksum")
|
||||
op.drop_column(table, "last_scanned_at")
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Add masking_rule table
|
||||
|
||||
Revision ID: 006
|
||||
Revises: 005
|
||||
Create Date: 2026-04-23 17:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "006"
|
||||
down_revision: Union[str, None] = "005"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"masking_rule",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column("name", sa.String(100), nullable=False),
|
||||
sa.Column("level_id", sa.Integer(), sa.ForeignKey("data_level.id"), nullable=True),
|
||||
sa.Column("category_id", sa.Integer(), sa.ForeignKey("category.id"), nullable=True),
|
||||
sa.Column("algorithm", sa.String(20), nullable=False),
|
||||
sa.Column("params", sa.JSON(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), default=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(), default=sa.func.now(), onupdate=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("masking_rule")
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Add watermark_log table
|
||||
|
||||
Revision ID: 007
|
||||
Revises: 006
|
||||
Create Date: 2026-04-23 18:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "007"
|
||||
down_revision: Union[str, None] = "006"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"watermark_log",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column("user_id", sa.Integer(), sa.ForeignKey("sys_user.id"), nullable=False),
|
||||
sa.Column("export_type", sa.String(20), default="csv"),
|
||||
sa.Column("data_scope", sa.Text(), nullable=True),
|
||||
sa.Column("watermark_key", sa.String(64), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("watermark_log")
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Add analysis_result to unstructured_file
|
||||
|
||||
Revision ID: 008
|
||||
Revises: 007
|
||||
Create Date: 2026-04-23 19:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "008"
|
||||
down_revision: Union[str, None] = "007"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("unstructured_file", sa.Column("analysis_result", sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("unstructured_file", "analysis_result")
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Add schema_change_log table
|
||||
|
||||
Revision ID: 009
|
||||
Revises: 008
|
||||
Create Date: 2026-04-23 20:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "009"
|
||||
down_revision: Union[str, None] = "008"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"schema_change_log",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column("source_id", sa.Integer(), sa.ForeignKey("data_source.id"), nullable=False),
|
||||
sa.Column("database_id", sa.Integer(), sa.ForeignKey("meta_database.id"), nullable=True),
|
||||
sa.Column("table_id", sa.Integer(), sa.ForeignKey("meta_table.id"), nullable=True),
|
||||
sa.Column("column_id", sa.Integer(), sa.ForeignKey("meta_column.id"), nullable=True),
|
||||
sa.Column("change_type", sa.String(20), nullable=False),
|
||||
sa.Column("old_value", sa.Text(), nullable=True),
|
||||
sa.Column("new_value", sa.Text(), nullable=True),
|
||||
sa.Column("detected_at", sa.DateTime(), default=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("schema_change_log")
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Add risk_assessment table
|
||||
|
||||
Revision ID: 010
|
||||
Revises: 009
|
||||
Create Date: 2026-04-23 21:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "010"
|
||||
down_revision: Union[str, None] = "009"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"risk_assessment",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column("entity_type", sa.String(20), nullable=False),
|
||||
sa.Column("entity_id", sa.Integer(), nullable=False),
|
||||
sa.Column("entity_name", sa.String(200), nullable=True),
|
||||
sa.Column("risk_score", sa.Float(), default=0.0),
|
||||
sa.Column("sensitivity_score", sa.Float(), default=0.0),
|
||||
sa.Column("exposure_score", sa.Float(), default=0.0),
|
||||
sa.Column("protection_score", sa.Float(), default=0.0),
|
||||
sa.Column("details", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(), default=sa.func.now(), onupdate=sa.func.now()),
|
||||
)
|
||||
op.create_index("idx_risk_entity", "risk_assessment", ["entity_type", "entity_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("idx_risk_entity", table_name="risk_assessment")
|
||||
op.drop_table("risk_assessment")
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Add compliance_rule and compliance_issue tables
|
||||
|
||||
Revision ID: 011
|
||||
Revises: 010
|
||||
Create Date: 2026-04-23 22:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "011"
|
||||
down_revision: Union[str, None] = "010"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"compliance_rule",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column("name", sa.String(200), nullable=False),
|
||||
sa.Column("standard", sa.String(50), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("check_logic", sa.String(50), nullable=False),
|
||||
sa.Column("severity", sa.String(20), default="medium"),
|
||||
sa.Column("is_active", sa.Boolean(), default=True),
|
||||
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
|
||||
)
|
||||
op.create_table(
|
||||
"compliance_issue",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column("rule_id", sa.Integer(), nullable=False),
|
||||
sa.Column("project_id", sa.Integer(), nullable=True),
|
||||
sa.Column("entity_type", sa.String(20), nullable=False),
|
||||
sa.Column("entity_id", sa.Integer(), nullable=False),
|
||||
sa.Column("entity_name", sa.String(200), nullable=True),
|
||||
sa.Column("severity", sa.String(20), default="medium"),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("suggestion", sa.Text(), nullable=True),
|
||||
sa.Column("status", sa.String(20), default="open"),
|
||||
sa.Column("resolved_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("compliance_issue")
|
||||
op.drop_table("compliance_rule")
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Add data_lineage table
|
||||
|
||||
Revision ID: 012
|
||||
Revises: 011
|
||||
Create Date: 2026-04-23 23:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "012"
|
||||
down_revision: Union[str, None] = "011"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"data_lineage",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column("source_table", sa.String(200), nullable=False),
|
||||
sa.Column("source_column", sa.String(200), nullable=True),
|
||||
sa.Column("target_table", sa.String(200), nullable=False),
|
||||
sa.Column("target_column", sa.String(200), nullable=True),
|
||||
sa.Column("relation_type", sa.String(20), default="direct"),
|
||||
sa.Column("script_content", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("data_lineage")
|
||||
@@ -0,0 +1,58 @@
|
||||
"""Add alert and work_order tables
|
||||
|
||||
Revision ID: 013
|
||||
Revises: 012
|
||||
Create Date: 2026-04-24 00:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "013"
|
||||
down_revision: Union[str, None] = "012"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"alert_rule",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column("name", sa.String(200), nullable=False),
|
||||
sa.Column("trigger_condition", sa.String(50), nullable=False),
|
||||
sa.Column("threshold", sa.Integer(), default=0),
|
||||
sa.Column("severity", sa.String(20), default="medium"),
|
||||
sa.Column("is_active", sa.Boolean(), default=True),
|
||||
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
|
||||
)
|
||||
op.create_table(
|
||||
"alert_record",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column("rule_id", sa.Integer(), sa.ForeignKey("alert_rule.id"), nullable=False),
|
||||
sa.Column("title", sa.String(200), nullable=False),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("severity", sa.String(20), default="medium"),
|
||||
sa.Column("status", sa.String(20), default="open"),
|
||||
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
|
||||
)
|
||||
op.create_table(
|
||||
"work_order",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column("alert_id", sa.Integer(), sa.ForeignKey("alert_record.id"), nullable=True),
|
||||
sa.Column("title", sa.String(200), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("assignee_id", sa.Integer(), sa.ForeignKey("sys_user.id"), nullable=True),
|
||||
sa.Column("status", sa.String(20), default="open"),
|
||||
sa.Column("resolution", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
|
||||
sa.Column("resolved_at", sa.DateTime(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("work_order")
|
||||
op.drop_table("alert_record")
|
||||
op.drop_table("alert_rule")
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Add api_asset and api_endpoint tables
|
||||
|
||||
Revision ID: 014
|
||||
Revises: 013
|
||||
Create Date: 2026-04-24 00:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "014"
|
||||
down_revision: Union[str, None] = "013"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"api_asset",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column("name", sa.String(200), nullable=False),
|
||||
sa.Column("base_url", sa.String(500), nullable=False),
|
||||
sa.Column("swagger_url", sa.String(500), nullable=True),
|
||||
sa.Column("auth_type", sa.String(50), default="none"),
|
||||
sa.Column("headers", sa.JSON(), default=dict),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("scan_status", sa.String(20), default="idle"),
|
||||
sa.Column("total_endpoints", sa.Integer(), default=0),
|
||||
sa.Column("sensitive_endpoints", sa.Integer(), default=0),
|
||||
sa.Column("created_by", sa.Integer(), sa.ForeignKey("sys_user.id"), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(), default=sa.func.now(), onupdate=sa.func.now()),
|
||||
)
|
||||
op.create_table(
|
||||
"api_endpoint",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column("asset_id", sa.Integer(), sa.ForeignKey("api_asset.id"), nullable=False),
|
||||
sa.Column("method", sa.String(10), nullable=False),
|
||||
sa.Column("path", sa.String(500), nullable=False),
|
||||
sa.Column("summary", sa.String(500), nullable=True),
|
||||
sa.Column("tags", sa.JSON(), default=list),
|
||||
sa.Column("parameters", sa.JSON(), default=list),
|
||||
sa.Column("request_body_schema", sa.JSON(), nullable=True),
|
||||
sa.Column("response_schema", sa.JSON(), nullable=True),
|
||||
sa.Column("sensitive_fields", sa.JSON(), default=list),
|
||||
sa.Column("risk_level", sa.String(20), default="low"),
|
||||
sa.Column("is_active", sa.Boolean(), default=True),
|
||||
sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("api_endpoint")
|
||||
op.drop_table("api_asset")
|
||||
@@ -1,6 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1 import auth, user, datasource, metadata, classification, project, task, report, dashboard
|
||||
from app.api.v1 import auth, user, datasource, metadata, classification, project, task, report, dashboard, masking, watermark, unstructured, schema_change, risk, compliance, lineage, alert, api_asset
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["认证"])
|
||||
@@ -12,3 +12,12 @@ api_router.include_router(project.router, prefix="/projects", tags=["项目管
|
||||
api_router.include_router(task.router, prefix="/tasks", tags=["任务管理"])
|
||||
api_router.include_router(report.router, prefix="/reports", tags=["报告管理"])
|
||||
api_router.include_router(dashboard.router, prefix="/dashboard", tags=["仪表盘"])
|
||||
api_router.include_router(masking.router, prefix="/masking", tags=["数据脱敏"])
|
||||
api_router.include_router(watermark.router, prefix="/watermark", tags=["数据水印"])
|
||||
api_router.include_router(unstructured.router, prefix="/unstructured", tags=["非结构化文件"])
|
||||
api_router.include_router(schema_change.router, prefix="/schema-changes", tags=["Schema变更"])
|
||||
api_router.include_router(risk.router, prefix="/risk", tags=["风险评估"])
|
||||
api_router.include_router(compliance.router, prefix="/compliance", tags=["合规检查"])
|
||||
api_router.include_router(lineage.router, prefix="/lineage", tags=["数据血缘"])
|
||||
api_router.include_router(alert.router, prefix="/alerts", tags=["告警与工单"])
|
||||
api_router.include_router(api_asset.router, prefix="/api-assets", tags=["API资产"])
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.common import ResponseModel, ListResponse
|
||||
from app.services import alert_service
|
||||
from app.api.deps import get_current_user, require_admin
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/init-rules")
|
||||
def init_alert_rules(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_admin),
|
||||
):
|
||||
alert_service.init_builtin_alert_rules(db)
|
||||
return ResponseModel(message="初始化完成")
|
||||
|
||||
|
||||
@router.post("/check")
|
||||
def check_alerts(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_admin),
|
||||
):
|
||||
records = alert_service.check_alerts(db)
|
||||
return ResponseModel(data={"alerts_created": len(records)})
|
||||
|
||||
|
||||
@router.get("/records")
|
||||
def list_alert_records(
|
||||
status: Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=500),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
query = db.query(alert_service.AlertRecord)
|
||||
if status:
|
||||
query = query.filter(alert_service.AlertRecord.status == status)
|
||||
total = query.count()
|
||||
items = query.order_by(alert_service.AlertRecord.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
return ListResponse(
|
||||
data=[{
|
||||
"id": r.id,
|
||||
"rule_id": r.rule_id,
|
||||
"title": r.title,
|
||||
"content": r.content,
|
||||
"severity": r.severity,
|
||||
"status": r.status,
|
||||
"created_at": r.created_at.isoformat() if r.created_at else None,
|
||||
} for r in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/work-orders")
|
||||
def create_work_order(
|
||||
alert_id: int,
|
||||
title: str,
|
||||
description: str = "",
|
||||
assignee_id: Optional[int] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
wo = alert_service.create_work_order(db, alert_id, title, description, assignee_id)
|
||||
return ResponseModel(data={"id": wo.id})
|
||||
|
||||
|
||||
@router.get("/work-orders")
|
||||
def list_work_orders(
|
||||
status: Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=500),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
from app.models.alert import WorkOrder
|
||||
query = db.query(WorkOrder)
|
||||
if status:
|
||||
query = query.filter(WorkOrder.status == status)
|
||||
total = query.count()
|
||||
items = query.order_by(WorkOrder.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
return ListResponse(
|
||||
data=[{
|
||||
"id": w.id,
|
||||
"alert_id": w.alert_id,
|
||||
"title": w.title,
|
||||
"status": w.status,
|
||||
"assignee_name": w.assignee.username if w.assignee else None,
|
||||
"created_at": w.created_at.isoformat() if w.created_at else None,
|
||||
} for w in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/work-orders/{wo_id}/status")
|
||||
def update_work_order(
|
||||
wo_id: int,
|
||||
status: str,
|
||||
resolution: str = "",
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
wo = alert_service.update_work_order_status(db, wo_id, status, resolution or None)
|
||||
if not wo:
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="工单不存在")
|
||||
return ResponseModel(data={"id": wo.id, "status": wo.status})
|
||||
@@ -0,0 +1,131 @@
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.common import ResponseModel, ListResponse
|
||||
from app.services import api_asset_service
|
||||
from app.api.deps import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class APIAssetCreate(BaseModel):
|
||||
name: str
|
||||
base_url: str
|
||||
swagger_url: Optional[str] = None
|
||||
auth_type: Optional[str] = "none"
|
||||
headers: Optional[dict] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
class APIAssetUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
swagger_url: Optional[str] = None
|
||||
auth_type: Optional[str] = None
|
||||
headers: Optional[dict] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
@router.post("")
|
||||
def create_asset(
|
||||
body: APIAssetCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
asset = api_asset_service.create_asset(db, body.dict(), current_user.id)
|
||||
return ResponseModel(data={"id": asset.id})
|
||||
|
||||
@router.get("")
|
||||
def list_assets(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=500),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
from app.models.api_asset import APIAsset
|
||||
query = db.query(APIAsset)
|
||||
total = query.count()
|
||||
items = query.order_by(APIAsset.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
return ListResponse(
|
||||
data=[{
|
||||
"id": a.id,
|
||||
"name": a.name,
|
||||
"base_url": a.base_url,
|
||||
"swagger_url": a.swagger_url,
|
||||
"auth_type": a.auth_type,
|
||||
"scan_status": a.scan_status,
|
||||
"total_endpoints": a.total_endpoints,
|
||||
"sensitive_endpoints": a.sensitive_endpoints,
|
||||
"created_at": a.created_at.isoformat() if a.created_at else None,
|
||||
} for a in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
@router.put("/{asset_id}")
|
||||
def update_asset(
|
||||
asset_id: int,
|
||||
body: APIAssetUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
asset = api_asset_service.update_asset(db, asset_id, body.dict(exclude_unset=True))
|
||||
if not asset:
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="资产不存在")
|
||||
return ResponseModel(data={"id": asset.id})
|
||||
|
||||
@router.delete("/{asset_id}")
|
||||
def delete_asset(
|
||||
asset_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
ok = api_asset_service.delete_asset(db, asset_id)
|
||||
if not ok:
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="资产不存在")
|
||||
return ResponseModel()
|
||||
|
||||
@router.post("/{asset_id}/scan")
|
||||
def scan_asset(
|
||||
asset_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
result = api_asset_service.scan_swagger(db, asset_id)
|
||||
return ResponseModel(data=result)
|
||||
|
||||
@router.get("/{asset_id}/endpoints")
|
||||
def list_endpoints(
|
||||
asset_id: int,
|
||||
risk_level: Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=500),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
from app.models.api_asset import APIEndpoint
|
||||
query = db.query(APIEndpoint).filter(APIEndpoint.asset_id == asset_id)
|
||||
if risk_level:
|
||||
query = query.filter(APIEndpoint.risk_level == risk_level)
|
||||
total = query.count()
|
||||
items = query.order_by(APIEndpoint.id.asc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
return ListResponse(
|
||||
data=[{
|
||||
"id": e.id,
|
||||
"method": e.method,
|
||||
"path": e.path,
|
||||
"summary": e.summary,
|
||||
"tags": e.tags,
|
||||
"parameters": e.parameters,
|
||||
"sensitive_fields": e.sensitive_fields,
|
||||
"risk_level": e.risk_level,
|
||||
"is_active": e.is_active,
|
||||
} for e in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@@ -238,3 +238,43 @@ def auto_classify(
|
||||
):
|
||||
result = classification_engine.run_auto_classification(db, project_id)
|
||||
return ResponseModel(data=result)
|
||||
|
||||
|
||||
@router.post("/ml-train")
|
||||
def ml_train(
|
||||
background: bool = True,
|
||||
model_name: Optional[str] = None,
|
||||
algorithm: str = "logistic_regression",
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_admin),
|
||||
):
|
||||
from app.tasks.ml_tasks import train_ml_model_task
|
||||
from app.services.ml_service import train_model
|
||||
|
||||
if background:
|
||||
task = train_ml_model_task.delay(model_name=model_name, algorithm=algorithm)
|
||||
return ResponseModel(data={"task_id": task.id, "status": task.state})
|
||||
else:
|
||||
mv = train_model(db, model_name=model_name, algorithm=algorithm)
|
||||
if mv:
|
||||
return ResponseModel(data={"model_id": mv.id, "accuracy": mv.accuracy, "train_samples": mv.train_samples})
|
||||
return ResponseModel(message="训练失败:样本不足或发生错误")
|
||||
|
||||
|
||||
@router.get("/ml-suggest/{project_id}")
|
||||
def ml_suggest(
|
||||
project_id: int,
|
||||
column_ids: Optional[str] = Query(None),
|
||||
top_k: int = Query(3, ge=1, le=5),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
from app.services.ml_service import suggest_for_project_columns
|
||||
ids = None
|
||||
if column_ids:
|
||||
ids = [int(x) for x in column_ids.split(",") if x.strip().isdigit()]
|
||||
result = suggest_for_project_columns(db, project_id, column_ids=ids, top_k=top_k)
|
||||
if not result.get("success"):
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result.get("message"))
|
||||
return ResponseModel(data=result["suggestions"])
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.common import ResponseModel, ListResponse
|
||||
from app.services import compliance_service
|
||||
from app.api.deps import get_current_user, require_admin
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/init-rules")
|
||||
def init_rules(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_admin),
|
||||
):
|
||||
compliance_service.init_builtin_rules(db)
|
||||
return ResponseModel(message="初始化完成")
|
||||
|
||||
|
||||
@router.post("/scan")
|
||||
def scan_compliance(
|
||||
project_id: Optional[int] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
issues = compliance_service.scan_compliance(db, project_id=project_id)
|
||||
return ResponseModel(data={"issues_found": len(issues)})
|
||||
|
||||
|
||||
@router.get("/issues")
|
||||
def list_issues(
|
||||
project_id: Optional[int] = Query(None),
|
||||
status: Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=500),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
items, total = compliance_service.list_issues(db, project_id=project_id, status=status, page=page, page_size=page_size)
|
||||
return ListResponse(
|
||||
data=[{
|
||||
"id": i.id,
|
||||
"rule_id": i.rule_id,
|
||||
"project_id": i.project_id,
|
||||
"entity_type": i.entity_type,
|
||||
"entity_name": i.entity_name,
|
||||
"severity": i.severity,
|
||||
"description": i.description,
|
||||
"suggestion": i.suggestion,
|
||||
"status": i.status,
|
||||
"created_at": i.created_at.isoformat() if i.created_at else None,
|
||||
} for i in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/issues/{issue_id}/resolve")
|
||||
def resolve_issue(
|
||||
issue_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
issue = compliance_service.resolve_issue(db, issue_id)
|
||||
if not issue:
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="问题不存在")
|
||||
return ResponseModel(message="已标记为已解决")
|
||||
@@ -0,0 +1,32 @@
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.common import ResponseModel
|
||||
from app.services import lineage_service
|
||||
from app.api.deps import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/parse")
|
||||
def parse_lineage(
|
||||
sql: str,
|
||||
target_table: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
records = lineage_service.parse_sql_lineage(db, sql, target_table)
|
||||
return ResponseModel(data={"records_created": len(records)})
|
||||
|
||||
|
||||
@router.get("/graph")
|
||||
def get_graph(
|
||||
table_name: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
graph = lineage_service.get_lineage_graph(db, table_name=table_name)
|
||||
return ResponseModel(data=graph)
|
||||
@@ -0,0 +1,88 @@
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.common import ResponseModel, ListResponse
|
||||
from app.services import masking_service
|
||||
from app.api.deps import get_current_user, require_admin
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/rules")
|
||||
def list_masking_rules(
|
||||
level_id: Optional[int] = Query(None),
|
||||
category_id: Optional[int] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=500),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
items, total = masking_service.list_masking_rules(db, level_id=level_id, category_id=category_id, page=page, page_size=page_size)
|
||||
return ListResponse(
|
||||
data=[{
|
||||
"id": r.id,
|
||||
"name": r.name,
|
||||
"level_id": r.level_id,
|
||||
"category_id": r.category_id,
|
||||
"algorithm": r.algorithm,
|
||||
"params": r.params,
|
||||
"is_active": r.is_active,
|
||||
"description": r.description,
|
||||
"level_name": r.level.name if r.level else None,
|
||||
"category_name": r.category.name if r.category else None,
|
||||
} for r in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/rules")
|
||||
def create_masking_rule(
|
||||
req: dict,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_admin),
|
||||
):
|
||||
item = masking_service.create_masking_rule(db, req)
|
||||
return ResponseModel(data={"id": item.id})
|
||||
|
||||
|
||||
@router.put("/rules/{rule_id}")
|
||||
def update_masking_rule(
|
||||
rule_id: int,
|
||||
req: dict,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_admin),
|
||||
):
|
||||
db_obj = masking_service.get_masking_rule(db, rule_id)
|
||||
if not db_obj:
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="规则不存在")
|
||||
item = masking_service.update_masking_rule(db, db_obj, req)
|
||||
return ResponseModel(data={"id": item.id})
|
||||
|
||||
|
||||
@router.delete("/rules/{rule_id}")
|
||||
def delete_masking_rule(
|
||||
rule_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_admin),
|
||||
):
|
||||
masking_service.delete_masking_rule(db, rule_id)
|
||||
return ResponseModel(message="删除成功")
|
||||
|
||||
|
||||
@router.post("/preview")
|
||||
def preview_masking(
|
||||
source_id: int,
|
||||
table_name: str,
|
||||
project_id: Optional[int] = None,
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
result = masking_service.preview_masking(db, source_id, table_name, project_id=project_id, limit=limit)
|
||||
return ResponseModel(data=result)
|
||||
@@ -101,9 +101,73 @@ def delete_project(
|
||||
@router.post("/{project_id}/auto-classify")
|
||||
def project_auto_classify(
|
||||
project_id: int,
|
||||
background: bool = True,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_manager),
|
||||
):
|
||||
from app.services.classification_engine import run_auto_classification
|
||||
result = run_auto_classification(db, project_id)
|
||||
return ResponseModel(data=result)
|
||||
from app.tasks.classification_tasks import auto_classify_task
|
||||
from celery.result import AsyncResult
|
||||
|
||||
project = project_service.get_project(db, project_id)
|
||||
if not project:
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
|
||||
|
||||
if background:
|
||||
# Check if already running
|
||||
if project.celery_task_id:
|
||||
existing = AsyncResult(project.celery_task_id)
|
||||
if existing.state in ("PENDING", "PROGRESS", "STARTED"):
|
||||
return ResponseModel(data={"task_id": project.celery_task_id, "status": existing.state})
|
||||
|
||||
task = auto_classify_task.delay(project_id)
|
||||
project.celery_task_id = task.id
|
||||
project.status = "scanning"
|
||||
db.commit()
|
||||
return ResponseModel(data={"task_id": task.id, "status": task.state})
|
||||
else:
|
||||
from app.services.classification_engine import run_auto_classification
|
||||
project.status = "scanning"
|
||||
db.commit()
|
||||
result = run_auto_classification(db, project_id)
|
||||
if result.get("success"):
|
||||
project.status = "assigning"
|
||||
else:
|
||||
project.status = "created"
|
||||
db.commit()
|
||||
return ResponseModel(data=result)
|
||||
|
||||
|
||||
@router.get("/{project_id}/auto-classify-status")
|
||||
def project_auto_classify_status(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
from celery.result import AsyncResult
|
||||
import json
|
||||
|
||||
project = project_service.get_project(db, project_id)
|
||||
if not project:
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
|
||||
|
||||
task_id = project.celery_task_id
|
||||
if not task_id:
|
||||
# Return persisted progress if any
|
||||
progress = json.loads(project.scan_progress) if project.scan_progress else None
|
||||
return ResponseModel(data={"status": project.status, "progress": progress})
|
||||
|
||||
result = AsyncResult(task_id)
|
||||
progress = None
|
||||
if result.state == "PROGRESS" and result.info:
|
||||
progress = result.info
|
||||
elif project.scan_progress:
|
||||
progress = json.loads(project.scan_progress)
|
||||
|
||||
return ResponseModel(data={
|
||||
"status": result.state,
|
||||
"task_id": task_id,
|
||||
"progress": progress,
|
||||
"project_status": project.status,
|
||||
})
|
||||
|
||||
@@ -44,12 +44,30 @@ def get_report_stats(
|
||||
@router.get("/projects/{project_id}/download")
|
||||
def download_report(
|
||||
project_id: int,
|
||||
format: str = "docx",
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if format == "excel":
|
||||
content = report_service.generate_excel_report(db, project_id)
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
headers={"Content-Disposition": f"attachment; filename=report_project_{project_id}.xlsx"},
|
||||
)
|
||||
content = report_service.generate_classification_report(db, project_id)
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
headers={"Content-Disposition": f"attachment; filename=report_project_{project_id}.docx"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}/summary")
|
||||
def report_summary(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
data = report_service.get_report_summary(db, project_id)
|
||||
return ResponseModel(data=data)
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.common import ResponseModel, ListResponse
|
||||
from app.services import risk_service
|
||||
from app.api.deps import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/recalculate")
|
||||
def recalculate_risk(
|
||||
project_id: Optional[int] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if project_id:
|
||||
result = risk_service.calculate_project_risk(db, project_id)
|
||||
return ResponseModel(data={"project_id": project_id, "risk_score": result.risk_score if result else 0})
|
||||
result = risk_service.calculate_all_projects_risk(db)
|
||||
return ResponseModel(data=result)
|
||||
|
||||
|
||||
@router.get("/top")
|
||||
def risk_top(
|
||||
entity_type: str = Query("project"),
|
||||
n: int = Query(10, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
items = risk_service.get_risk_top_n(db, entity_type=entity_type, n=n)
|
||||
return ListResponse(
|
||||
data=[{
|
||||
"id": r.id,
|
||||
"entity_type": r.entity_type,
|
||||
"entity_id": r.entity_id,
|
||||
"entity_name": r.entity_name,
|
||||
"risk_score": r.risk_score,
|
||||
"sensitivity_score": r.sensitivity_score,
|
||||
"exposure_score": r.exposure_score,
|
||||
"protection_score": r.protection_score,
|
||||
"updated_at": r.updated_at.isoformat() if r.updated_at else None,
|
||||
} for r in items],
|
||||
total=len(items),
|
||||
page=1,
|
||||
page_size=n,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}")
|
||||
def project_risk(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
from app.models.risk import RiskAssessment
|
||||
item = db.query(RiskAssessment).filter(
|
||||
RiskAssessment.entity_type == "project",
|
||||
RiskAssessment.entity_id == project_id,
|
||||
).first()
|
||||
if not item:
|
||||
return ResponseModel(data=None)
|
||||
return ResponseModel(data={
|
||||
"risk_score": item.risk_score,
|
||||
"sensitivity_score": item.sensitivity_score,
|
||||
"exposure_score": item.exposure_score,
|
||||
"protection_score": item.protection_score,
|
||||
"details": item.details,
|
||||
"updated_at": item.updated_at.isoformat() if item.updated_at else None,
|
||||
})
|
||||
@@ -0,0 +1,45 @@
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.common import ResponseModel, ListResponse
|
||||
from app.models.schema_change import SchemaChangeLog
|
||||
from app.api.deps import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/logs")
|
||||
def list_schema_changes(
|
||||
source_id: Optional[int] = Query(None),
|
||||
change_type: Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=500),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
query = db.query(SchemaChangeLog)
|
||||
if source_id:
|
||||
query = query.filter(SchemaChangeLog.source_id == source_id)
|
||||
if change_type:
|
||||
query = query.filter(SchemaChangeLog.change_type == change_type)
|
||||
total = query.count()
|
||||
items = query.order_by(SchemaChangeLog.detected_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
return ListResponse(
|
||||
data=[{
|
||||
"id": log.id,
|
||||
"source_id": log.source_id,
|
||||
"database_id": log.database_id,
|
||||
"table_id": log.table_id,
|
||||
"column_id": log.column_id,
|
||||
"change_type": log.change_type,
|
||||
"old_value": log.old_value,
|
||||
"new_value": log.new_value,
|
||||
"detected_at": log.detected_at.isoformat() if log.detected_at else None,
|
||||
} for log in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@@ -0,0 +1,108 @@
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query, UploadFile, File
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.common import ResponseModel, ListResponse
|
||||
from app.services import unstructured_service
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.events import minio_client
|
||||
from app.core.config import settings
|
||||
from app.models.metadata import UnstructuredFile
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
# Determine file type
|
||||
filename = file.filename or "unknown"
|
||||
ext = filename.split(".")[-1].lower() if "." in filename else ""
|
||||
type_map = {
|
||||
"docx": "word", "doc": "word",
|
||||
"xlsx": "excel", "xls": "excel",
|
||||
"pdf": "pdf",
|
||||
"txt": "txt",
|
||||
}
|
||||
file_type = type_map.get(ext, "unknown")
|
||||
|
||||
# Upload to MinIO
|
||||
storage_path = f"unstructured/{current_user.id}/{filename}"
|
||||
try:
|
||||
data = file.file.read()
|
||||
minio_client.put_object(
|
||||
settings.MINIO_BUCKET_NAME,
|
||||
storage_path,
|
||||
data=data,
|
||||
length=len(data),
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
return ResponseModel(message=f"上传失败: {e}")
|
||||
|
||||
db_obj = UnstructuredFile(
|
||||
original_name=filename,
|
||||
file_type=file_type,
|
||||
file_size=len(data),
|
||||
storage_path=storage_path,
|
||||
status="pending",
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
|
||||
# Trigger processing
|
||||
try:
|
||||
result = unstructured_service.process_unstructured_file(db, db_obj.id)
|
||||
return ResponseModel(data={"id": db_obj.id, "matches": result.get("matches", []), "status": "processed"})
|
||||
except Exception as e:
|
||||
return ResponseModel(data={"id": db_obj.id, "status": "error", "error": str(e)})
|
||||
|
||||
|
||||
@router.get("/files")
|
||||
def list_files(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=500),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
query = db.query(UnstructuredFile).filter(UnstructuredFile.created_by == current_user.id)
|
||||
total = query.count()
|
||||
items = query.order_by(UnstructuredFile.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
return ListResponse(
|
||||
data=[{
|
||||
"id": f.id,
|
||||
"original_name": f.original_name,
|
||||
"file_type": f.file_type,
|
||||
"file_size": f.file_size,
|
||||
"status": f.status,
|
||||
"analysis_result": f.analysis_result,
|
||||
"created_at": f.created_at.isoformat() if f.created_at else None,
|
||||
} for f in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/files/{file_id}/reprocess")
|
||||
def reprocess_file(
|
||||
file_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
file_obj = db.query(UnstructuredFile).filter(
|
||||
UnstructuredFile.id == file_id,
|
||||
UnstructuredFile.created_by == current_user.id,
|
||||
).first()
|
||||
if not file_obj:
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="文件不存在")
|
||||
result = unstructured_service.process_unstructured_file(db, file_id)
|
||||
return ResponseModel(data={"matches": result.get("matches", []), "status": "processed"})
|
||||
@@ -0,0 +1,23 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.common import ResponseModel
|
||||
from app.services import watermark_service
|
||||
from app.api.deps import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/trace")
|
||||
def trace_watermark(
|
||||
req: dict,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
text = req.get("text", "")
|
||||
result = watermark_service.trace_watermark(db, text)
|
||||
if not result:
|
||||
return ResponseModel(data=None, message="未检测到水印")
|
||||
return ResponseModel(data=result)
|
||||
@@ -10,6 +10,11 @@ class Settings(BaseSettings):
|
||||
DATABASE_URL: str = "postgresql+psycopg2://pdg:pdg_secret_2024@localhost:5432/prop_data_guard"
|
||||
REDIS_URL: str = "redis://localhost:6379/0"
|
||||
|
||||
# Database password encryption key (Fernet-compatible base64, 32 bytes)
|
||||
# If empty, will be derived from SECRET_KEY for backward compatibility.
|
||||
# STRONGLY recommended to set this explicitly in production.
|
||||
DB_ENCRYPTION_KEY: str = ""
|
||||
|
||||
SECRET_KEY: str = "prop-data-guard-super-secret-key-change-in-production"
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
|
||||
@@ -2,6 +2,14 @@ from app.models.user import User, Role, Dept, UserRole
|
||||
from app.models.metadata import DataSource, Database, DataTable, DataColumn, UnstructuredFile
|
||||
from app.models.classification import Category, DataLevel, RecognitionRule, ClassificationTemplate
|
||||
from app.models.project import ClassificationProject, ClassificationTask, ClassificationResult, ClassificationChange
|
||||
from app.models.ml import MLModelVersion
|
||||
from app.models.masking import MaskingRule
|
||||
from app.models.watermark import WatermarkLog
|
||||
from app.models.schema_change import SchemaChangeLog
|
||||
from app.models.risk import RiskAssessment
|
||||
from app.models.compliance import ComplianceRule, ComplianceIssue
|
||||
from app.models.alert import AlertRule, AlertRecord, WorkOrder
|
||||
from app.models.api_asset import APIAsset, APIEndpoint
|
||||
from app.models.log import OperationLog
|
||||
|
||||
__all__ = [
|
||||
@@ -9,5 +17,12 @@ __all__ = [
|
||||
"DataSource", "Database", "DataTable", "DataColumn", "UnstructuredFile",
|
||||
"Category", "DataLevel", "RecognitionRule", "ClassificationTemplate",
|
||||
"ClassificationProject", "ClassificationTask", "ClassificationResult", "ClassificationChange",
|
||||
"MLModelVersion",
|
||||
"MaskingRule",
|
||||
"WatermarkLog",
|
||||
"SchemaChangeLog",
|
||||
"RiskAssessment",
|
||||
"ComplianceRule",
|
||||
"ComplianceIssue",
|
||||
"OperationLog",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, ForeignKey, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class AlertRule(Base):
|
||||
__tablename__ = "alert_rule"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(200), nullable=False)
|
||||
trigger_condition = Column(String(50), nullable=False) # l5_count, risk_score, schema_change
|
||||
threshold = Column(Integer, default=0)
|
||||
severity = Column(String(20), default="medium")
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
class AlertRecord(Base):
|
||||
__tablename__ = "alert_record"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
rule_id = Column(Integer, ForeignKey("alert_rule.id"), nullable=False)
|
||||
title = Column(String(200), nullable=False)
|
||||
content = Column(Text)
|
||||
severity = Column(String(20), default="medium")
|
||||
status = Column(String(20), default="open") # open, acknowledged, resolved
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
rule = relationship("AlertRule")
|
||||
|
||||
|
||||
class WorkOrder(Base):
|
||||
__tablename__ = "work_order"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
alert_id = Column(Integer, ForeignKey("alert_record.id"), nullable=True)
|
||||
title = Column(String(200), nullable=False)
|
||||
description = Column(Text)
|
||||
assignee_id = Column(Integer, ForeignKey("sys_user.id"), nullable=True)
|
||||
status = Column(String(20), default="open") # open, in_progress, resolved
|
||||
resolution = Column(Text)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
resolved_at = Column(DateTime, nullable=True)
|
||||
|
||||
assignee = relationship("User")
|
||||
@@ -0,0 +1,41 @@
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, JSON, ForeignKey, BigInteger
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.core.database import Base
|
||||
from datetime import datetime
|
||||
|
||||
class APIAsset(Base):
|
||||
__tablename__ = "api_asset"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(200), nullable=False)
|
||||
base_url = Column(String(500), nullable=False)
|
||||
swagger_url = Column(String(500), nullable=True)
|
||||
auth_type = Column(String(50), default="none") # none, bearer, api_key, basic
|
||||
headers = Column(JSON, default=dict)
|
||||
description = Column(Text, nullable=True)
|
||||
scan_status = Column(String(20), default="idle") # idle, scanning, completed, failed
|
||||
total_endpoints = Column(Integer, default=0)
|
||||
sensitive_endpoints = Column(Integer, default=0)
|
||||
created_by = Column(Integer, ForeignKey("sys_user.id"), nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
endpoints = relationship("APIEndpoint", back_populates="asset", cascade="all, delete-orphan")
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
|
||||
class APIEndpoint(Base):
|
||||
__tablename__ = "api_endpoint"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
asset_id = Column(Integer, ForeignKey("api_asset.id"), nullable=False)
|
||||
method = Column(String(10), nullable=False) # GET, POST, PUT, DELETE, etc.
|
||||
path = Column(String(500), nullable=False)
|
||||
summary = Column(String(500), nullable=True)
|
||||
tags = Column(JSON, default=list)
|
||||
parameters = Column(JSON, default=list)
|
||||
request_body_schema = Column(JSON, nullable=True)
|
||||
response_schema = Column(JSON, nullable=True)
|
||||
sensitive_fields = Column(JSON, default=list) # detected PII fields
|
||||
risk_level = Column(String(20), default="low") # low, medium, high, critical
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
asset = relationship("APIAsset", back_populates="endpoints")
|
||||
@@ -0,0 +1,33 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, JSON
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class ComplianceRule(Base):
|
||||
__tablename__ = "compliance_rule"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(200), nullable=False)
|
||||
standard = Column(String(50), nullable=False) # dengbao, pipl, gdpr
|
||||
description = Column(Text)
|
||||
check_logic = Column(String(50), nullable=False) # check_masking, check_encryption, check_audit, check_level
|
||||
severity = Column(String(20), default="medium") # low, medium, high, critical
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
class ComplianceIssue(Base):
|
||||
__tablename__ = "compliance_issue"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
rule_id = Column(Integer, nullable=False)
|
||||
project_id = Column(Integer, nullable=True)
|
||||
entity_type = Column(String(20), nullable=False) # project, source, column
|
||||
entity_id = Column(Integer, nullable=False)
|
||||
entity_name = Column(String(200))
|
||||
severity = Column(String(20), default="medium")
|
||||
description = Column(Text)
|
||||
suggestion = Column(Text)
|
||||
status = Column(String(20), default="open") # open, resolved, ignored
|
||||
resolved_at = Column(DateTime, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
@@ -0,0 +1,16 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class DataLineage(Base):
|
||||
__tablename__ = "data_lineage"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
source_table = Column(String(200), nullable=False)
|
||||
source_column = Column(String(200), nullable=True)
|
||||
target_table = Column(String(200), nullable=False)
|
||||
target_column = Column(String(200), nullable=True)
|
||||
relation_type = Column(String(20), default="direct") # direct, derived, lookup
|
||||
script_content = Column(Text)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
@@ -0,0 +1,22 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, JSON, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class MaskingRule(Base):
|
||||
__tablename__ = "masking_rule"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(100), nullable=False)
|
||||
level_id = Column(Integer, ForeignKey("data_level.id"), nullable=True)
|
||||
category_id = Column(Integer, ForeignKey("category.id"), nullable=True)
|
||||
algorithm = Column(String(20), nullable=False) # mask, truncate, hash, generalize, replace
|
||||
params = Column(JSON, default=dict) # algorithm-specific params
|
||||
is_active = Column(Boolean, default=True)
|
||||
description = Column(Text)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
level = relationship("DataLevel")
|
||||
category = relationship("Category")
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, Text, BigInteger
|
||||
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, Text, BigInteger, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.core.database import Base
|
||||
@@ -36,6 +36,10 @@ class Database(Base):
|
||||
charset = Column(String(50))
|
||||
table_count = Column(Integer, default=0)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
last_scanned_at = Column(DateTime, nullable=True)
|
||||
checksum = Column(String(64), nullable=True)
|
||||
is_deleted = Column(Boolean, default=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
source = relationship("DataSource", back_populates="databases")
|
||||
tables = relationship("DataTable", back_populates="database", cascade="all, delete-orphan")
|
||||
@@ -51,6 +55,10 @@ class DataTable(Base):
|
||||
row_count = Column(BigInteger, default=0)
|
||||
column_count = Column(Integer, default=0)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
last_scanned_at = Column(DateTime, nullable=True)
|
||||
checksum = Column(String(64), nullable=True)
|
||||
is_deleted = Column(Boolean, default=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
database = relationship("Database", back_populates="tables")
|
||||
columns = relationship("DataColumn", back_populates="table", cascade="all, delete-orphan")
|
||||
@@ -68,6 +76,10 @@ class DataColumn(Base):
|
||||
is_nullable = Column(Boolean, default=True)
|
||||
sample_data = Column(Text) # JSON array of sample values
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
last_scanned_at = Column(DateTime, nullable=True)
|
||||
checksum = Column(String(64), nullable=True)
|
||||
is_deleted = Column(Boolean, default=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
table = relationship("DataTable", back_populates="columns")
|
||||
|
||||
@@ -81,6 +93,7 @@ class UnstructuredFile(Base):
|
||||
file_size = Column(BigInteger)
|
||||
storage_path = Column(String(500))
|
||||
extracted_text = Column(Text)
|
||||
analysis_result = Column(JSON, nullable=True) # JSON: {matches: [{rule_name, category, level, snippet}]}
|
||||
status = Column(String(20), default="pending") # pending, processed, error
|
||||
created_by = Column(Integer, ForeignKey("sys_user.id"))
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Float, DateTime, Boolean, Text
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class MLModelVersion(Base):
|
||||
__tablename__ = "ml_model_version"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(100), nullable=False)
|
||||
model_path = Column(String(500), nullable=False) # joblib dump path
|
||||
vectorizer_path = Column(String(500), nullable=False) # tfidf vectorizer path
|
||||
accuracy = Column(Float, default=0.0)
|
||||
train_samples = Column(Integer, default=0)
|
||||
train_date = Column(DateTime, default=datetime.utcnow)
|
||||
is_active = Column(Boolean, default=False)
|
||||
description = Column(Text)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
@@ -48,6 +48,10 @@ class ClassificationProject(Base):
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Async classification tracking
|
||||
celery_task_id = Column(String(100), nullable=True)
|
||||
scan_progress = Column(Text, nullable=True) # JSON: {"scanned": 0, "matched": 0, "total": 0}
|
||||
|
||||
template = relationship("ClassificationTemplate")
|
||||
tasks = relationship("ClassificationTask", back_populates="project", cascade="all, delete-orphan")
|
||||
results = relationship("ClassificationResult", back_populates="project", cascade="all, delete-orphan")
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Float, DateTime, ForeignKey, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class RiskAssessment(Base):
|
||||
__tablename__ = "risk_assessment"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
entity_type = Column(String(20), nullable=False) # project, source, table, field
|
||||
entity_id = Column(Integer, nullable=False)
|
||||
entity_name = Column(String(200))
|
||||
risk_score = Column(Float, default=0.0) # 0-100
|
||||
sensitivity_score = Column(Float, default=0.0)
|
||||
exposure_score = Column(Float, default=0.0)
|
||||
protection_score = Column(Float, default=0.0)
|
||||
details = Column(JSON, default=dict)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
@@ -0,0 +1,23 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class SchemaChangeLog(Base):
|
||||
__tablename__ = "schema_change_log"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
source_id = Column(Integer, ForeignKey("data_source.id"), nullable=False)
|
||||
database_id = Column(Integer, ForeignKey("meta_database.id"), nullable=True)
|
||||
table_id = Column(Integer, ForeignKey("meta_table.id"), nullable=True)
|
||||
column_id = Column(Integer, ForeignKey("meta_column.id"), nullable=True)
|
||||
change_type = Column(String(20), nullable=False) # add_table, drop_table, add_column, drop_column, change_type, change_comment
|
||||
old_value = Column(Text)
|
||||
new_value = Column(Text)
|
||||
detected_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
source = relationship("DataSource")
|
||||
database = relationship("Database")
|
||||
table = relationship("DataTable")
|
||||
column = relationship("DataColumn")
|
||||
@@ -0,0 +1,17 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class WatermarkLog(Base):
|
||||
__tablename__ = "watermark_log"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("sys_user.id"), nullable=False)
|
||||
export_type = Column(String(20), default="csv") # csv, excel, txt
|
||||
data_scope = Column(Text) # JSON: {source_id, table_name, row_count}
|
||||
watermark_key = Column(String(64), nullable=False) # random key for this export
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
user = relationship("User")
|
||||
@@ -0,0 +1,92 @@
|
||||
from typing import List, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.alert import AlertRule, AlertRecord, WorkOrder
|
||||
from app.models.project import ClassificationProject, ClassificationResult
|
||||
from app.models.risk import RiskAssessment
|
||||
|
||||
|
||||
def init_builtin_alert_rules(db: Session):
|
||||
if db.query(AlertRule).first():
|
||||
return
|
||||
rules = [
|
||||
AlertRule(name="L5字段数量突增", trigger_condition="l5_count", threshold=5, severity="high"),
|
||||
AlertRule(name="项目风险分过高", trigger_condition="risk_score", threshold=80, severity="critical"),
|
||||
AlertRule(name="Schema新增敏感字段", trigger_condition="schema_change", threshold=1, severity="medium"),
|
||||
]
|
||||
for r in rules:
|
||||
db.add(r)
|
||||
db.commit()
|
||||
|
||||
|
||||
def check_alerts(db: Session) -> List[AlertRecord]:
|
||||
"""Run alert checks and create records."""
|
||||
rules = db.query(AlertRule).filter(AlertRule.is_active == True).all()
|
||||
records = []
|
||||
for rule in rules:
|
||||
if rule.trigger_condition == "l5_count":
|
||||
projects = db.query(ClassificationProject).all()
|
||||
for p in projects:
|
||||
l5_count = db.query(ClassificationResult).filter(
|
||||
ClassificationResult.project_id == p.id,
|
||||
ClassificationResult.level_id.isnot(None),
|
||||
).join(ClassificationResult.level).filter(
|
||||
ClassificationResult.level.has(code="L5")
|
||||
).count()
|
||||
if l5_count >= rule.threshold:
|
||||
rec = AlertRecord(
|
||||
rule_id=rule.id,
|
||||
title=f"项目 {p.name} L5字段数量达到 {l5_count}",
|
||||
content=f"阈值: {rule.threshold}",
|
||||
severity=rule.severity,
|
||||
)
|
||||
db.add(rec)
|
||||
records.append(rec)
|
||||
elif rule.trigger_condition == "risk_score":
|
||||
risks = db.query(RiskAssessment).filter(
|
||||
RiskAssessment.entity_type == "project",
|
||||
RiskAssessment.risk_score >= rule.threshold,
|
||||
).all()
|
||||
for rsk in risks:
|
||||
rec = AlertRecord(
|
||||
rule_id=rule.id,
|
||||
title=f"项目 {rsk.entity_name} 风险分 {rsk.risk_score}",
|
||||
content=f"阈值: {rule.threshold}",
|
||||
severity=rule.severity,
|
||||
)
|
||||
db.add(rec)
|
||||
records.append(rec)
|
||||
db.commit()
|
||||
return records
|
||||
|
||||
|
||||
def create_work_order(db: Session, alert_id: int, title: str, description: str, assignee_id: Optional[int] = None) -> WorkOrder:
|
||||
wo = WorkOrder(
|
||||
alert_id=alert_id,
|
||||
title=title,
|
||||
description=description,
|
||||
assignee_id=assignee_id,
|
||||
)
|
||||
db.add(wo)
|
||||
db.commit()
|
||||
db.refresh(wo)
|
||||
return wo
|
||||
|
||||
|
||||
def update_work_order_status(db: Session, wo_id: int, status: str, resolution: str = None) -> WorkOrder:
|
||||
wo = db.query(WorkOrder).filter(WorkOrder.id == wo_id).first()
|
||||
if wo:
|
||||
wo.status = status
|
||||
if resolution:
|
||||
wo.resolution = resolution
|
||||
if status == "resolved":
|
||||
wo.resolved_at = datetime.utcnow()
|
||||
# Also resolve linked alert
|
||||
if wo.alert_id:
|
||||
alert = db.query(AlertRecord).filter(AlertRecord.id == wo.alert_id).first()
|
||||
if alert:
|
||||
alert.status = "resolved"
|
||||
db.commit()
|
||||
db.refresh(wo)
|
||||
return wo
|
||||
@@ -0,0 +1,174 @@
|
||||
import requests, json
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.api_asset import APIAsset, APIEndpoint
|
||||
from app.models.metadata import DataColumn
|
||||
from app.services.classification_engine import match_rule
|
||||
|
||||
# Simple sensitive keywords for API field detection
|
||||
SENSITIVE_KEYWORDS = [
|
||||
"password", "pwd", "passwd", "secret", "token", "credit_card", "card_no",
|
||||
"bank_account", "bank_card", "id_card", "id_number", "phone", "mobile",
|
||||
"email", "address", "name", "age", "gender", "salary", "income",
|
||||
"health", "medical", "biometric", "fingerprint", "face",
|
||||
]
|
||||
|
||||
def _is_sensitive_field(name: str, schema: dict) -> tuple[bool, str]:
|
||||
low = name.lower()
|
||||
for kw in SENSITIVE_KEYWORDS:
|
||||
if kw in low:
|
||||
return True, f"keyword:{kw}"
|
||||
# Check description / format hints
|
||||
desc = str(schema.get("description", "")).lower()
|
||||
fmt = str(schema.get("format", "")).lower()
|
||||
if "email" in fmt or "email" in desc:
|
||||
return True, "format:email"
|
||||
if "uuid" in fmt and "user" in low:
|
||||
return True, "format:user-uuid"
|
||||
return False, ""
|
||||
|
||||
def _extract_fields(schema: dict, prefix: str = "") -> list[dict]:
|
||||
fields = []
|
||||
if not isinstance(schema, dict):
|
||||
return fields
|
||||
props = schema.get("properties", {})
|
||||
for k, v in props.items():
|
||||
full_name = f"{prefix}.{k}" if prefix else k
|
||||
sensitive, reason = _is_sensitive_field(k, v)
|
||||
if sensitive:
|
||||
fields.append({"name": full_name, "type": v.get("type", "unknown"), "reason": reason})
|
||||
# nested object
|
||||
if v.get("type") == "object" and "properties" in v:
|
||||
fields.extend(_extract_fields(v, full_name))
|
||||
# array items
|
||||
if v.get("type") == "array" and isinstance(v.get("items"), dict):
|
||||
fields.extend(_extract_fields(v["items"], full_name + "[]"))
|
||||
return fields
|
||||
|
||||
def _risk_level_from_fields(fields: list[dict]) -> str:
|
||||
if not fields:
|
||||
return "low"
|
||||
high_keywords = {"password", "secret", "token", "credit_card", "bank_account", "biometric", "fingerprint", "face"}
|
||||
for f in fields:
|
||||
for kw in high_keywords:
|
||||
if kw in f["name"].lower():
|
||||
return "critical" if kw in {"password", "secret", "token", "biometric"} else "high"
|
||||
return "medium"
|
||||
|
||||
def scan_swagger(db: Session, asset_id: int) -> dict:
|
||||
asset = db.query(APIAsset).filter(APIAsset.id == asset_id).first()
|
||||
if not asset:
|
||||
return {"success": False, "error": "Asset not found"}
|
||||
if not asset.swagger_url:
|
||||
return {"success": False, "error": "No swagger_url configured"}
|
||||
|
||||
asset.scan_status = "scanning"
|
||||
db.commit()
|
||||
try:
|
||||
headers = dict(asset.headers or {})
|
||||
resp = requests.get(asset.swagger_url, headers=headers, timeout=30)
|
||||
resp.raise_for_status()
|
||||
spec = resp.json()
|
||||
|
||||
# Clear previous endpoints
|
||||
db.query(APIEndpoint).filter(APIEndpoint.asset_id == asset_id).delete()
|
||||
|
||||
paths = spec.get("paths", {})
|
||||
total = 0
|
||||
sensitive_total = 0
|
||||
for path, methods in paths.items():
|
||||
for method, detail in methods.items():
|
||||
if method.lower() not in {"get","post","put","patch","delete","head","options"}:
|
||||
continue
|
||||
total += 1
|
||||
parameters = []
|
||||
for p in detail.get("parameters", []):
|
||||
parameters.append({"name": p.get("name"), "in": p.get("in"), "required": p.get("required", False), "type": p.get("schema",{}).get("type","string")})
|
||||
req_schema = detail.get("requestBody", {}).get("content", {}).get("application/json", {}).get("schema")
|
||||
resp_schema = None
|
||||
for code, resp_detail in (detail.get("responses", {}).get("200", {}).get("content", {}) or {}).items():
|
||||
if isinstance(resp_detail, dict) and "schema" in resp_detail:
|
||||
resp_schema = resp_detail["schema"]
|
||||
break
|
||||
# Also try generic 200
|
||||
if resp_schema is None:
|
||||
ok = detail.get("responses", {}).get("200", {})
|
||||
for ct, cd in ok.get("content", {}).items():
|
||||
if isinstance(cd, dict) and "schema" in cd:
|
||||
resp_schema = cd["schema"]
|
||||
break
|
||||
|
||||
fields = []
|
||||
if req_schema:
|
||||
fields.extend(_extract_fields(req_schema))
|
||||
if resp_schema:
|
||||
fields.extend(_extract_fields(resp_schema))
|
||||
# dedup
|
||||
seen = set()
|
||||
unique_fields = []
|
||||
for f in fields:
|
||||
if f["name"] not in seen:
|
||||
seen.add(f["name"])
|
||||
unique_fields.append(f)
|
||||
|
||||
risk = _risk_level_from_fields(unique_fields)
|
||||
ep = APIEndpoint(
|
||||
asset_id=asset_id,
|
||||
method=method.upper(),
|
||||
path=path,
|
||||
summary=detail.get("summary", ""),
|
||||
tags=detail.get("tags", []),
|
||||
parameters=parameters,
|
||||
request_body_schema=req_schema,
|
||||
response_schema=resp_schema,
|
||||
sensitive_fields=unique_fields,
|
||||
risk_level=risk,
|
||||
)
|
||||
db.add(ep)
|
||||
if unique_fields:
|
||||
sensitive_total += 1
|
||||
|
||||
asset.scan_status = "completed"
|
||||
asset.total_endpoints = total
|
||||
asset.sensitive_endpoints = sensitive_total
|
||||
asset.updated_at = __import__('datetime').datetime.utcnow()
|
||||
db.commit()
|
||||
return {"success": True, "total": total, "sensitive": sensitive_total}
|
||||
except Exception as e:
|
||||
asset.scan_status = "failed"
|
||||
db.commit()
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def create_asset(db: Session, data: dict, user_id: Optional[int] = None) -> APIAsset:
|
||||
asset = APIAsset(
|
||||
name=data["name"],
|
||||
base_url=data["base_url"],
|
||||
swagger_url=data.get("swagger_url"),
|
||||
auth_type=data.get("auth_type", "none"),
|
||||
headers=data.get("headers"),
|
||||
description=data.get("description"),
|
||||
created_by=user_id,
|
||||
)
|
||||
db.add(asset)
|
||||
db.commit()
|
||||
db.refresh(asset)
|
||||
return asset
|
||||
|
||||
def update_asset(db: Session, asset_id: int, data: dict) -> Optional[APIAsset]:
|
||||
asset = db.query(APIAsset).filter(APIAsset.id == asset_id).first()
|
||||
if not asset:
|
||||
return None
|
||||
for k, v in data.items():
|
||||
if hasattr(asset, k):
|
||||
setattr(asset, k, v)
|
||||
db.commit()
|
||||
db.refresh(asset)
|
||||
return asset
|
||||
|
||||
def delete_asset(db: Session, asset_id: int) -> bool:
|
||||
asset = db.query(APIAsset).filter(APIAsset.id == asset_id).first()
|
||||
if not asset:
|
||||
return False
|
||||
db.delete(asset)
|
||||
db.commit()
|
||||
return True
|
||||
@@ -51,11 +51,39 @@ def match_rule(rule: RecognitionRule, column: DataColumn) -> Tuple[bool, float]:
|
||||
if t.strip().lower() in enums:
|
||||
return True, 0.90
|
||||
|
||||
elif rule.rule_type == "similarity":
|
||||
benchmarks = [b.strip().lower() for b in rule.rule_content.split(",") if b.strip()]
|
||||
if not benchmarks:
|
||||
return False, 0.0
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
texts = [t.lower() for t in targets] + benchmarks
|
||||
try:
|
||||
vectorizer = TfidfVectorizer(analyzer="char_wb", ngram_range=(2, 3))
|
||||
tfidf = vectorizer.fit_transform(texts)
|
||||
target_vecs = tfidf[:len(targets)]
|
||||
bench_vecs = tfidf[len(targets):]
|
||||
sim_matrix = cosine_similarity(target_vecs, bench_vecs)
|
||||
max_sim = float(sim_matrix.max())
|
||||
if max_sim >= 0.75:
|
||||
return True, round(min(max_sim, 0.99), 4)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False, 0.0
|
||||
|
||||
|
||||
def run_auto_classification(db: Session, project_id: int, source_ids: Optional[List[int]] = None) -> dict:
|
||||
"""Run automatic classification for a project."""
|
||||
def run_auto_classification(
|
||||
db: Session,
|
||||
project_id: int,
|
||||
source_ids: Optional[List[int]] = None,
|
||||
progress_callback=None,
|
||||
) -> dict:
|
||||
"""Run automatic classification for a project.
|
||||
|
||||
Args:
|
||||
progress_callback: Optional callable(scanned, matched, total) to report progress.
|
||||
"""
|
||||
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
|
||||
if not project:
|
||||
return {"success": False, "message": "项目不存在"}
|
||||
@@ -82,7 +110,10 @@ def run_auto_classification(db: Session, project_id: int, source_ids: Optional[L
|
||||
columns = columns_query.all()
|
||||
|
||||
matched_count = 0
|
||||
for col in columns:
|
||||
total = len(columns)
|
||||
report_interval = max(1, total // 20) # report ~20 times
|
||||
|
||||
for idx, col in enumerate(columns):
|
||||
# Check if already has a result for this project
|
||||
existing = db.query(ClassificationResult).filter(
|
||||
ClassificationResult.project_id == project_id,
|
||||
@@ -121,12 +152,20 @@ def run_auto_classification(db: Session, project_id: int, source_ids: Optional[L
|
||||
# Increment hit count
|
||||
best_rule.hit_count = (best_rule.hit_count or 0) + 1
|
||||
|
||||
# Report progress periodically
|
||||
if progress_callback and (idx + 1) % report_interval == 0:
|
||||
progress_callback(scanned=idx + 1, matched=matched_count, total=total)
|
||||
|
||||
db.commit()
|
||||
|
||||
# Final progress report
|
||||
if progress_callback:
|
||||
progress_callback(scanned=total, matched=matched_count, total=total)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"自动分类完成,共扫描 {len(columns)} 个字段,命中 {matched_count} 个",
|
||||
"scanned": len(columns),
|
||||
"message": f"自动分类完成,共扫描 {total} 个字段,命中 {matched_count} 个",
|
||||
"scanned": total,
|
||||
"matched": matched_count,
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
from typing import List, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.compliance import ComplianceRule, ComplianceIssue
|
||||
from app.models.project import ClassificationProject, ClassificationResult
|
||||
from app.models.classification import DataLevel
|
||||
from app.models.masking import MaskingRule
|
||||
|
||||
|
||||
def init_builtin_rules(db: Session):
|
||||
"""Initialize built-in compliance rules."""
|
||||
if db.query(ComplianceRule).first():
|
||||
return
|
||||
rules = [
|
||||
ComplianceRule(name="L4/L5字段未配置脱敏", standard="dengbao", description="等保2.0要求:四级及以上数据应进行脱敏处理", check_logic="check_masking", severity="high"),
|
||||
ComplianceRule(name="L5字段缺乏加密存储措施", standard="dengbao", description="等保2.0要求:五级数据应加密存储", check_logic="check_encryption", severity="critical"),
|
||||
ComplianceRule(name="个人敏感信息处理未授权", standard="pipl", description="个人信息保护法:处理敏感个人信息应取得单独同意", check_logic="check_level", severity="high"),
|
||||
ComplianceRule(name="数据跨境传输未评估", standard="gdpr", description="GDPR:个人数据跨境传输需进行影响评估", check_logic="check_audit", severity="medium"),
|
||||
]
|
||||
for r in rules:
|
||||
db.add(r)
|
||||
db.commit()
|
||||
|
||||
|
||||
def scan_compliance(db: Session, project_id: Optional[int] = None) -> List[ComplianceIssue]:
|
||||
"""Run compliance scan and generate issues."""
|
||||
rules = db.query(ComplianceRule).filter(ComplianceRule.is_active == True).all()
|
||||
issues = []
|
||||
|
||||
# Get masking rules for check_masking logic
|
||||
masking_rules = db.query(MaskingRule).filter(MaskingRule.is_active == True).all()
|
||||
masking_level_ids = {r.level_id for r in masking_rules if r.level_id}
|
||||
|
||||
query = db.query(ClassificationProject)
|
||||
if project_id:
|
||||
query = query.filter(ClassificationProject.id == project_id)
|
||||
projects = query.all()
|
||||
|
||||
for project in projects:
|
||||
results = db.query(ClassificationResult).filter(
|
||||
ClassificationResult.project_id == project.id,
|
||||
ClassificationResult.level_id.isnot(None),
|
||||
).all()
|
||||
|
||||
for r in results:
|
||||
if not r.level:
|
||||
continue
|
||||
level_code = r.level.code
|
||||
|
||||
for rule in rules:
|
||||
matched = False
|
||||
desc = ""
|
||||
suggestion = ""
|
||||
|
||||
if rule.check_logic == "check_masking" and level_code in ("L4", "L5"):
|
||||
if r.level_id not in masking_level_ids:
|
||||
matched = True
|
||||
desc = f"字段 '{r.column.name if r.column else '未知'}' 为 {level_code} 级,但未配置脱敏规则"
|
||||
suggestion = "请在【数据脱敏】模块为该分级配置脱敏策略"
|
||||
|
||||
elif rule.check_logic == "check_encryption" and level_code == "L5":
|
||||
# Placeholder: no encryption check in MVP, always flag
|
||||
matched = True
|
||||
desc = f"字段 '{r.column.name if r.column else '未知'}' 为 L5 级核心数据,建议确认是否加密存储"
|
||||
suggestion = "请确认该字段在数据库中已加密存储"
|
||||
|
||||
elif rule.check_logic == "check_level" and level_code in ("L4", "L5"):
|
||||
if r.source == "auto":
|
||||
matched = True
|
||||
desc = f"个人敏感字段 '{r.column.name if r.column else '未知'}' 目前为自动识别,建议人工复核并确认授权"
|
||||
suggestion = "请人工确认该字段的处理已取得合法授权"
|
||||
|
||||
elif rule.check_logic == "check_audit":
|
||||
# Placeholder for cross-border check
|
||||
pass
|
||||
|
||||
if matched:
|
||||
# Check if open issue already exists
|
||||
existing = db.query(ComplianceIssue).filter(
|
||||
ComplianceIssue.rule_id == rule.id,
|
||||
ComplianceIssue.project_id == project.id,
|
||||
ComplianceIssue.entity_type == "column",
|
||||
ComplianceIssue.entity_id == (r.column_id or 0),
|
||||
ComplianceIssue.status == "open",
|
||||
).first()
|
||||
if not existing:
|
||||
issue = ComplianceIssue(
|
||||
rule_id=rule.id,
|
||||
project_id=project.id,
|
||||
entity_type="column",
|
||||
entity_id=r.column_id or 0,
|
||||
entity_name=r.column.name if r.column else "未知",
|
||||
severity=rule.severity,
|
||||
description=desc,
|
||||
suggestion=suggestion,
|
||||
)
|
||||
db.add(issue)
|
||||
issues.append(issue)
|
||||
|
||||
db.commit()
|
||||
return issues
|
||||
|
||||
|
||||
def list_issues(db: Session, project_id: Optional[int] = None, status: Optional[str] = None, page: int = 1, page_size: int = 20):
|
||||
query = db.query(ComplianceIssue)
|
||||
if project_id:
|
||||
query = query.filter(ComplianceIssue.project_id == project_id)
|
||||
if status:
|
||||
query = query.filter(ComplianceIssue.status == status)
|
||||
total = query.count()
|
||||
items = query.order_by(ComplianceIssue.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
return items, total
|
||||
|
||||
|
||||
def resolve_issue(db: Session, issue_id: int):
|
||||
issue = db.query(ComplianceIssue).filter(ComplianceIssue.id == issue_id).first()
|
||||
if issue:
|
||||
issue.status = "resolved"
|
||||
issue.resolved_at = datetime.utcnow()
|
||||
db.commit()
|
||||
return issue
|
||||
@@ -1,3 +1,6 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Optional, List, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import HTTPException, status
|
||||
@@ -7,9 +10,28 @@ from app.models.metadata import DataSource
|
||||
from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSourceTest
|
||||
from app.core.config import settings
|
||||
|
||||
# Simple AES-like symmetric encryption for DB passwords
|
||||
# In production, use a proper KMS
|
||||
_fernet = Fernet(Fernet.generate_key())
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
"""Initialize Fernet with a stable key.
|
||||
|
||||
If DB_ENCRYPTION_KEY is set, use it directly.
|
||||
Otherwise derive deterministically from SECRET_KEY for backward compatibility.
|
||||
"""
|
||||
if settings.DB_ENCRYPTION_KEY:
|
||||
key = settings.DB_ENCRYPTION_KEY.encode()
|
||||
else:
|
||||
logger.warning(
|
||||
"DB_ENCRYPTION_KEY is not set. Deriving encryption key from SECRET_KEY. "
|
||||
"Please set DB_ENCRYPTION_KEY explicitly via environment variable or .env file."
|
||||
)
|
||||
digest = hashlib.sha256(settings.SECRET_KEY.encode()).digest()
|
||||
key = base64.urlsafe_b64encode(digest)
|
||||
return Fernet(key)
|
||||
|
||||
|
||||
_fernet = _get_fernet()
|
||||
|
||||
|
||||
def _encrypt_password(password: str) -> str:
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
import re
|
||||
from typing import List, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.lineage import DataLineage
|
||||
|
||||
|
||||
def _extract_tables(sql: str) -> List[str]:
|
||||
"""Extract table names from SQL using regex (simple heuristic)."""
|
||||
# Normalize SQL
|
||||
sql = re.sub(r"--.*?\n", " ", sql)
|
||||
sql = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL)
|
||||
sql = sql.lower()
|
||||
tables = set()
|
||||
# FROM / JOIN / INTO
|
||||
for pattern in [r"\bfrom\s+([a-z_][a-z0-9_]*)", r"\bjoin\s+([a-z_][a-z0-9_]*)"]:
|
||||
for m in re.finditer(pattern, sql):
|
||||
tables.add(m.group(1))
|
||||
return sorted(tables)
|
||||
|
||||
|
||||
def parse_sql_lineage(db: Session, sql: str, target_table: str) -> List[DataLineage]:
|
||||
"""Parse SQL and create lineage records pointing to target_table."""
|
||||
source_tables = _extract_tables(sql)
|
||||
records = []
|
||||
for st in source_tables:
|
||||
if st == target_table:
|
||||
continue
|
||||
existing = db.query(DataLineage).filter(
|
||||
DataLineage.source_table == st,
|
||||
DataLineage.target_table == target_table,
|
||||
).first()
|
||||
if not existing:
|
||||
rec = DataLineage(
|
||||
source_table=st,
|
||||
target_table=target_table,
|
||||
relation_type="direct",
|
||||
script_content=sql[:2000],
|
||||
)
|
||||
db.add(rec)
|
||||
records.append(rec)
|
||||
db.commit()
|
||||
return records
|
||||
|
||||
|
||||
def get_lineage_graph(db: Session, table_name: Optional[str] = None) -> dict:
|
||||
"""Build graph data for ECharts."""
|
||||
query = db.query(DataLineage)
|
||||
if table_name:
|
||||
query = query.filter(
|
||||
(DataLineage.source_table == table_name) | (DataLineage.target_table == table_name)
|
||||
)
|
||||
items = query.limit(500).all()
|
||||
|
||||
nodes = {}
|
||||
links = []
|
||||
for item in items:
|
||||
nodes[item.source_table] = {"name": item.source_table, "category": 0}
|
||||
nodes[item.target_table] = {"name": item.target_table, "category": 1}
|
||||
links.append({"source": item.source_table, "target": item.target_table, "value": item.relation_type})
|
||||
|
||||
return {
|
||||
"nodes": list(nodes.values()),
|
||||
"links": links,
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
import hashlib
|
||||
from typing import Optional, Dict
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import HTTPException, status
|
||||
from app.models.metadata import DataSource, Database, DataTable, DataColumn
|
||||
from app.models.project import ClassificationResult
|
||||
from app.models.masking import MaskingRule
|
||||
from app.services.datasource_service import get_datasource, _decrypt_password
|
||||
|
||||
|
||||
def get_masking_rule(db: Session, rule_id: int):
|
||||
return db.query(MaskingRule).filter(MaskingRule.id == rule_id).first()
|
||||
|
||||
|
||||
def list_masking_rules(db: Session, level_id=None, category_id=None, page=1, page_size=20):
|
||||
query = db.query(MaskingRule).filter(MaskingRule.is_active == True)
|
||||
if level_id:
|
||||
query = query.filter(MaskingRule.level_id == level_id)
|
||||
if category_id:
|
||||
query = query.filter(MaskingRule.category_id == category_id)
|
||||
total = query.count()
|
||||
items = query.offset((page - 1) * page_size).limit(page_size).all()
|
||||
return items, total
|
||||
|
||||
|
||||
def create_masking_rule(db: Session, data: dict):
|
||||
db_obj = MaskingRule(**data)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
|
||||
def update_masking_rule(db: Session, db_obj: MaskingRule, data: dict):
|
||||
for k, v in data.items():
|
||||
if v is not None:
|
||||
setattr(db_obj, k, v)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
|
||||
def delete_masking_rule(db: Session, rule_id: int):
|
||||
db_obj = get_masking_rule(db, rule_id)
|
||||
if not db_obj:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="规则不存在")
|
||||
db.delete(db_obj)
|
||||
db.commit()
|
||||
|
||||
|
||||
def _apply_mask(value, params):
|
||||
if not value:
|
||||
return value
|
||||
keep_prefix = params.get("keep_prefix", 3)
|
||||
keep_suffix = params.get("keep_suffix", 4)
|
||||
mask_char = params.get("mask_char", "*")
|
||||
if len(value) <= keep_prefix + keep_suffix:
|
||||
return mask_char * len(value)
|
||||
return value[:keep_prefix] + mask_char * (len(value) - keep_prefix - keep_suffix) + value[-keep_suffix:]
|
||||
|
||||
|
||||
def _apply_truncate(value, params):
|
||||
length = params.get("length", 3)
|
||||
suffix = params.get("suffix", "...")
|
||||
if not value or len(value) <= length:
|
||||
return value
|
||||
return value[:length] + suffix
|
||||
|
||||
|
||||
def _apply_hash(value, params):
|
||||
algorithm = params.get("algorithm", "sha256")
|
||||
if algorithm == "md5":
|
||||
return hashlib.md5(str(value).encode()).hexdigest()[:16]
|
||||
return hashlib.sha256(str(value).encode()).hexdigest()[:32]
|
||||
|
||||
|
||||
def _apply_generalize(value, params):
|
||||
try:
|
||||
step = params.get("step", 10)
|
||||
num = float(value)
|
||||
lower = int(num // step * step)
|
||||
upper = lower + step
|
||||
return f"{lower}-{upper}"
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
|
||||
def _apply_replace(value, params):
|
||||
return params.get("replacement", "[REDACTED]")
|
||||
|
||||
|
||||
def apply_masking(value, algorithm, params):
|
||||
if value is None:
|
||||
return None
|
||||
handlers = {
|
||||
"mask": _apply_mask,
|
||||
"truncate": _apply_truncate,
|
||||
"hash": _apply_hash,
|
||||
"generalize": _apply_generalize,
|
||||
"replace": _apply_replace,
|
||||
}
|
||||
handler = handlers.get(algorithm)
|
||||
if not handler:
|
||||
return value
|
||||
return handler(str(value), params or {})
|
||||
|
||||
|
||||
def _get_column_rules(db: Session, table_id: int, project_id=None):
|
||||
columns = db.query(DataColumn).filter(DataColumn.table_id == table_id).all()
|
||||
col_rules = {}
|
||||
results = {}
|
||||
if project_id:
|
||||
res_list = db.query(ClassificationResult).filter(
|
||||
ClassificationResult.project_id == project_id,
|
||||
ClassificationResult.column_id.in_([c.id for c in columns]),
|
||||
).all()
|
||||
results = {r.column_id: r for r in res_list}
|
||||
rules = db.query(MaskingRule).filter(MaskingRule.is_active == True).all()
|
||||
rule_map = {}
|
||||
for r in rules:
|
||||
key = (r.level_id, r.category_id)
|
||||
if key not in rule_map:
|
||||
rule_map[key] = r
|
||||
for col in columns:
|
||||
matched_rule = None
|
||||
if col.id in results:
|
||||
r = results[col.id]
|
||||
matched_rule = rule_map.get((r.level_id, r.category_id))
|
||||
if not matched_rule:
|
||||
matched_rule = rule_map.get((r.level_id, None))
|
||||
if not matched_rule:
|
||||
matched_rule = rule_map.get((None, r.category_id))
|
||||
col_rules[col.id] = matched_rule
|
||||
return col_rules
|
||||
|
||||
|
||||
def preview_masking(db: Session, source_id: int, table_name: str, project_id=None, limit=20):
|
||||
source = get_datasource(db, source_id)
|
||||
if not source:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在")
|
||||
table = (
|
||||
db.query(DataTable)
|
||||
.join(Database)
|
||||
.filter(Database.source_id == source_id, DataTable.name == table_name)
|
||||
.first()
|
||||
)
|
||||
if not table:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="表不存在")
|
||||
col_rules = _get_column_rules(db, table.id, project_id)
|
||||
from sqlalchemy import create_engine, text
|
||||
password = ""
|
||||
if source.encrypted_password:
|
||||
try:
|
||||
password = _decrypt_password(source.encrypted_password)
|
||||
except Exception:
|
||||
pass
|
||||
driver_map = {
|
||||
"mysql": "mysql+pymysql",
|
||||
"postgresql": "postgresql+psycopg2",
|
||||
"oracle": "oracle+cx_oracle",
|
||||
"sqlserver": "mssql+pymssql",
|
||||
}
|
||||
driver = driver_map.get(source.source_type, source.source_type)
|
||||
url = f"{driver}://{source.username}:{password}@{source.host}:{source.port}/{source.database_name}"
|
||||
engine = create_engine(url, pool_pre_ping=True)
|
||||
columns = db.query(DataColumn).filter(DataColumn.table_id == table.id).all()
|
||||
rows_raw = []
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text(f'SELECT * FROM "{table_name}" LIMIT {limit}'))
|
||||
rows_raw = [dict(row._mapping) for row in result]
|
||||
except Exception:
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text(f"SELECT * FROM {table_name} LIMIT {limit}"))
|
||||
rows_raw = [dict(row._mapping) for row in result]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"查询失败: {e}")
|
||||
masked_rows = []
|
||||
for raw in rows_raw:
|
||||
masked = {}
|
||||
for col in columns:
|
||||
val = raw.get(col.name)
|
||||
rule = col_rules.get(col.id)
|
||||
if rule:
|
||||
masked[col.name] = apply_masking(val, rule.algorithm, rule.params or {})
|
||||
else:
|
||||
masked[col.name] = val
|
||||
masked_rows.append(masked)
|
||||
return {
|
||||
"success": True,
|
||||
"columns": [{"name": c.name, "data_type": c.data_type, "has_rule": col_rules.get(c.id) is not None} for c in columns],
|
||||
"rows": masked_rows,
|
||||
"total_rows": len(masked_rows),
|
||||
}
|
||||
@@ -3,9 +3,23 @@ from sqlalchemy.orm import Session
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.models.metadata import DataSource, Database, DataTable, DataColumn
|
||||
from app.models.schema_change import SchemaChangeLog
|
||||
from app.services.datasource_service import get_datasource, _decrypt_password
|
||||
|
||||
|
||||
def _log_schema_change(db: Session, source_id: int, change_type: str, database_id: int = None, table_id: int = None, column_id: int = None, old_value: str = None, new_value: str = None):
|
||||
log = SchemaChangeLog(
|
||||
source_id=source_id,
|
||||
database_id=database_id,
|
||||
table_id=table_id,
|
||||
column_id=column_id,
|
||||
change_type=change_type,
|
||||
old_value=old_value,
|
||||
new_value=new_value,
|
||||
)
|
||||
db.add(log)
|
||||
|
||||
|
||||
def get_database(db: Session, db_id: int) -> Optional[Database]:
|
||||
return db.query(Database).filter(Database.id == db_id).first()
|
||||
|
||||
@@ -19,14 +33,14 @@ def get_column(db: Session, column_id: int) -> Optional[DataColumn]:
|
||||
|
||||
|
||||
def list_databases(db: Session, source_id: Optional[int] = None) -> List[Database]:
|
||||
query = db.query(Database)
|
||||
query = db.query(Database).filter(Database.is_deleted == False)
|
||||
if source_id:
|
||||
query = query.filter(Database.source_id == source_id)
|
||||
return query.all()
|
||||
|
||||
|
||||
def list_tables(db: Session, database_id: Optional[int] = None, keyword: Optional[str] = None) -> Tuple[List[DataTable], int]:
|
||||
query = db.query(DataTable)
|
||||
query = db.query(DataTable).filter(DataTable.is_deleted == False)
|
||||
if database_id:
|
||||
query = query.filter(DataTable.database_id == database_id)
|
||||
if keyword:
|
||||
@@ -37,7 +51,7 @@ def list_tables(db: Session, database_id: Optional[int] = None, keyword: Optiona
|
||||
|
||||
|
||||
def list_columns(db: Session, table_id: Optional[int] = None, keyword: Optional[str] = None, page: int = 1, page_size: int = 50) -> Tuple[List[DataColumn], int]:
|
||||
query = db.query(DataColumn)
|
||||
query = db.query(DataColumn).filter(DataColumn.is_deleted == False)
|
||||
if table_id:
|
||||
query = query.filter(DataColumn.table_id == table_id)
|
||||
if keyword:
|
||||
@@ -49,7 +63,7 @@ def list_columns(db: Session, table_id: Optional[int] = None, keyword: Optional[
|
||||
return items, total
|
||||
|
||||
|
||||
def build_tree(db: Session, source_id: Optional[int] = None) -> List[dict]:
|
||||
def build_tree(db: Session, source_id: Optional[int] = None, include_deleted: bool = False) -> List[dict]:
|
||||
sources = db.query(DataSource)
|
||||
if source_id:
|
||||
sources = sources.filter(DataSource.id == source_id)
|
||||
@@ -65,20 +79,24 @@ def build_tree(db: Session, source_id: Optional[int] = None) -> List[dict]:
|
||||
"meta": {"source_type": s.source_type, "status": s.status},
|
||||
}
|
||||
for d in s.databases:
|
||||
if not include_deleted and d.is_deleted:
|
||||
continue
|
||||
db_node = {
|
||||
"id": d.id,
|
||||
"name": d.name,
|
||||
"type": "database",
|
||||
"children": [],
|
||||
"meta": {"charset": d.charset, "table_count": d.table_count},
|
||||
"meta": {"charset": d.charset, "table_count": d.table_count, "is_deleted": d.is_deleted},
|
||||
}
|
||||
for t in d.tables:
|
||||
if not include_deleted and t.is_deleted:
|
||||
continue
|
||||
table_node = {
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"type": "table",
|
||||
"children": [],
|
||||
"meta": {"comment": t.comment, "row_count": t.row_count, "column_count": t.column_count},
|
||||
"meta": {"comment": t.comment, "row_count": t.row_count, "column_count": t.column_count, "is_deleted": t.is_deleted},
|
||||
}
|
||||
db_node["children"].append(table_node)
|
||||
source_node["children"].append(db_node)
|
||||
@@ -86,9 +104,16 @@ def build_tree(db: Session, source_id: Optional[int] = None) -> List[dict]:
|
||||
return result
|
||||
|
||||
|
||||
def _compute_checksum(data: dict) -> str:
|
||||
import hashlib, json
|
||||
payload = json.dumps(data, sort_keys=True, ensure_ascii=False, default=str)
|
||||
return hashlib.sha256(payload.encode()).hexdigest()[:32]
|
||||
|
||||
|
||||
def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
|
||||
from sqlalchemy import create_engine, inspect, text
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
source = get_datasource(db, source_id)
|
||||
if not source:
|
||||
@@ -118,29 +143,56 @@ def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
|
||||
inspector = inspect(engine)
|
||||
|
||||
db_names = inspector.get_schema_names() or [source.database_name]
|
||||
scan_time = datetime.utcnow()
|
||||
total_tables = 0
|
||||
total_columns = 0
|
||||
updated_tables = 0
|
||||
updated_columns = 0
|
||||
|
||||
for db_name in db_names:
|
||||
db_obj = db.query(Database).filter(Database.source_id == source.id, Database.name == db_name).first()
|
||||
db_checksum = _compute_checksum({"name": db_name})
|
||||
db_obj = db.query(Database).filter(
|
||||
Database.source_id == source.id, Database.name == db_name
|
||||
).first()
|
||||
if not db_obj:
|
||||
db_obj = Database(source_id=source.id, name=db_name)
|
||||
db_obj = Database(source_id=source.id, name=db_name, checksum=db_checksum, last_scanned_at=scan_time)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
else:
|
||||
db_obj.checksum = db_checksum
|
||||
db_obj.last_scanned_at = scan_time
|
||||
db_obj.is_deleted = False
|
||||
db_obj.deleted_at = None
|
||||
|
||||
table_names = inspector.get_table_names(schema=db_name)
|
||||
for tname in table_names:
|
||||
table_obj = db.query(DataTable).filter(DataTable.database_id == db_obj.id, DataTable.name == tname).first()
|
||||
t_checksum = _compute_checksum({"name": tname})
|
||||
table_obj = db.query(DataTable).filter(
|
||||
DataTable.database_id == db_obj.id, DataTable.name == tname
|
||||
).first()
|
||||
if not table_obj:
|
||||
table_obj = DataTable(database_id=db_obj.id, name=tname)
|
||||
table_obj = DataTable(database_id=db_obj.id, name=tname, checksum=t_checksum, last_scanned_at=scan_time)
|
||||
db.add(table_obj)
|
||||
db.commit()
|
||||
db.refresh(table_obj)
|
||||
_log_schema_change(db, source.id, "add_table", database_id=db_obj.id, table_id=table_obj.id, new_value=tname)
|
||||
else:
|
||||
if table_obj.checksum != t_checksum:
|
||||
table_obj.checksum = t_checksum
|
||||
updated_tables += 1
|
||||
table_obj.last_scanned_at = scan_time
|
||||
table_obj.is_deleted = False
|
||||
table_obj.deleted_at = None
|
||||
|
||||
columns = inspector.get_columns(tname, schema=db_name)
|
||||
for col in columns:
|
||||
col_obj = db.query(DataColumn).filter(DataColumn.table_id == table_obj.id, DataColumn.name == col["name"]).first()
|
||||
col_checksum = _compute_checksum({
|
||||
"name": col["name"],
|
||||
"type": str(col.get("type", "")),
|
||||
"max_length": col.get("max_length"),
|
||||
"comment": col.get("comment"),
|
||||
"nullable": col.get("nullable", True),
|
||||
})
|
||||
col_obj = db.query(DataColumn).filter(
|
||||
DataColumn.table_id == table_obj.id, DataColumn.name == col["name"]
|
||||
).first()
|
||||
if not col_obj:
|
||||
sample = None
|
||||
try:
|
||||
@@ -150,7 +202,6 @@ def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
|
||||
sample = json.dumps(samples, ensure_ascii=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
col_obj = DataColumn(
|
||||
table_id=table_obj.id,
|
||||
name=col["name"],
|
||||
@@ -159,13 +210,58 @@ def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
|
||||
comment=col.get("comment"),
|
||||
is_nullable=col.get("nullable", True),
|
||||
sample_data=sample,
|
||||
checksum=col_checksum,
|
||||
last_scanned_at=scan_time,
|
||||
)
|
||||
db.add(col_obj)
|
||||
total_columns += 1
|
||||
_log_schema_change(db, source.id, "add_column", database_id=db_obj.id, table_id=table_obj.id, column_id=col_obj.id, new_value=col["name"])
|
||||
else:
|
||||
if col_obj.checksum != col_checksum:
|
||||
old_val = f"type={col_obj.data_type}, len={col_obj.length}, comment={col_obj.comment}"
|
||||
new_val = f"type={str(col.get('type', ''))}, len={col.get('max_length')}, comment={col.get('comment')}"
|
||||
_log_schema_change(db, source.id, "change_type", database_id=db_obj.id, table_id=table_obj.id, column_id=col_obj.id, old_value=old_val, new_value=new_val)
|
||||
col_obj.checksum = col_checksum
|
||||
col_obj.data_type = str(col.get("type", ""))
|
||||
col_obj.length = col.get("max_length")
|
||||
col_obj.comment = col.get("comment")
|
||||
col_obj.is_nullable = col.get("nullable", True)
|
||||
updated_columns += 1
|
||||
col_obj.last_scanned_at = scan_time
|
||||
col_obj.is_deleted = False
|
||||
col_obj.deleted_at = None
|
||||
|
||||
total_tables += 1
|
||||
|
||||
db.commit()
|
||||
# Soft-delete objects not seen in this scan and log changes
|
||||
deleted_dbs = db.query(Database).filter(
|
||||
Database.source_id == source.id,
|
||||
Database.last_scanned_at < scan_time,
|
||||
).all()
|
||||
for d in deleted_dbs:
|
||||
_log_schema_change(db, source.id, "drop_database", database_id=d.id, old_value=d.name)
|
||||
d.is_deleted = True
|
||||
d.deleted_at = scan_time
|
||||
|
||||
for db_obj in db.query(Database).filter(Database.source_id == source.id).all():
|
||||
deleted_tables = db.query(DataTable).filter(
|
||||
DataTable.database_id == db_obj.id,
|
||||
DataTable.last_scanned_at < scan_time,
|
||||
).all()
|
||||
for t in deleted_tables:
|
||||
_log_schema_change(db, source.id, "drop_table", database_id=db_obj.id, table_id=t.id, old_value=t.name)
|
||||
t.is_deleted = True
|
||||
t.deleted_at = scan_time
|
||||
|
||||
for table_obj in db.query(DataTable).filter(DataTable.database_id == db_obj.id).all():
|
||||
deleted_cols = db.query(DataColumn).filter(
|
||||
DataColumn.table_id == table_obj.id,
|
||||
DataColumn.last_scanned_at < scan_time,
|
||||
).all()
|
||||
for c in deleted_cols:
|
||||
_log_schema_change(db, source.id, "drop_column", database_id=db_obj.id, table_id=table_obj.id, column_id=c.id, old_value=c.name)
|
||||
c.is_deleted = True
|
||||
c.deleted_at = scan_time
|
||||
|
||||
source.status = "active"
|
||||
db.commit()
|
||||
@@ -176,6 +272,8 @@ def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
|
||||
"databases": len(db_names),
|
||||
"tables": total_tables,
|
||||
"columns": total_columns,
|
||||
"updated_tables": updated_tables,
|
||||
"updated_columns": updated_columns,
|
||||
}
|
||||
except Exception as e:
|
||||
source.status = "error"
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
from typing import Optional, List, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.models.metadata import DataSource, Database, DataTable, DataColumn
|
||||
from app.services.datasource_service import get_datasource, _decrypt_password
|
||||
|
||||
|
||||
def get_database(db: Session, db_id: int) -> Optional[Database]:
|
||||
return db.query(Database).filter(Database.id == db_id).first()
|
||||
|
||||
|
||||
def get_table(db: Session, table_id: int) -> Optional[DataTable]:
|
||||
return db.query(DataTable).filter(DataTable.id == table_id).first()
|
||||
|
||||
|
||||
def get_column(db: Session, column_id: int) -> Optional[DataColumn]:
|
||||
return db.query(DataColumn).filter(DataColumn.id == column_id).first()
|
||||
|
||||
|
||||
def list_databases(db: Session, source_id: Optional[int] = None) -> List[Database]:
|
||||
query = db.query(Database)
|
||||
if source_id:
|
||||
query = query.filter(Database.source_id == source_id)
|
||||
return query.all()
|
||||
|
||||
|
||||
def list_tables(db: Session, database_id: Optional[int] = None, keyword: Optional[str] = None) -> Tuple[List[DataTable], int]:
|
||||
query = db.query(DataTable)
|
||||
if database_id:
|
||||
query = query.filter(DataTable.database_id == database_id)
|
||||
if keyword:
|
||||
query = query.filter(
|
||||
(DataTable.name.contains(keyword)) | (DataTable.comment.contains(keyword))
|
||||
)
|
||||
return query.all(), query.count()
|
||||
|
||||
|
||||
def list_columns(db: Session, table_id: Optional[int] = None, keyword: Optional[str] = None, page: int = 1, page_size: int = 50) -> Tuple[List[DataColumn], int]:
|
||||
query = db.query(DataColumn)
|
||||
if table_id:
|
||||
query = query.filter(DataColumn.table_id == table_id)
|
||||
if keyword:
|
||||
query = query.filter(
|
||||
(DataColumn.name.contains(keyword)) | (DataColumn.comment.contains(keyword))
|
||||
)
|
||||
total = query.count()
|
||||
items = query.offset((page - 1) * page_size).limit(page_size).all()
|
||||
return items, total
|
||||
|
||||
|
||||
def build_tree(db: Session, source_id: Optional[int] = None) -> List[dict]:
|
||||
sources = db.query(DataSource)
|
||||
if source_id:
|
||||
sources = sources.filter(DataSource.id == source_id)
|
||||
sources = sources.all()
|
||||
|
||||
result = []
|
||||
for s in sources:
|
||||
source_node = {
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"type": "source",
|
||||
"children": [],
|
||||
"meta": {"source_type": s.source_type, "status": s.status},
|
||||
}
|
||||
for d in s.databases:
|
||||
db_node = {
|
||||
"id": d.id,
|
||||
"name": d.name,
|
||||
"type": "database",
|
||||
"children": [],
|
||||
"meta": {"charset": d.charset, "table_count": d.table_count},
|
||||
}
|
||||
for t in d.tables:
|
||||
table_node = {
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"type": "table",
|
||||
"children": [],
|
||||
"meta": {"comment": t.comment, "row_count": t.row_count, "column_count": t.column_count},
|
||||
}
|
||||
db_node["children"].append(table_node)
|
||||
source_node["children"].append(db_node)
|
||||
result.append(source_node)
|
||||
return result
|
||||
|
||||
|
||||
def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
|
||||
from sqlalchemy import create_engine, inspect, text
|
||||
import json
|
||||
|
||||
source = get_datasource(db, source_id)
|
||||
if not source:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在")
|
||||
|
||||
driver_map = {
|
||||
"mysql": "mysql+pymysql",
|
||||
"postgresql": "postgresql+psycopg2",
|
||||
"oracle": "oracle+cx_oracle",
|
||||
"sqlserver": "mssql+pymssql",
|
||||
}
|
||||
driver = driver_map.get(source.source_type, source.source_type)
|
||||
|
||||
if source.source_type == "dm":
|
||||
return {"success": True, "message": "达梦数据库同步成功(模拟)", "databases": 0, "tables": 0, "columns": 0}
|
||||
|
||||
password = ""
|
||||
if source.encrypted_password:
|
||||
try:
|
||||
password = _decrypt_password(source.encrypted_password)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
url = f"{driver}://{source.username}:{password}@{source.host}:{source.port}/{source.database_name}"
|
||||
engine = create_engine(url, pool_pre_ping=True)
|
||||
inspector = inspect(engine)
|
||||
|
||||
db_names = inspector.get_schema_names() or [source.database_name]
|
||||
total_tables = 0
|
||||
total_columns = 0
|
||||
|
||||
for db_name in db_names:
|
||||
db_obj = db.query(Database).filter(Database.source_id == source.id, Database.name == db_name).first()
|
||||
if not db_obj:
|
||||
db_obj = Database(source_id=source.id, name=db_name)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
|
||||
table_names = inspector.get_table_names(schema=db_name)
|
||||
for tname in table_names:
|
||||
table_obj = db.query(DataTable).filter(DataTable.database_id == db_obj.id, DataTable.name == tname).first()
|
||||
if not table_obj:
|
||||
table_obj = DataTable(database_id=db_obj.id, name=tname)
|
||||
db.add(table_obj)
|
||||
db.commit()
|
||||
db.refresh(table_obj)
|
||||
|
||||
columns = inspector.get_columns(tname, schema=db_name)
|
||||
for col in columns:
|
||||
col_obj = db.query(DataColumn).filter(DataColumn.table_id == table_obj.id, DataColumn.name == col["name"]).first()
|
||||
if not col_obj:
|
||||
sample = None
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text(f'SELECT "{col["name"]}" FROM "{db_name}"."{tname}" LIMIT 5'))
|
||||
samples = [str(r[0]) for r in result if r[0] is not None]
|
||||
sample = json.dumps(samples, ensure_ascii=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
col_obj = DataColumn(
|
||||
table_id=table_obj.id,
|
||||
name=col["name"],
|
||||
data_type=str(col.get("type", "")),
|
||||
length=col.get("max_length"),
|
||||
comment=col.get("comment"),
|
||||
is_nullable=col.get("nullable", True),
|
||||
sample_data=sample,
|
||||
)
|
||||
db.add(col_obj)
|
||||
total_columns += 1
|
||||
|
||||
total_tables += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
source.status = "active"
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "元数据同步成功",
|
||||
"databases": len(db_names),
|
||||
"tables": total_tables,
|
||||
"columns": total_columns,
|
||||
}
|
||||
except Exception as e:
|
||||
source.status = "error"
|
||||
db.commit()
|
||||
return {"success": False, "message": f"同步失败: {str(e)}", "databases": 0, "tables": 0, "columns": 0}
|
||||
@@ -0,0 +1,195 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import accuracy_score
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.project import ClassificationResult
|
||||
from app.models.classification import Category
|
||||
from app.models.ml import MLModelVersion
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODELS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "ml_models")
|
||||
os.makedirs(MODELS_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def _build_text_features(column_name: str, comment: Optional[str], sample_data: Optional[str]) -> str:
|
||||
parts = [column_name]
|
||||
if comment:
|
||||
parts.append(comment)
|
||||
if sample_data:
|
||||
try:
|
||||
samples = json.loads(sample_data)
|
||||
if isinstance(samples, list):
|
||||
parts.extend([str(s) for s in samples[:5]])
|
||||
except Exception:
|
||||
parts.append(sample_data)
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _fetch_training_data(db: Session, min_samples_per_class: int = 5):
|
||||
results = (
|
||||
db.query(ClassificationResult)
|
||||
.filter(ClassificationResult.source == "manual")
|
||||
.filter(ClassificationResult.category_id.isnot(None))
|
||||
.all()
|
||||
)
|
||||
texts = []
|
||||
labels = []
|
||||
for r in results:
|
||||
if r.column:
|
||||
text = _build_text_features(r.column.name, r.column.comment, r.column.sample_data)
|
||||
texts.append(text)
|
||||
labels.append(r.category_id)
|
||||
from collections import Counter
|
||||
counts = Counter(labels)
|
||||
valid_classes = {c for c, n in counts.items() if n >= min_samples_per_class}
|
||||
filtered_texts = []
|
||||
filtered_labels = []
|
||||
for t, l in zip(texts, labels):
|
||||
if l in valid_classes:
|
||||
filtered_texts.append(t)
|
||||
filtered_labels.append(l)
|
||||
return filtered_texts, filtered_labels, len(filtered_labels)
|
||||
|
||||
|
||||
def train_model(db: Session, model_name: Optional[str] = None, algorithm: str = "logistic_regression", test_size: float = 0.2):
|
||||
texts, labels, total = _fetch_training_data(db)
|
||||
if total < 20:
|
||||
logger.warning("Not enough training data (need >= 20, got %d)", total)
|
||||
return None
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
texts, labels, test_size=test_size, random_state=42, stratify=labels
|
||||
)
|
||||
vectorizer = TfidfVectorizer(analyzer="char_wb", ngram_range=(2, 4), max_features=5000)
|
||||
X_train_vec = vectorizer.fit_transform(X_train)
|
||||
X_test_vec = vectorizer.transform(X_test)
|
||||
if algorithm == "logistic_regression":
|
||||
clf = LogisticRegression(max_iter=1000, multi_class="multinomial", solver="lbfgs")
|
||||
elif algorithm == "random_forest":
|
||||
clf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
|
||||
else:
|
||||
clf = LogisticRegression(max_iter=1000, multi_class="multinomial", solver="lbfgs")
|
||||
clf.fit(X_train_vec, y_train)
|
||||
y_pred = clf.predict(X_test_vec)
|
||||
acc = accuracy_score(y_test, y_pred)
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
name = model_name or f"model_{timestamp}"
|
||||
model_path = os.path.join(MODELS_DIR, f"{name}_clf.joblib")
|
||||
vec_path = os.path.join(MODELS_DIR, f"{name}_tfidf.joblib")
|
||||
joblib.dump(clf, model_path)
|
||||
joblib.dump(vectorizer, vec_path)
|
||||
db.query(MLModelVersion).filter(MLModelVersion.is_active == True).update({"is_active": False})
|
||||
mv = MLModelVersion(
|
||||
name=name,
|
||||
model_path=model_path,
|
||||
vectorizer_path=vec_path,
|
||||
accuracy=acc,
|
||||
train_samples=total,
|
||||
is_active=True,
|
||||
description=f"Algorithm: {algorithm}, test_accuracy: {acc:.4f}",
|
||||
)
|
||||
db.add(mv)
|
||||
db.commit()
|
||||
db.refresh(mv)
|
||||
logger.info("Trained model %s with accuracy %.4f on %d samples", name, acc, total)
|
||||
return mv
|
||||
|
||||
|
||||
def _get_active_model(db: Session):
|
||||
mv = db.query(MLModelVersion).filter(MLModelVersion.is_active == True).first()
|
||||
if not mv or not os.path.exists(mv.model_path) or not os.path.exists(mv.vectorizer_path):
|
||||
return None
|
||||
clf = joblib.load(mv.model_path)
|
||||
vectorizer = joblib.load(mv.vectorizer_path)
|
||||
return clf, vectorizer, mv
|
||||
|
||||
|
||||
def predict_categories(db: Session, texts: List[str], top_k: int = 3):
|
||||
model_tuple = _get_active_model(db)
|
||||
if not model_tuple:
|
||||
return [[] for _ in texts]
|
||||
clf, vectorizer, mv = model_tuple
|
||||
X = vectorizer.transform(texts)
|
||||
if hasattr(clf, "predict_proba"):
|
||||
probs = clf.predict_proba(X)
|
||||
else:
|
||||
preds = clf.predict(X)
|
||||
return [[{"category_id": int(p), "confidence": 1.0}] for p in preds]
|
||||
classes = [int(c) for c in clf.classes_]
|
||||
results = []
|
||||
for prob in probs:
|
||||
top_idx = np.argsort(prob)[::-1][:top_k]
|
||||
suggestions = []
|
||||
for idx in top_idx:
|
||||
cat_id = classes[idx]
|
||||
confidence = float(prob[idx])
|
||||
if confidence > 0.01:
|
||||
suggestions.append({"category_id": cat_id, "confidence": round(confidence, 4)})
|
||||
results.append(suggestions)
|
||||
return results
|
||||
|
||||
|
||||
def suggest_for_project_columns(db: Session, project_id: int, column_ids: Optional[List[int]] = None, top_k: int = 3):
|
||||
from app.models.project import ClassificationProject
|
||||
from app.models.metadata import DataColumn
|
||||
|
||||
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
|
||||
if not project:
|
||||
return {"success": False, "message": "项目不存在"}
|
||||
|
||||
query = db.query(DataColumn).join(
|
||||
ClassificationResult,
|
||||
(ClassificationResult.column_id == DataColumn.id) & (ClassificationResult.project_id == project_id),
|
||||
isouter=True,
|
||||
)
|
||||
if column_ids:
|
||||
query = query.filter(DataColumn.id.in_(column_ids))
|
||||
|
||||
columns = query.all()
|
||||
texts = []
|
||||
col_map = []
|
||||
for col in columns:
|
||||
texts.append(_build_text_features(col.name, col.comment, col.sample_data))
|
||||
col_map.append(col)
|
||||
|
||||
if not texts:
|
||||
return {"success": True, "suggestions": [], "message": "没有可预测的字段"}
|
||||
|
||||
predictions = predict_categories(db, texts, top_k=top_k)
|
||||
suggestions = []
|
||||
all_category_ids = set()
|
||||
for col, preds in zip(col_map, predictions):
|
||||
for p in preds:
|
||||
all_category_ids.add(p["category_id"])
|
||||
|
||||
categories = {c.id: c for c in db.query(Category).filter(Category.id.in_(list(all_category_ids))).all()}
|
||||
|
||||
for col, preds in zip(col_map, predictions):
|
||||
item = {
|
||||
"column_id": col.id,
|
||||
"column_name": col.name,
|
||||
"table_name": col.table.name if col.table else None,
|
||||
"suggestions": [],
|
||||
}
|
||||
for p in preds:
|
||||
cat = categories.get(p["category_id"])
|
||||
item["suggestions"].append({
|
||||
"category_id": p["category_id"],
|
||||
"category_name": cat.name if cat else None,
|
||||
"category_code": cat.code if cat else None,
|
||||
"confidence": p["confidence"],
|
||||
})
|
||||
suggestions.append(item)
|
||||
|
||||
return {"success": True, "suggestions": suggestions}
|
||||
@@ -94,3 +94,155 @@ def generate_classification_report(db: Session, project_id: int) -> bytes:
|
||||
doc.save(buffer)
|
||||
buffer.seek(0)
|
||||
return buffer.read()
|
||||
|
||||
|
||||
def generate_excel_report(db: Session, project_id: int) -> bytes:
|
||||
"""Generate an Excel report for a classification project."""
|
||||
from openpyxl import Workbook
|
||||
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
|
||||
from openpyxl.chart import PieChart, Reference
|
||||
from sqlalchemy import func
|
||||
|
||||
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
|
||||
if not project:
|
||||
raise ValueError("项目不存在")
|
||||
|
||||
wb = Workbook()
|
||||
ws = wb.active
|
||||
ws.title = "报告概览"
|
||||
|
||||
# Title
|
||||
ws.merge_cells('A1:D1')
|
||||
ws['A1'] = '数据分类分级项目报告'
|
||||
ws['A1'].font = Font(size=18, bold=True)
|
||||
ws['A1'].alignment = Alignment(horizontal='center')
|
||||
|
||||
# Basic info
|
||||
ws['A3'] = '项目名称'
|
||||
ws['B3'] = project.name
|
||||
ws['A4'] = '报告生成时间'
|
||||
ws['B4'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
ws['A5'] = '项目状态'
|
||||
ws['B5'] = project.status
|
||||
ws['A6'] = '模板版本'
|
||||
ws['B6'] = project.template.version if project.template else 'N/A'
|
||||
|
||||
# Statistics
|
||||
results = db.query(ClassificationResult).filter(ClassificationResult.project_id == project_id).all()
|
||||
total = len(results)
|
||||
auto_count = sum(1 for r in results if r.source == 'auto')
|
||||
manual_count = sum(1 for r in results if r.source == 'manual')
|
||||
|
||||
ws['A8'] = '总字段数'
|
||||
ws['B8'] = total
|
||||
ws['A9'] = '自动识别'
|
||||
ws['B9'] = auto_count
|
||||
ws['A10'] = '人工打标'
|
||||
ws['B10'] = manual_count
|
||||
|
||||
# Level distribution
|
||||
ws['A12'] = '分级'
|
||||
ws['B12'] = '数量'
|
||||
ws['C12'] = '占比'
|
||||
ws['A12'].font = Font(bold=True)
|
||||
ws['B12'].font = Font(bold=True)
|
||||
ws['C12'].font = Font(bold=True)
|
||||
|
||||
level_stats = {}
|
||||
for r in results:
|
||||
if r.level:
|
||||
level_stats[r.level.name] = level_stats.get(r.level.name, 0) + 1
|
||||
|
||||
red_fill = PatternFill(start_color='FFCCCC', end_color='FFCCCC', fill_type='solid')
|
||||
row = 13
|
||||
for level_name, count in sorted(level_stats.items(), key=lambda x: -x[1]):
|
||||
ws.cell(row=row, column=1, value=level_name)
|
||||
ws.cell(row=row, column=2, value=count)
|
||||
pct = f'{count / total * 100:.1f}%' if total > 0 else '0%'
|
||||
ws.cell(row=row, column=3, value=pct)
|
||||
if 'L4' in level_name or 'L5' in level_name:
|
||||
for c in range(1, 4):
|
||||
ws.cell(row=row, column=c).fill = red_fill
|
||||
row += 1
|
||||
|
||||
# High risk sheet
|
||||
ws2 = wb.create_sheet("高敏感数据清单")
|
||||
ws2.append(['字段名', '所属表', '分类', '分级', '来源', '置信度'])
|
||||
for cell in ws2[1]:
|
||||
cell.font = Font(bold=True)
|
||||
cell.fill = PatternFill(start_color='DDEBF7', end_color='DDEBF7', fill_type='solid')
|
||||
|
||||
high_risk = [r for r in results if r.level and r.level.code in ('L4', 'L5')]
|
||||
for r in high_risk[:500]:
|
||||
ws2.append([
|
||||
r.column.name if r.column else 'N/A',
|
||||
r.column.table.name if r.column and r.column.table else 'N/A',
|
||||
r.category.name if r.category else 'N/A',
|
||||
r.level.name if r.level else 'N/A',
|
||||
'自动' if r.source == 'auto' else '人工',
|
||||
r.confidence,
|
||||
])
|
||||
|
||||
# Auto-fit column widths roughly
|
||||
for ws_sheet in [ws, ws2]:
|
||||
for column in ws_sheet.columns:
|
||||
max_length = 0
|
||||
column_letter = column[0].column_letter
|
||||
for cell in column:
|
||||
try:
|
||||
if len(str(cell.value)) > max_length:
|
||||
max_length = len(str(cell.value))
|
||||
except:
|
||||
pass
|
||||
adjusted_width = min(max_length + 2, 50)
|
||||
ws_sheet.column_dimensions[column_letter].width = adjusted_width
|
||||
|
||||
buffer = BytesIO()
|
||||
wb.save(buffer)
|
||||
buffer.seek(0)
|
||||
return buffer.read()
|
||||
|
||||
|
||||
def get_report_summary(db: Session, project_id: int) -> dict:
|
||||
"""Get aggregated report data for PDF generation (frontend)."""
|
||||
from sqlalchemy import func
|
||||
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
|
||||
if not project:
|
||||
raise ValueError("项目不存在")
|
||||
|
||||
results = db.query(ClassificationResult).filter(ClassificationResult.project_id == project_id).all()
|
||||
total = len(results)
|
||||
auto_count = sum(1 for r in results if r.source == 'auto')
|
||||
manual_count = sum(1 for r in results if r.source == 'manual')
|
||||
|
||||
level_stats = {}
|
||||
for r in results:
|
||||
if r.level:
|
||||
level_stats[r.level.name] = level_stats.get(r.level.name, 0) + 1
|
||||
|
||||
high_risk = []
|
||||
for r in results:
|
||||
if r.level and r.level.code in ('L4', 'L5'):
|
||||
high_risk.append({
|
||||
"column_name": r.column.name if r.column else 'N/A',
|
||||
"table_name": r.column.table.name if r.column and r.column.table else 'N/A',
|
||||
"category_name": r.category.name if r.category else 'N/A',
|
||||
"level_name": r.level.name if r.level else 'N/A',
|
||||
"source": '自动' if r.source == 'auto' else '人工',
|
||||
"confidence": r.confidence,
|
||||
})
|
||||
|
||||
return {
|
||||
"project_name": project.name,
|
||||
"status": project.status,
|
||||
"template_version": project.template.version if project.template else 'N/A',
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"total": total,
|
||||
"auto": auto_count,
|
||||
"manual": manual_count,
|
||||
"level_distribution": [
|
||||
{"name": name, "count": count}
|
||||
for name, count in sorted(level_stats.items(), key=lambda x: -x[1])
|
||||
],
|
||||
"high_risk": high_risk[:100],
|
||||
}
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
from typing import List, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.project import ClassificationProject, ClassificationResult
|
||||
from app.models.classification import DataLevel
|
||||
from app.models.metadata import DataSource, Database, DataTable, DataColumn
|
||||
from app.models.masking import MaskingRule
|
||||
from app.models.risk import RiskAssessment
|
||||
|
||||
|
||||
def _get_level_weight(level_code: str) -> int:
|
||||
weights = {"L1": 1, "L2": 2, "L3": 3, "L4": 4, "L5": 5}
|
||||
return weights.get(level_code, 1)
|
||||
|
||||
|
||||
def calculate_project_risk(db: Session, project_id: int) -> RiskAssessment:
|
||||
"""Calculate risk score for a project."""
|
||||
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
|
||||
if not project:
|
||||
return None
|
||||
|
||||
results = db.query(ClassificationResult).filter(
|
||||
ClassificationResult.project_id == project_id,
|
||||
ClassificationResult.level_id.isnot(None),
|
||||
).all()
|
||||
|
||||
total_risk = 0.0
|
||||
total_sensitivity = 0.0
|
||||
total_exposure = 0.0
|
||||
total_protection = 0.0
|
||||
detail_items = []
|
||||
|
||||
# Get all active masking rules for quick lookup
|
||||
rules = db.query(MaskingRule).filter(MaskingRule.is_active == True).all()
|
||||
rule_level_ids = {r.level_id for r in rules if r.level_id}
|
||||
rule_cat_ids = {r.category_id for r in rules if r.category_id}
|
||||
|
||||
for r in results:
|
||||
if not r.level:
|
||||
continue
|
||||
level_weight = _get_level_weight(r.level.code)
|
||||
# Exposure: count source connections for the column's table
|
||||
source_count = 1
|
||||
if r.column and r.column.table and r.column.table.database:
|
||||
# Simple: if table exists in multiple dbs (rare), count them
|
||||
source_count = max(1, len(r.column.table.database.source.databases or []))
|
||||
exposure_factor = 1 + source_count * 0.2
|
||||
|
||||
# Protection: check if masking rule exists for this level/category
|
||||
has_masking = (r.level_id in rule_level_ids) or (r.category_id in rule_cat_ids)
|
||||
protection_rate = 0.3 if has_masking else 0.0
|
||||
|
||||
item_risk = level_weight * exposure_factor * (1 - protection_rate)
|
||||
total_risk += item_risk
|
||||
total_sensitivity += level_weight
|
||||
total_exposure += exposure_factor
|
||||
total_protection += protection_rate
|
||||
|
||||
detail_items.append({
|
||||
"column_id": r.column_id,
|
||||
"column_name": r.column.name if r.column else None,
|
||||
"level": r.level.code if r.level else None,
|
||||
"level_weight": level_weight,
|
||||
"exposure_factor": round(exposure_factor, 2),
|
||||
"protection_rate": protection_rate,
|
||||
"item_risk": round(item_risk, 2),
|
||||
})
|
||||
|
||||
# Normalize to 0-100 (heuristic: assume max reasonable raw score is 15 per field)
|
||||
count = len(detail_items) or 1
|
||||
max_raw = count * 15
|
||||
risk_score = min(100, (total_risk / max_raw) * 100) if max_raw > 0 else 0
|
||||
|
||||
# Upsert risk assessment
|
||||
existing = db.query(RiskAssessment).filter(
|
||||
RiskAssessment.entity_type == "project",
|
||||
RiskAssessment.entity_id == project_id,
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
existing.risk_score = round(risk_score, 2)
|
||||
existing.sensitivity_score = round(total_sensitivity / count, 2)
|
||||
existing.exposure_score = round(total_exposure / count, 2)
|
||||
existing.protection_score = round(total_protection / count, 2)
|
||||
existing.details = {"items": detail_items[:100], "total_items": len(detail_items)}
|
||||
existing.updated_at = datetime.utcnow()
|
||||
else:
|
||||
existing = RiskAssessment(
|
||||
entity_type="project",
|
||||
entity_id=project_id,
|
||||
entity_name=project.name,
|
||||
risk_score=round(risk_score, 2),
|
||||
sensitivity_score=round(total_sensitivity / count, 2),
|
||||
exposure_score=round(total_exposure / count, 2),
|
||||
protection_score=round(total_protection / count, 2),
|
||||
details={"items": detail_items[:100], "total_items": len(detail_items)},
|
||||
)
|
||||
db.add(existing)
|
||||
|
||||
db.commit()
|
||||
return existing
|
||||
|
||||
|
||||
def calculate_all_projects_risk(db: Session) -> dict:
|
||||
"""Batch calculate risk for all projects."""
|
||||
projects = db.query(ClassificationProject).all()
|
||||
updated = 0
|
||||
for p in projects:
|
||||
try:
|
||||
calculate_project_risk(db, p.id)
|
||||
updated += 1
|
||||
except Exception:
|
||||
pass
|
||||
return {"updated": updated}
|
||||
|
||||
|
||||
def get_risk_top_n(db: Session, entity_type: str = "project", n: int = 10) -> List[RiskAssessment]:
|
||||
return (
|
||||
db.query(RiskAssessment)
|
||||
.filter(RiskAssessment.entity_type == entity_type)
|
||||
.order_by(RiskAssessment.risk_score.desc())
|
||||
.limit(n)
|
||||
.all()
|
||||
)
|
||||
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from typing import Optional, List
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.models.metadata import UnstructuredFile
|
||||
from app.core.events import minio_client
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def extract_text_from_file(file_path: str, file_type: str) -> str:
|
||||
text = ""
|
||||
ft = file_type.lower()
|
||||
if ft in ("word", "docx"):
|
||||
try:
|
||||
from docx import Document
|
||||
doc = Document(file_path)
|
||||
text = "\n".join([p.text for p in doc.paragraphs if p.text])
|
||||
except Exception as e:
|
||||
raise ValueError(f"解析Word失败: {e}")
|
||||
elif ft in ("excel", "xlsx", "xls"):
|
||||
try:
|
||||
from openpyxl import load_workbook
|
||||
wb = load_workbook(file_path, data_only=True)
|
||||
parts = []
|
||||
for sheet in wb.worksheets:
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
parts.append(" ".join([str(c) for c in row if c is not None]))
|
||||
text = "\n".join(parts)
|
||||
except Exception as e:
|
||||
raise ValueError(f"解析Excel失败: {e}")
|
||||
elif ft == "pdf":
|
||||
try:
|
||||
import pdfplumber
|
||||
with pdfplumber.open(file_path) as pdf:
|
||||
text = "\n".join([page.extract_text() or "" for page in pdf.pages])
|
||||
except Exception as e:
|
||||
raise ValueError(f"解析PDF失败: {e}")
|
||||
elif ft == "txt":
|
||||
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
|
||||
text = f.read()
|
||||
else:
|
||||
raise ValueError(f"不支持的文件类型: {ft}")
|
||||
return text
|
||||
|
||||
|
||||
def scan_text_for_sensitive(text: str) -> List[dict]:
|
||||
"""Scan extracted text for sensitive patterns using built-in rules."""
|
||||
matches = []
|
||||
# ID card
|
||||
id_pattern = re.compile(r"(?<!\d)\d{17}[\dXx](?!\d)")
|
||||
for m in id_pattern.finditer(text):
|
||||
snippet = text[max(0, m.start()-10):min(len(text), m.end()+10)]
|
||||
matches.append({"rule_name": "身份证号", "category_code": "CUST_PERSONAL", "level_code": "L4", "snippet": snippet, "position": m.start()})
|
||||
# Phone
|
||||
phone_pattern = re.compile(r"(?<!\d)1[3-9]\d{9}(?!\d)")
|
||||
for m in phone_pattern.finditer(text):
|
||||
snippet = text[max(0, m.start()-10):min(len(text), m.end()+10)]
|
||||
matches.append({"rule_name": "手机号", "category_code": "CUST_PERSONAL", "level_code": "L4", "snippet": snippet, "position": m.start()})
|
||||
# Bank card (simple 16-19 digits)
|
||||
bank_pattern = re.compile(r"(?<!\d)\d{16,19}(?!\d)")
|
||||
for m in bank_pattern.finditer(text):
|
||||
snippet = text[max(0, m.start()-10):min(len(text), m.end()+10)]
|
||||
matches.append({"rule_name": "银行卡号", "category_code": "FIN_PAYMENT", "level_code": "L4", "snippet": snippet, "position": m.start()})
|
||||
# Amount
|
||||
amount_pattern = re.compile(r"(?<!\d)\d{1,3}(,\d{3})*\.\d{2}(?!\d)")
|
||||
for m in amount_pattern.finditer(text):
|
||||
snippet = text[max(0, m.start()-10):min(len(text), m.end()+10)]
|
||||
matches.append({"rule_name": "金额", "category_code": "FIN_PAYMENT", "level_code": "L3", "snippet": snippet, "position": m.start()})
|
||||
return matches
|
||||
|
||||
|
||||
def process_unstructured_file(db: Session, file_id: int) -> dict:
|
||||
file_obj = db.query(UnstructuredFile).filter(UnstructuredFile.id == file_id).first()
|
||||
if not file_obj:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="文件不存在")
|
||||
if not file_obj.storage_path:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="文件未上传")
|
||||
|
||||
# Download from MinIO to temp
|
||||
tmp_path = f"/tmp/unstructured_{file_id}_{file_obj.original_name}"
|
||||
try:
|
||||
minio_client.fget_object(settings.MINIO_BUCKET_NAME, file_obj.storage_path, tmp_path)
|
||||
text = extract_text_from_file(tmp_path, file_obj.file_type or "")
|
||||
file_obj.extracted_text = text[:50000] # limit storage
|
||||
matches = scan_text_for_sensitive(text)
|
||||
file_obj.analysis_result = {"matches": matches, "total_chars": len(text)}
|
||||
file_obj.status = "processed"
|
||||
db.commit()
|
||||
return {"success": True, "matches": matches, "total_chars": len(text)}
|
||||
except Exception as e:
|
||||
file_obj.status = "error"
|
||||
db.commit()
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
@@ -0,0 +1,97 @@
|
||||
import secrets
|
||||
from typing import Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.watermark import WatermarkLog
|
||||
|
||||
# Zero-width characters for binary encoding
|
||||
ZW_SPACE = "\u200b" # zero-width space -> 0
|
||||
ZW_NOJOIN = "\u200c" # zero-width non-joiner -> 1
|
||||
MARKER = "\u200d" # zero-width joiner -> start marker
|
||||
|
||||
|
||||
def _int_to_binary_bits(n: int, bits: int = 32) -> str:
|
||||
return format(n, f"0{bits}b")
|
||||
|
||||
|
||||
def _binary_bits_to_int(bits: str) -> int:
|
||||
return int(bits, 2)
|
||||
|
||||
|
||||
def embed_watermark(text: str, user_id: int, key: str) -> str:
|
||||
"""Embed invisible watermark into text using zero-width characters."""
|
||||
# Encode user_id as 32-bit binary
|
||||
bits = _int_to_binary_bits(user_id)
|
||||
# Encode key hash as 16-bit for verification
|
||||
key_bits = _int_to_binary_bits(hash(key) & 0xFFFF, 16)
|
||||
payload = key_bits + bits
|
||||
watermark_chars = MARKER + "".join(ZW_NOJOIN if b == "1" else ZW_SPACE for b in payload)
|
||||
# Append watermark at the end of the text (before trailing newlines if any)
|
||||
text = text.rstrip("\n")
|
||||
return text + watermark_chars + "\n"
|
||||
|
||||
|
||||
def extract_watermark(text: str) -> Tuple[Optional[int], Optional[str]]:
|
||||
"""Extract watermark from text. Returns (user_id, key_hash_bits) or (None, None)."""
|
||||
if MARKER not in text:
|
||||
return None, None
|
||||
idx = text.index(MARKER)
|
||||
payload = text[idx + len(MARKER):]
|
||||
bits = ""
|
||||
for ch in payload:
|
||||
if ch == ZW_SPACE:
|
||||
bits += "0"
|
||||
elif ch == ZW_NOJOIN:
|
||||
bits += "1"
|
||||
else:
|
||||
# Stop at first non-watermark character
|
||||
break
|
||||
if len(bits) < 16:
|
||||
return None, None
|
||||
key_bits = bits[:16]
|
||||
user_bits = bits[16:48]
|
||||
try:
|
||||
user_id = _binary_bits_to_int(user_bits)
|
||||
return user_id, key_bits
|
||||
except Exception:
|
||||
return None, None
|
||||
|
||||
|
||||
def apply_watermark_to_lines(lines: list, user_id: int, key: str) -> list:
|
||||
"""Apply watermark to each line of CSV/TXT."""
|
||||
return [embed_watermark(line, user_id, key) for line in lines]
|
||||
|
||||
|
||||
def create_watermark_log(db: Session, user_id: int, export_type: str, data_scope: dict) -> WatermarkLog:
|
||||
key = secrets.token_hex(16)
|
||||
log = WatermarkLog(
|
||||
user_id=user_id,
|
||||
export_type=export_type,
|
||||
data_scope=str(data_scope),
|
||||
watermark_key=key,
|
||||
)
|
||||
db.add(log)
|
||||
db.commit()
|
||||
db.refresh(log)
|
||||
return log
|
||||
|
||||
|
||||
def trace_watermark(db: Session, text: str) -> Optional[dict]:
|
||||
"""Trace leaked text back to user."""
|
||||
user_id, _ = extract_watermark(text)
|
||||
if user_id is None:
|
||||
return None
|
||||
log = (
|
||||
db.query(WatermarkLog)
|
||||
.filter(WatermarkLog.user_id == user_id)
|
||||
.order_by(WatermarkLog.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if not log:
|
||||
return None
|
||||
return {
|
||||
"user_id": log.user_id,
|
||||
"username": log.user.username if log.user else None,
|
||||
"export_type": log.export_type,
|
||||
"data_scope": log.data_scope,
|
||||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from app.tasks.worker import celery_app
|
||||
|
||||
|
||||
@@ -5,12 +6,10 @@ from app.tasks.worker import celery_app
|
||||
def auto_classify_task(self, project_id: int, source_ids: list = None):
|
||||
"""
|
||||
Async task to run automatic classification on metadata.
|
||||
Phase 1 placeholder.
|
||||
"""
|
||||
from app.core.database import SessionLocal
|
||||
from app.models.project import ClassificationProject, ClassificationResult, ResultStatus
|
||||
from app.models.classification import RecognitionRule
|
||||
from app.models.metadata import DataColumn
|
||||
from app.models.project import ClassificationProject
|
||||
from app.services.classification_engine import run_auto_classification
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
@@ -18,15 +17,46 @@ def auto_classify_task(self, project_id: int, source_ids: list = None):
|
||||
if not project:
|
||||
return {"status": "failed", "reason": "project not found"}
|
||||
|
||||
# Update project status
|
||||
def progress_callback(scanned, matched, total):
|
||||
percent = int(scanned / total * 100) if total else 0
|
||||
meta = {
|
||||
"scanned": scanned,
|
||||
"matched": matched,
|
||||
"total": total,
|
||||
"percent": percent,
|
||||
}
|
||||
self.update_state(state="PROGRESS", meta=meta)
|
||||
# Persist lightweight progress to DB for UI polling
|
||||
project.scan_progress = json.dumps(meta)
|
||||
db.commit()
|
||||
|
||||
# Initialize
|
||||
project.status = "scanning"
|
||||
project.scan_progress = json.dumps({"scanned": 0, "matched": 0, "total": 0, "percent": 0})
|
||||
db.commit()
|
||||
|
||||
rules = db.query(RecognitionRule).filter(RecognitionRule.is_active == True).all()
|
||||
# TODO: implement rule matching logic in Phase 2
|
||||
result = run_auto_classification(
|
||||
db,
|
||||
project_id,
|
||||
source_ids=source_ids,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
project.status = "assigning"
|
||||
if result.get("success"):
|
||||
project.status = "assigning"
|
||||
else:
|
||||
project.status = "created"
|
||||
|
||||
project.celery_task_id = None
|
||||
db.commit()
|
||||
return {"status": "completed", "project_id": project_id, "matched": 0}
|
||||
return {"status": "completed", "project_id": project_id, "result": result}
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
|
||||
if project:
|
||||
project.status = "created"
|
||||
project.celery_task_id = None
|
||||
db.commit()
|
||||
return {"status": "failed", "reason": str(e)}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
from app.tasks.worker import celery_app
|
||||
|
||||
|
||||
@celery_app.task(bind=True)
|
||||
def train_ml_model_task(self, model_name: str = None, algorithm: str = "logistic_regression"):
|
||||
from app.core.database import SessionLocal
|
||||
from app.services.ml_service import train_model
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
self.update_state(state="PROGRESS", meta={"message": "Fetching training data"})
|
||||
mv = train_model(db, model_name=model_name, algorithm=algorithm)
|
||||
if mv:
|
||||
return {
|
||||
"status": "completed",
|
||||
"model_id": mv.id,
|
||||
"name": mv.name,
|
||||
"accuracy": mv.accuracy,
|
||||
"train_samples": mv.train_samples,
|
||||
}
|
||||
else:
|
||||
return {"status": "failed", "reason": "Not enough training data (need >= 20 samples)"}
|
||||
except Exception as e:
|
||||
return {"status": "failed", "reason": str(e)}
|
||||
finally:
|
||||
db.close()
|
||||
@@ -5,7 +5,7 @@ celery_app = Celery(
|
||||
"data_pointer",
|
||||
broker=settings.REDIS_URL,
|
||||
backend=settings.REDIS_URL,
|
||||
include=["app.tasks.classification_tasks"],
|
||||
include=["app.tasks.classification_tasks", "app.tasks.ml_tasks"],
|
||||
)
|
||||
|
||||
celery_app.conf.update(
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
import sys, requests
|
||||
BASE = "http://localhost:8000"
|
||||
API = f"{BASE}/api/v1"
|
||||
errors, passed = [], []
|
||||
|
||||
def check(name, ok, detail=""):
|
||||
if ok:
|
||||
passed.append(name); print(f" ✅ {name}")
|
||||
else:
|
||||
errors.append((name, detail)); print(f" ❌ {name}: {detail}")
|
||||
|
||||
def get_items(resp):
|
||||
d = resp.json().get("data", [])
|
||||
if isinstance(d, list):
|
||||
return d
|
||||
if isinstance(d, dict):
|
||||
return d.get("items", [])
|
||||
return []
|
||||
|
||||
def get_total(resp):
|
||||
return resp.json().get("total", 0)
|
||||
|
||||
print("\n[1/15] Health")
|
||||
r = requests.get(f"{BASE}/health")
|
||||
check("health", r.status_code == 200 and r.json().get("status") == "ok")
|
||||
|
||||
print("\n[2/15] Auth")
|
||||
r = requests.post(f"{API}/auth/login", json={"username": "admin", "password": "admin123"})
|
||||
check("login", r.status_code == 200)
|
||||
token = r.json().get("data", {}).get("access_token", "")
|
||||
check("token", bool(token))
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
print("\n[3/15] User")
|
||||
r = requests.get(f"{API}/users/me", headers=headers)
|
||||
check("me", r.status_code == 200 and r.json()["data"]["username"] == "admin")
|
||||
r = requests.get(f"{API}/users?page_size=100", headers=headers)
|
||||
check("users", r.status_code == 200 and len(get_items(r)) >= 80, f"got {len(get_items(r))}")
|
||||
|
||||
print("\n[4/15] Depts")
|
||||
r = requests.get(f"{API}/users/depts", headers=headers)
|
||||
check("depts", r.status_code == 200 and len(r.json().get("data", [])) >= 12, f"got {len(r.json().get('data', []))}")
|
||||
|
||||
print("\n[5/15] DataSources")
|
||||
r = requests.get(f"{API}/datasources", headers=headers)
|
||||
check("datasources", r.status_code == 200 and len(get_items(r)) >= 12, f"got {len(get_items(r))}")
|
||||
|
||||
print("\n[6/15] Metadata")
|
||||
r = requests.get(f"{API}/metadata/databases", headers=headers)
|
||||
check("databases", r.status_code == 200 and len(get_items(r)) >= 31, f"got {len(get_items(r))}")
|
||||
r = requests.get(f"{API}/metadata/tables", headers=headers)
|
||||
check("tables", r.status_code == 200 and len(get_items(r)) >= 800, f"got {len(get_items(r))}")
|
||||
r = requests.get(f"{API}/metadata/columns", headers=headers)
|
||||
check("columns", r.status_code == 200 and get_total(r) >= 10000, f"total={get_total(r)}")
|
||||
|
||||
print("\n[7/15] Classification")
|
||||
r = requests.get(f"{API}/classifications/levels", headers=headers)
|
||||
check("levels", r.status_code == 200 and len(r.json().get("data", [])) == 5)
|
||||
r = requests.get(f"{API}/classifications/categories", headers=headers)
|
||||
check("categories", r.status_code == 200 and len(r.json().get("data", [])) >= 20, f"got {len(r.json().get('data', []))}")
|
||||
r = requests.get(f"{API}/classifications/results", headers=headers)
|
||||
check("results", r.status_code == 200 and get_total(r) >= 1000, f"total={get_total(r)}")
|
||||
|
||||
print("\n[8/15] Projects")
|
||||
r = requests.get(f"{API}/projects", headers=headers)
|
||||
check("projects", r.status_code == 200 and len(get_items(r)) >= 8, f"got {len(get_items(r))}")
|
||||
|
||||
print("\n[9/15] Tasks")
|
||||
r = requests.get(f"{API}/tasks/my-tasks", headers=headers)
|
||||
check("tasks", r.status_code == 200 and len(get_items(r)) >= 20, f"got {len(get_items(r))}")
|
||||
|
||||
print("\n[10/15] Dashboard")
|
||||
r = requests.get(f"{API}/dashboard/stats", headers=headers)
|
||||
check("stats", r.status_code == 200)
|
||||
stats = r.json().get("data", {})
|
||||
check("stats.data_sources", stats.get("data_sources", 0) >= 12, f"got {stats.get('data_sources')}")
|
||||
check("stats.tables", stats.get("tables", 0) >= 800, f"got {stats.get('tables')}")
|
||||
check("stats.columns", stats.get("columns", 0) >= 10000, f"got {stats.get('columns')}")
|
||||
check("stats.labeled", stats.get("labeled", 0) >= 10000, f"got {stats.get('labeled')}")
|
||||
r = requests.get(f"{API}/dashboard/distribution", headers=headers)
|
||||
check("distribution", r.status_code == 200 and "level_distribution" in r.json().get("data", {}))
|
||||
|
||||
print("\n[11/15] Reports")
|
||||
r = requests.get(f"{API}/reports/stats", headers=headers)
|
||||
check("report stats", r.status_code == 200)
|
||||
|
||||
print("\n[12/15] Masking")
|
||||
r = requests.get(f"{API}/masking/rules", headers=headers)
|
||||
check("masking rules", r.status_code == 200)
|
||||
|
||||
print("\n[13/15] Watermark")
|
||||
r = requests.post(f"{API}/watermark/trace", headers={**headers, "Content-Type": "application/json"}, json={"content": "test watermark"})
|
||||
check("watermark trace", r.status_code == 200)
|
||||
|
||||
print("\n[14/15] Risk")
|
||||
r = requests.get(f"{API}/risk/top", headers=headers)
|
||||
check("risk top", r.status_code == 200)
|
||||
|
||||
print("\n[15/15] Compliance")
|
||||
r = requests.get(f"{API}/compliance/issues", headers=headers)
|
||||
check("compliance issues", r.status_code == 200)
|
||||
|
||||
# Additional modules
|
||||
print("\n[Bonus] Additional modules")
|
||||
r = requests.get(f"{API}/lineage/graph", headers=headers)
|
||||
check("lineage graph", r.status_code == 200 and "nodes" in r.json().get("data", {}))
|
||||
r = requests.get(f"{API}/alerts/records", headers=headers)
|
||||
check("alert records", r.status_code == 200)
|
||||
r = requests.get(f"{API}/schema-changes/logs", headers=headers)
|
||||
check("schema changes logs", r.status_code == 200)
|
||||
r = requests.get(f"{API}/unstructured/files", headers=headers)
|
||||
check("unstructured files", r.status_code == 200)
|
||||
r = requests.get(f"{API}/api-assets", headers=headers)
|
||||
check("api assets", r.status_code == 200)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print(f"Results: {len(passed)} passed, {len(errors)} failed")
|
||||
print("="*60)
|
||||
if errors:
|
||||
for n, d in errors: print(f" ❌ {n}: {d}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("🎉 All integration tests passed!")
|
||||
sys.exit(0)
|
||||
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Generate synthetic manual-labeled data for ML model training/demo.
|
||||
Run this script after metadata has been scanned so there are columns to label.
|
||||
"""
|
||||
import random
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from app.models.metadata import DataColumn
|
||||
from app.models.classification import Category
|
||||
from app.models.project import ClassificationResult
|
||||
|
||||
|
||||
def main():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
columns = db.query(DataColumn).limit(300).all()
|
||||
if not columns:
|
||||
print("No columns found in database. Please scan a data source first.")
|
||||
return
|
||||
|
||||
categories = db.query(Category).filter(Category.level == 2).all()
|
||||
if not categories:
|
||||
print("No sub-categories found.")
|
||||
return
|
||||
|
||||
# Clear old manual labels to avoid duplicates
|
||||
db.query(ClassificationResult).filter(ClassificationResult.source == "manual").delete()
|
||||
db.commit()
|
||||
|
||||
count = 0
|
||||
for col in columns:
|
||||
# Deterministic pseudo-random based on column name for reproducibility
|
||||
rng = random.Random(col.name)
|
||||
cat = rng.choice(categories)
|
||||
# Create a fake manual result (project_id=1 assumed to exist or None)
|
||||
result = ClassificationResult(
|
||||
project_id=None,
|
||||
column_id=col.id,
|
||||
category_id=cat.id,
|
||||
level_id=cat.parent.level if cat.parent else 3, # fallback
|
||||
source="manual",
|
||||
confidence=1.0,
|
||||
status="manual",
|
||||
)
|
||||
db.add(result)
|
||||
count += 1
|
||||
|
||||
db.commit()
|
||||
print(f"Generated {count} manual labels across {len(categories)} categories.")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,176 @@
|
||||
import sys
|
||||
|
||||
with open(sys.argv[1], 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
marker = 'def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:'
|
||||
idx = content.find(marker)
|
||||
if idx == -1:
|
||||
print('Marker not found')
|
||||
sys.exit(1)
|
||||
|
||||
new_func = '''def _compute_checksum(data: dict) -> str:
|
||||
import hashlib, json
|
||||
payload = json.dumps(data, sort_keys=True, ensure_ascii=False, default=str)
|
||||
return hashlib.sha256(payload.encode()).hexdigest()[:32]
|
||||
|
||||
|
||||
def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
|
||||
from sqlalchemy import create_engine, inspect, text
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
source = get_datasource(db, source_id)
|
||||
if not source:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在")
|
||||
|
||||
driver_map = {
|
||||
"mysql": "mysql+pymysql",
|
||||
"postgresql": "postgresql+psycopg2",
|
||||
"oracle": "oracle+cx_oracle",
|
||||
"sqlserver": "mssql+pymssql",
|
||||
}
|
||||
driver = driver_map.get(source.source_type, source.source_type)
|
||||
|
||||
if source.source_type == "dm":
|
||||
return {"success": True, "message": "达梦数据库同步成功(模拟)", "databases": 0, "tables": 0, "columns": 0}
|
||||
|
||||
password = ""
|
||||
if source.encrypted_password:
|
||||
try:
|
||||
password = _decrypt_password(source.encrypted_password)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
url = f"{driver}://{source.username}:{password}@{source.host}:{source.port}/{source.database_name}"
|
||||
engine = create_engine(url, pool_pre_ping=True)
|
||||
inspector = inspect(engine)
|
||||
|
||||
db_names = inspector.get_schema_names() or [source.database_name]
|
||||
scan_time = datetime.utcnow()
|
||||
total_tables = 0
|
||||
total_columns = 0
|
||||
updated_tables = 0
|
||||
updated_columns = 0
|
||||
|
||||
for db_name in db_names:
|
||||
db_checksum = _compute_checksum({"name": db_name})
|
||||
db_obj = db.query(Database).filter(
|
||||
Database.source_id == source.id, Database.name == db_name
|
||||
).first()
|
||||
if not db_obj:
|
||||
db_obj = Database(source_id=source.id, name=db_name, checksum=db_checksum, last_scanned_at=scan_time)
|
||||
db.add(db_obj)
|
||||
else:
|
||||
db_obj.checksum = db_checksum
|
||||
db_obj.last_scanned_at = scan_time
|
||||
db_obj.is_deleted = False
|
||||
db_obj.deleted_at = None
|
||||
|
||||
table_names = inspector.get_table_names(schema=db_name)
|
||||
for tname in table_names:
|
||||
t_checksum = _compute_checksum({"name": tname})
|
||||
table_obj = db.query(DataTable).filter(
|
||||
DataTable.database_id == db_obj.id, DataTable.name == tname
|
||||
).first()
|
||||
if not table_obj:
|
||||
table_obj = DataTable(database_id=db_obj.id, name=tname, checksum=t_checksum, last_scanned_at=scan_time)
|
||||
db.add(table_obj)
|
||||
else:
|
||||
if table_obj.checksum != t_checksum:
|
||||
table_obj.checksum = t_checksum
|
||||
updated_tables += 1
|
||||
table_obj.last_scanned_at = scan_time
|
||||
table_obj.is_deleted = False
|
||||
table_obj.deleted_at = None
|
||||
|
||||
columns = inspector.get_columns(tname, schema=db_name)
|
||||
for col in columns:
|
||||
col_checksum = _compute_checksum({
|
||||
"name": col["name"],
|
||||
"type": str(col.get("type", "")),
|
||||
"max_length": col.get("max_length"),
|
||||
"comment": col.get("comment"),
|
||||
"nullable": col.get("nullable", True),
|
||||
})
|
||||
col_obj = db.query(DataColumn).filter(
|
||||
DataColumn.table_id == table_obj.id, DataColumn.name == col["name"]
|
||||
).first()
|
||||
if not col_obj:
|
||||
sample = None
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text(f'SELECT "{col["name"]}" FROM "{db_name}"."{tname}" LIMIT 5'))
|
||||
samples = [str(r[0]) for r in result if r[0] is not None]
|
||||
sample = json.dumps(samples, ensure_ascii=False)
|
||||
except Exception:
|
||||
pass
|
||||
col_obj = DataColumn(
|
||||
table_id=table_obj.id,
|
||||
name=col["name"],
|
||||
data_type=str(col.get("type", "")),
|
||||
length=col.get("max_length"),
|
||||
comment=col.get("comment"),
|
||||
is_nullable=col.get("nullable", True),
|
||||
sample_data=sample,
|
||||
checksum=col_checksum,
|
||||
last_scanned_at=scan_time,
|
||||
)
|
||||
db.add(col_obj)
|
||||
total_columns += 1
|
||||
else:
|
||||
if col_obj.checksum != col_checksum:
|
||||
col_obj.checksum = col_checksum
|
||||
col_obj.data_type = str(col.get("type", ""))
|
||||
col_obj.length = col.get("max_length")
|
||||
col_obj.comment = col.get("comment")
|
||||
col_obj.is_nullable = col.get("nullable", True)
|
||||
updated_columns += 1
|
||||
col_obj.last_scanned_at = scan_time
|
||||
col_obj.is_deleted = False
|
||||
col_obj.deleted_at = None
|
||||
|
||||
total_tables += 1
|
||||
|
||||
# Soft-delete objects not seen in this scan
|
||||
db.query(Database).filter(
|
||||
Database.source_id == source.id,
|
||||
Database.last_scanned_at < scan_time,
|
||||
).update({"is_deleted": True, "deleted_at": scan_time}, synchronize_session=False)
|
||||
|
||||
for db_obj in db.query(Database).filter(Database.source_id == source.id).all():
|
||||
db.query(DataTable).filter(
|
||||
DataTable.database_id == db_obj.id,
|
||||
DataTable.last_scanned_at < scan_time,
|
||||
).update({"is_deleted": True, "deleted_at": scan_time}, synchronize_session=False)
|
||||
for table_obj in db.query(DataTable).filter(DataTable.database_id == db_obj.id).all():
|
||||
db.query(DataColumn).filter(
|
||||
DataColumn.table_id == table_obj.id,
|
||||
DataColumn.last_scanned_at < scan_time,
|
||||
).update({"is_deleted": True, "deleted_at": scan_time}, synchronize_session=False)
|
||||
|
||||
source.status = "active"
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "元数据同步成功",
|
||||
"databases": len(db_names),
|
||||
"tables": total_tables,
|
||||
"columns": total_columns,
|
||||
"updated_tables": updated_tables,
|
||||
"updated_columns": updated_columns,
|
||||
}
|
||||
except Exception as e:
|
||||
source.status = "error"
|
||||
db.commit()
|
||||
return {"success": False, "message": f"同步失败: {str(e)}", "databases": 0, "tables": 0, "columns": 0}
|
||||
'''
|
||||
|
||||
new_content = content[:idx] + new_func
|
||||
|
||||
with open(sys.argv[1], 'w') as f:
|
||||
f.write(new_content)
|
||||
|
||||
print('Patched successfully')
|
||||
Reference in New Issue
Block a user