Files
prop-data-guard/backend/scripts/generate_test_data.py
T
hiderfong 9d38180745 rebrand: PropDataGuard → DataPointer
- App name: PropDataGuard → DataPointer
- Frontend title: 财险数据分级分类平台 → 数据分类分级管理平台
- LocalStorage keys: pdg_token/pdg_refresh → dp_token/dp_refresh
- Package name: prop-data-guard-frontend → data-pointer-frontend
- Project config: admin@propdataguard.comadmin@datapo.com
- Celery app name: prop_data_guard → data_pointer
- Layout logo, login title, page title all updated
2026-04-23 11:26:28 +08:00

543 lines
22 KiB
Python

"""
Generate test data for DataPointer system.
Targets: 10000+ records across all tables.
"""
import sys
sys.path.insert(0, '/Users/nathan/Work/DataPointer/prop-data-guard/backend')
import random
import string
import json
from datetime import datetime, timedelta
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.core.database import Base
from app.models.user import User, Role, Dept, UserRole
from app.models.metadata import DataSource, Database, DataTable, DataColumn
from app.models.classification import Category, DataLevel, RecognitionRule, ClassificationTemplate
from app.models.project import ClassificationProject, ClassificationTask, ClassificationResult, ResultStatus
from app.models.log import OperationLog
from app.core.security import get_password_hash
# Database connection
DATABASE_URL = "postgresql+psycopg2://pdg:pdg_secret_2024@localhost:5432/prop_data_guard"
engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(bind=engine)
db = SessionLocal()
# Clear existing test data (preserve admin user and built-in data)
print("Clearing existing test data...")
db.query(ClassificationResult).delete(synchronize_session=False)
db.query(ClassificationTask).delete(synchronize_session=False)
db.query(ClassificationProject).delete(synchronize_session=False)
db.query(DataColumn).delete(synchronize_session=False)
db.query(DataTable).delete(synchronize_session=False)
db.query(Database).delete(synchronize_session=False)
db.query(UserRole).filter(UserRole.user_id > 1).delete(synchronize_session=False)
db.query(User).filter(User.id > 1).delete(synchronize_session=False)
db.query(Dept).filter(Dept.id > 1).delete(synchronize_session=False)
db.query(OperationLog).delete(synchronize_session=False)
db.commit()
# Reset all sequences to avoid ID conflicts
from sqlalchemy import text
sequences = [
"sys_dept_id_seq", "sys_user_id_seq", "sys_user_role_id_seq",
"data_source_id_seq", "meta_database_id_seq", "meta_table_id_seq", "meta_column_id_seq",
"classification_project_id_seq", "classification_task_id_seq", "classification_result_id_seq",
"classification_change_id_seq", "sys_operation_log_id_seq",
]
for seq in sequences:
db.execute(text(f"ALTER SEQUENCE {seq} RESTART WITH 100"))
db.commit()
print(" Sequences reset")
random.seed(42)
# ============================================================
# 1. Departments
# ============================================================
print("Generating departments...")
root_dept_names = ["数据安全部", "合规管理部", "信息技术部"]
root_depts = []
for name in root_dept_names:
d = Dept(name=name, parent_id=None, sort_order=len(root_depts))
db.add(d)
root_depts.append(d)
db.commit()
for d in root_depts:
db.refresh(d)
# Map root depts by index: 0=数据安全部, 1=合规管理部, 2=信息技术部
root_id_map = {i+1: d.id for i, d in enumerate(root_depts)}
child_dept_defs = [
("业务一部", root_id_map[1]), ("业务二部", root_id_map[1]),
("车险事业部", root_id_map[3]), ("非车险事业部", root_id_map[3]), ("理赔服务部", root_id_map[3]),
("财务部", root_id_map[2]), ("精算部", root_id_map[2]),
("客户服务部", root_id_map[1]), ("渠道管理部", root_id_map[1]),
]
depts = root_depts[:]
for name, pid in child_dept_defs:
d = Dept(name=name, parent_id=pid, sort_order=len(depts))
db.add(d)
depts.append(d)
db.commit()
for d in depts[len(root_depts):]:
db.refresh(d)
print(f" Created {len(depts)} departments")
# ============================================================
# 2. Users
# ============================================================
print("Generating users...")
roles = db.query(Role).all()
role_map = {r.code: r.id for r in roles}
first_names = ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]
last_names = ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "秀英", "", ""]
def random_name():
return random.choice(first_names) + random.choice(last_names)
def random_phone():
return "1" + random.choice(["3","4","5","6","7","8","9"]) + "".join(random.choices(string.digits, k=9))
users = []
for i in range(80):
real = random_name()
username = f"user{i+2:03d}"
user = User(
username=username,
email=f"{username}@datapo.com",
hashed_password=get_password_hash("password123"),
real_name=real,
phone=random_phone(),
is_active=random.random() > 0.05,
is_superuser=False,
dept_id=random.choice(depts).id,
)
db.add(user)
users.append(user)
db.commit()
for u in users:
db.refresh(u)
# Assign roles
role_list = list(roles)
for u in users:
assigned_roles = random.sample(role_list, k=random.randint(1, 2))
for r in assigned_roles:
db.add(UserRole(user_id=u.id, role_id=r.id))
db.commit()
print(f" Created {len(users)} users")
# ============================================================
# 3. Data Sources
# ============================================================
print("Generating data sources...")
source_types = ["postgresql", "mysql", "oracle", "sqlserver", "dm"]
source_configs = [
("核心保单数据库", "postgresql", "db-core-prod", 5432, "core_policy"),
("理赔系统数据库", "mysql", "db-claim-prod", 3306, "claim_db"),
("财务数据仓库", "postgresql", "db-finance-dw", 5432, "finance_dw"),
("客户信息主库", "mysql", "db-cust-master", 3306, "customer_master"),
("渠道管理系统", "oracle", "db-channel-ora", 1521, "CHANNEL"),
("精算分析平台", "postgresql", "db-actuary-ana", 5432, "actuary_analytics"),
("监管报送库", "mysql", "db-regulatory", 3306, "regulatory_report"),
("车辆信息库", "postgresql", "db-vehicle", 5432, "vehicle_db"),
("非车险业务库", "sqlserver", "db-nonauto", 1433, "NonAutoDB"),
("历史归档库", "postgresql", "db-archive", 5432, "archive_db"),
("测试环境核心库", "postgresql", "db-core-test", 5432, "core_test"),
("达梦国产数据库", "dm", "db-dameng-prod", 5236, "DAMENG"),
]
sources = []
for name, stype, host, port, dbname in source_configs:
ds = DataSource(
name=name,
source_type=stype,
host=f"{host}.internal.company.com",
port=port,
database_name=dbname,
username=f"{stype}_admin",
encrypted_password=None,
status="active" if random.random() > 0.1 else "error",
dept_id=random.choice(depts).id,
created_by=random.choice(users).id,
)
db.add(ds)
sources.append(ds)
db.commit()
for s in sources:
db.refresh(s)
print(f" Created {len(sources)} data sources")
# ============================================================
# 4. Databases
# ============================================================
print("Generating databases...")
databases = []
for source in sources:
num_dbs = random.randint(1, 3)
for i in range(num_dbs):
d = Database(
source_id=source.id,
name=f"{source.database_name}_{i+1}" if num_dbs > 1 else source.database_name,
charset="UTF8" if source.source_type != "sqlserver" else "Chinese_PRC_CI_AS",
table_count=0,
)
db.add(d)
databases.append(d)
db.commit()
for d in databases:
db.refresh(d)
print(f" Created {len(databases)} databases")
# ============================================================
# 5. Data Tables & Columns (the big one)
# ============================================================
print("Generating tables and columns...")
table_prefixes = {
"policy": ["t_policy", "t_policy_detail", "t_policy_extension", "t_policy_history", "t_endorsement"],
"claim": ["t_claim", "t_claim_detail", "t_claim_payment", "t_claim_document", "t_survey"],
"customer": ["t_customer", "t_customer_contact", "t_customer_identity", "t_customer_vehicle", "t_customer_preference"],
"finance": ["t_payment", "t_receipt", "t_invoice", "t_commission", "t_reserve"],
"channel": ["t_agent", "t_agent_contract", "t_partner", "t_broker", "t_sales_record"],
"actuary": ["t_pricing_model", "t_risk_factor", "t_loss_ratio", "t_reserve_calc", "t_solvency"],
"regulatory": ["t_report_cbrc", "t_report_circ", "t_stat_premium", "t_stat_claim", "t_stat_channel"],
"vehicle": ["t_vehicle", "t_vehicle_model", "t_vehicle_usage", "t_vehicle_accident", "t_vehicle_maintenance"],
"system": ["t_user", "t_role", "t_permission", "t_log", "t_config", "t_dict"],
"archive": ["t_archive_policy", "t_archive_claim", "t_archive_customer", "t_archive_finance"],
}
column_templates = [
("id", "BIGINT", "主键ID", "system", 2),
("created_at", "TIMESTAMP", "创建时间", "system", 2),
("updated_at", "TIMESTAMP", "更新时间", "system", 2),
("is_deleted", "BOOLEAN", "是否删除", "system", 2),
("created_by", "BIGINT", "创建人", "system", 2),
("customer_name", "VARCHAR", "客户姓名", "customer", 4),
("customer_id_no", "VARCHAR", "客户身份证号", "customer", 4),
("mobile_phone", "VARCHAR", "手机号码", "customer", 4),
("email", "VARCHAR", "电子邮箱", "customer", 3),
("address", "VARCHAR", "联系地址", "customer", 3),
("bank_account", "VARCHAR", "银行账户", "finance", 4),
("bank_card_no", "VARCHAR", "银行卡号", "finance", 4),
("policy_no", "VARCHAR", "保单号", "policy", 3),
("policy_status", "VARCHAR", "保单状态", "policy", 2),
("premium_amount", "DECIMAL", "保费金额", "finance", 3),
("claim_no", "VARCHAR", "理赔号", "claim", 3),
("claim_amount", "DECIMAL", "理赔金额", "claim", 4),
("loss_description", "TEXT", "损失描述", "claim", 3),
("accident_location", "VARCHAR", "出险地点", "claim", 3),
("vehicle_plate", "VARCHAR", "车牌号", "vehicle", 3),
("vin_code", "VARCHAR", "车辆识别代码VIN", "vehicle", 4),
("agent_name", "VARCHAR", "代理人姓名", "channel", 3),
("agent_license", "VARCHAR", "代理人执业证号", "channel", 3),
("commission_rate", "DECIMAL", "佣金比例", "finance", 3),
("reserve_amount", "DECIMAL", "准备金金额", "finance", 5),
("solvency_ratio", "DECIMAL", "偿付能力充足率", "finance", 5),
("password_hash", "VARCHAR", "密码哈希", "system", 5),
("api_secret", "VARCHAR", "API密钥", "system", 5),
("session_token", "VARCHAR", "会话令牌", "system", 4),
("gps_location", "VARCHAR", "GPS定位信息", "vehicle", 4),
("driving_record", "TEXT", "行驶记录", "vehicle", 4),
("medical_record", "TEXT", "医疗记录", "claim", 4),
("income_info", "DECIMAL", "收入信息", "customer", 4),
("credit_score", "INT", "信用评分", "customer", 4),
("family_member", "VARCHAR", "家庭成员信息", "customer", 3),
("emergency_contact", "VARCHAR", "紧急联系人", "customer", 3),
("beneficiary_name", "VARCHAR", "受益人姓名", "policy", 4),
("beneficiary_id_no", "VARCHAR", "受益人身份证号", "policy", 4),
("underwriting_decision", "VARCHAR", "核保结论", "policy", 3),
("risk_score", "DECIMAL", "风险评分", "actuary", 3),
("fraud_flag", "BOOLEAN", "欺诈标记", "claim", 3),
("audit_comment", "TEXT", "审计意见", "system", 3),
("report_period", "VARCHAR", "报表期间", "regulatory", 2),
("regulatory_code", "VARCHAR", "监管编码", "regulatory", 2),
]
all_tables = []
all_columns = []
for database in databases:
prefix_key = "system"
for k in table_prefixes:
if k in database.name.lower() or k in database.source.name.lower():
prefix_key = k
break
prefix_list = table_prefixes.get(prefix_key, table_prefixes["system"])
num_tables = random.randint(15, 40)
for tidx in range(num_tables):
table_name = f"{random.choice(prefix_list)}_{tidx+1:03d}"
tbl = DataTable(
database_id=database.id,
name=table_name,
comment=f"{table_name}数据表",
row_count=random.randint(10000, 10000000),
column_count=0,
)
db.add(tbl)
all_tables.append(tbl)
db.commit()
for t in all_tables:
db.refresh(t)
print(f" Created {len(all_tables)} tables")
# Now generate columns
print(" Generating columns (this may take a moment)...")
levels = db.query(DataLevel).all()
level_map = {l.code: l.id for l in levels}
categories = db.query(Category).all()
cat_map = {}
for c in categories:
if c.code.startswith("CUST") and "customer" not in cat_map:
cat_map["customer"] = c.id
elif c.code.startswith("POLICY") and "policy" not in cat_map:
cat_map["policy"] = c.id
elif c.code.startswith("CLAIM") and "claim" not in cat_map:
cat_map["claim"] = c.id
elif c.code.startswith("FIN") and "finance" not in cat_map:
cat_map["finance"] = c.id
elif c.code.startswith("CHANNEL") and "channel" not in cat_map:
cat_map["channel"] = c.id
elif c.code.startswith("REG") and "regulatory" not in cat_map:
cat_map["regulatory"] = c.id
elif c.code.startswith("INT") and "system" not in cat_map:
cat_map["system"] = c.id
elif c.code.startswith("SUB") and "vehicle" not in cat_map:
cat_map["vehicle"] = c.id
sample_values = {
"customer_name": ["张三", "李四", "王五", "赵六", "钱七"],
"customer_id_no": ["110101199001011234", "310101198502023456", "440106197803034567"],
"mobile_phone": ["13800138000", "13900139000", "13700137000"],
"email": ["user1@example.com", "user2@test.com", "contact@company.com"],
"bank_card_no": ["6222021234567890123", "6228481234567890123"],
"vin_code": ["LSVAG2180E2100001", "LFV3A28K8A3000001"],
"vehicle_plate": ["京A12345", "沪B67890", "粤C11111"],
"policy_no": ["PICC2024000001", "PICC2024000002", "PICC2024000003"],
"claim_no": ["CLM2024000001", "CLM2024000002", "CLM2024000003"],
"address": ["北京市海淀区xxx路1号", "上海市浦东新区xxx路2号"],
}
batch_size = 500
column_batch = []
for tbl in all_tables:
num_cols = random.randint(12, 28)
selected_templates = random.sample(column_templates, k=min(num_cols, len(column_templates)))
for cidx, (col_name, col_type, comment, cat_hint, lvl_hint) in enumerate(selected_templates):
actual_name = col_name if cidx == 0 else f"{col_name}_{cidx}"
samples = None
if col_name in sample_values:
samples = json.dumps(random.sample(sample_values[col_name], k=min(3, len(sample_values[col_name]))), ensure_ascii=False)
col = DataColumn(
table_id=tbl.id,
name=actual_name,
data_type=col_type,
length=random.choice([20, 50, 100, 200, 500]) if "VARCHAR" in col_type else None,
comment=comment,
is_nullable=random.random() > 0.2,
sample_data=samples,
)
column_batch.append(col)
if len(column_batch) >= batch_size:
db.bulk_save_objects(column_batch)
db.commit()
all_columns.extend(column_batch)
column_batch = []
if column_batch:
db.bulk_save_objects(column_batch)
db.commit()
all_columns.extend(column_batch)
print(f" Created {len(all_columns)} columns")
# Update table counts
for tbl in all_tables:
tbl.column_count = db.query(DataColumn).filter(DataColumn.table_id == tbl.id).count()
db.add(tbl)
db.commit()
for database in databases:
database.table_count = db.query(DataTable).filter(DataTable.database_id == database.id).count()
db.add(database)
db.commit()
# ============================================================
# 6. Classification Projects
# ============================================================
print("Generating classification projects...")
templates = db.query(ClassificationTemplate).all()
projects = []
project_names = [
"2024年度数据分类分级专项",
"核心系统敏感数据梳理",
"新核心上线数据定级",
"客户个人信息保护专项",
"财务数据安全治理",
"理赔数据合规检查",
"渠道数据梳理项目",
"监管报送数据定级",
]
for i, name in enumerate(project_names):
p = ClassificationProject(
name=name,
template_id=random.choice(templates).id,
description=f"{name} - 数据分类分级治理项目",
status=random.choice(["created", "scanning", "labeling", "reviewing", "published"]),
target_source_ids=",".join(str(s.id) for s in random.sample(sources, k=random.randint(2, 5))),
planned_start=datetime.now() - timedelta(days=random.randint(10, 60)),
planned_end=datetime.now() + timedelta(days=random.randint(10, 90)),
created_by=random.choice(users).id,
)
db.add(p)
projects.append(p)
db.commit()
for p in projects:
db.refresh(p)
print(f" Created {len(projects)} projects")
# ============================================================
# 7. Classification Results (the critical mass)
# ============================================================
print("Generating classification results...")
# Re-fetch column IDs from DB since bulk_save_objects doesn't populate object IDs
col_rows = db.query(DataColumn.id).all()
all_col_ids = [c[0] for c in col_rows]
random.shuffle(all_col_ids)
result_batch = []
total_results_target = 12000
results_per_project = total_results_target // len(projects)
for proj in projects:
assigned_cols = random.sample(all_col_ids, k=min(results_per_project, len(all_col_ids)))
for col_id in assigned_cols:
source_type = random.choices(["auto", "manual"], weights=[0.7, 0.3])[0]
status_val = "auto" if source_type == "auto" else random.choice(["manual", "reviewed"])
cat = random.choice(categories)
lvl = random.choice(levels)
conf = round(random.uniform(0.3, 0.98), 2)
r = ClassificationResult(
project_id=proj.id,
column_id=col_id,
category_id=cat.id,
level_id=lvl.id,
source=source_type,
confidence=conf,
status=status_val,
labeler_id=random.choice(users).id if source_type == "manual" else None,
)
result_batch.append(r)
if len(result_batch) >= batch_size:
db.bulk_save_objects(result_batch)
db.commit()
result_batch = []
if result_batch:
db.bulk_save_objects(result_batch)
db.commit()
total_results = db.query(ClassificationResult).count()
print(f" Created {total_results} classification results")
# ============================================================
# 8. Classification Tasks
# ============================================================
print("Generating classification tasks...")
tasks = []
for proj in projects:
num_tasks = random.randint(2, 5)
for tidx in range(num_tasks):
task = ClassificationTask(
project_id=proj.id,
name=f"{proj.name}-任务{tidx+1}",
assigner_id=random.choice(users).id,
assignee_id=random.choice(users).id,
target_type="column",
status=random.choice(["pending", "in_progress", "completed"]),
deadline=datetime.now() + timedelta(days=random.randint(5, 30)),
)
db.add(task)
tasks.append(task)
db.commit()
print(f" Created {len(tasks)} tasks")
# ============================================================
# 9. Operation Logs
# ============================================================
print("Generating operation logs...")
log_actions = ["登录", "查询数据源", "创建项目", "自动分类", "人工打标", "导出报告", "修改规则", "删除任务"]
log_modules = ["auth", "datasource", "project", "classification", "task", "report", "rule", "system"]
log_batch = []
for i in range(5000):
log = OperationLog(
user_id=random.choice([None] + [u.id for u in users]),
username=random.choice(["admin"] + [u.username for u in users]),
module=random.choice(log_modules),
action=random.choice(log_actions),
method=random.choice(["GET", "POST", "PUT", "DELETE"]),
path=f"/api/v1/{random.choice(log_modules)}/{random.randint(1, 100)}",
ip=f"10.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(0,255)}",
status_code=random.choice([200, 200, 200, 201, 400, 401, 404, 500]),
duration_ms=random.randint(10, 2000),
created_at=datetime.now() - timedelta(days=random.randint(0, 30), hours=random.randint(0, 23)),
)
log_batch.append(log)
if len(log_batch) >= batch_size:
db.bulk_save_objects(log_batch)
db.commit()
log_batch = []
if log_batch:
db.bulk_save_objects(log_batch)
db.commit()
total_logs = db.query(OperationLog).count()
print(f" Created {total_logs} operation logs")
# ============================================================
# Summary
# ============================================================
print("\n" + "="*60)
print("Test data generation complete!")
print("="*60)
print(f" Departments: {db.query(Dept).count()}")
print(f" Users: {db.query(User).count()}")
print(f" Data Sources: {db.query(DataSource).count()}")
print(f" Databases: {db.query(Database).count()}")
print(f" Tables: {db.query(DataTable).count()}")
print(f" Columns: {db.query(DataColumn).count()}")
print(f" Categories: {db.query(Category).count()}")
print(f" Data Levels: {db.query(DataLevel).count()}")
print(f" Rules: {db.query(RecognitionRule).count()}")
print(f" Templates: {db.query(ClassificationTemplate).count()}")
print(f" Projects: {db.query(ClassificationProject).count()}")
print(f" Tasks: {db.query(ClassificationTask).count()}")
print(f" Results: {db.query(ClassificationResult).count()}")
print(f" Operation Logs: {db.query(OperationLog).count()}")
print("="*60)
db.close()