From 4b08bb50577f9376163c62196e455ddcd4772119 Mon Sep 17 00:00:00 2001 From: hiderfong Date: Wed, 22 Apr 2026 18:16:44 +0800 Subject: [PATCH] feat: add test data generator script Generate 35,000+ realistic test records: - 12 data sources (PostgreSQL, MySQL, Oracle, SQL Server, DM) - 31 databases, 863 tables, 17,152 columns - 81 users across 13 departments - 8 classification projects, 24 tasks - 12,000 classification results with varied confidence levels - 5,000 operation logs Covers all insurance domains: policy, claim, customer, finance, channel, actuary, regulatory, vehicle --- backend/scripts/check_sequences.py | 9 + backend/scripts/generate_test_data.py | 540 ++++++++++++++++++++++++++ 2 files changed, 549 insertions(+) create mode 100644 backend/scripts/check_sequences.py create mode 100644 backend/scripts/generate_test_data.py diff --git a/backend/scripts/check_sequences.py b/backend/scripts/check_sequences.py new file mode 100644 index 00000000..b75ba765 --- /dev/null +++ b/backend/scripts/check_sequences.py @@ -0,0 +1,9 @@ +import sys +sys.path.insert(0, '/Users/nathan/Work/DataPointer/prop-data-guard/backend') +from app.core.database import engine +from sqlalchemy import text + +with engine.connect() as conn: + result = conn.execute(text("SELECT sequencename FROM pg_sequences WHERE schemaname = 'public'")) + for row in result: + print(row[0]) diff --git a/backend/scripts/generate_test_data.py b/backend/scripts/generate_test_data.py new file mode 100644 index 00000000..621d5a00 --- /dev/null +++ b/backend/scripts/generate_test_data.py @@ -0,0 +1,540 @@ +""" +Generate test data for PropDataGuard system. +Targets: 10000+ records across all tables. +""" +import sys +sys.path.insert(0, '/Users/nathan/Work/DataPointer/prop-data-guard/backend') + +import random +import string +import json +from datetime import datetime, timedelta +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from app.core.database import Base +from app.models.user import User, Role, Dept, UserRole +from app.models.metadata import DataSource, Database, DataTable, DataColumn +from app.models.classification import Category, DataLevel, RecognitionRule, ClassificationTemplate +from app.models.project import ClassificationProject, ClassificationTask, ClassificationResult, ResultStatus +from app.models.log import OperationLog +from app.core.security import get_password_hash + +# Database connection +DATABASE_URL = "postgresql+psycopg2://pdg:pdg_secret_2024@localhost:5432/prop_data_guard" +engine = create_engine(DATABASE_URL) +SessionLocal = sessionmaker(bind=engine) +db = SessionLocal() + +# Clear existing test data (preserve admin user and built-in data) +print("Clearing existing test data...") +db.query(ClassificationResult).delete(synchronize_session=False) +db.query(ClassificationTask).delete(synchronize_session=False) +db.query(ClassificationProject).delete(synchronize_session=False) +db.query(DataColumn).delete(synchronize_session=False) +db.query(DataTable).delete(synchronize_session=False) +db.query(Database).delete(synchronize_session=False) +db.query(UserRole).filter(UserRole.user_id > 1).delete(synchronize_session=False) +db.query(User).filter(User.id > 1).delete(synchronize_session=False) +db.query(Dept).filter(Dept.id > 1).delete(synchronize_session=False) +db.query(OperationLog).delete(synchronize_session=False) +db.commit() + +# Reset all sequences to avoid ID conflicts +from sqlalchemy import text +sequences = [ + "sys_dept_id_seq", "sys_user_id_seq", "sys_user_role_id_seq", + "data_source_id_seq", "meta_database_id_seq", "meta_table_id_seq", "meta_column_id_seq", + "classification_project_id_seq", "classification_task_id_seq", "classification_result_id_seq", + "classification_change_id_seq", "sys_operation_log_id_seq", +] +for seq in sequences: + db.execute(text(f"ALTER SEQUENCE {seq} RESTART WITH 100")) +db.commit() +print(" Sequences reset") + +random.seed(42) + +# ============================================================ +# 1. Departments +# ============================================================ +print("Generating departments...") +root_dept_names = ["数据安全部", "合规管理部", "信息技术部"] +root_depts = [] +for name in root_dept_names: + d = Dept(name=name, parent_id=None, sort_order=len(root_depts)) + db.add(d) + root_depts.append(d) +db.commit() +for d in root_depts: + db.refresh(d) + +# Map root depts by index: 0=数据安全部, 1=合规管理部, 2=信息技术部 +root_id_map = {i+1: d.id for i, d in enumerate(root_depts)} + +child_dept_defs = [ + ("业务一部", root_id_map[1]), ("业务二部", root_id_map[1]), + ("车险事业部", root_id_map[3]), ("非车险事业部", root_id_map[3]), ("理赔服务部", root_id_map[3]), + ("财务部", root_id_map[2]), ("精算部", root_id_map[2]), + ("客户服务部", root_id_map[1]), ("渠道管理部", root_id_map[1]), +] +depts = root_depts[:] +for name, pid in child_dept_defs: + d = Dept(name=name, parent_id=pid, sort_order=len(depts)) + db.add(d) + depts.append(d) +db.commit() +for d in depts[len(root_depts):]: + db.refresh(d) +print(f" Created {len(depts)} departments") + +# ============================================================ +# 2. Users +# ============================================================ +print("Generating users...") +roles = db.query(Role).all() +role_map = {r.code: r.id for r in roles} + +first_names = ["王", "李", "张", "刘", "陈", "杨", "赵", "黄", "周", "吴", "徐", "孙", "马", "朱", "胡", "郭", "林", "何", "高", "罗"] +last_names = ["伟", "芳", "娜", "敏", "静", "丽", "强", "磊", "军", "洋", "勇", "艳", "杰", "娟", "涛", "明", "超", "秀英", "华", "平"] + +def random_name(): + return random.choice(first_names) + random.choice(last_names) + +def random_phone(): + return "1" + random.choice(["3","4","5","6","7","8","9"]) + "".join(random.choices(string.digits, k=9)) + +users = [] +for i in range(80): + real = random_name() + username = f"user{i+2:03d}" + user = User( + username=username, + email=f"{username}@propdataguard.com", + hashed_password=get_password_hash("password123"), + real_name=real, + phone=random_phone(), + is_active=random.random() > 0.05, + is_superuser=False, + dept_id=random.choice(depts).id, + ) + db.add(user) + users.append(user) +db.commit() +for u in users: + db.refresh(u) + +# Assign roles +role_list = list(roles) +for u in users: + assigned_roles = random.sample(role_list, k=random.randint(1, 2)) + for r in assigned_roles: + db.add(UserRole(user_id=u.id, role_id=r.id)) +db.commit() +print(f" Created {len(users)} users") + +# ============================================================ +# 3. Data Sources +# ============================================================ +print("Generating data sources...") +source_types = ["postgresql", "mysql", "oracle", "sqlserver", "dm"] +source_configs = [ + ("核心保单数据库", "postgresql", "db-core-prod", 5432, "core_policy"), + ("理赔系统数据库", "mysql", "db-claim-prod", 3306, "claim_db"), + ("财务数据仓库", "postgresql", "db-finance-dw", 5432, "finance_dw"), + ("客户信息主库", "mysql", "db-cust-master", 3306, "customer_master"), + ("渠道管理系统", "oracle", "db-channel-ora", 1521, "CHANNEL"), + ("精算分析平台", "postgresql", "db-actuary-ana", 5432, "actuary_analytics"), + ("监管报送库", "mysql", "db-regulatory", 3306, "regulatory_report"), + ("车辆信息库", "postgresql", "db-vehicle", 5432, "vehicle_db"), + ("非车险业务库", "sqlserver", "db-nonauto", 1433, "NonAutoDB"), + ("历史归档库", "postgresql", "db-archive", 5432, "archive_db"), + ("测试环境核心库", "postgresql", "db-core-test", 5432, "core_test"), + ("达梦国产数据库", "dm", "db-dameng-prod", 5236, "DAMENG"), +] + +sources = [] +for name, stype, host, port, dbname in source_configs: + ds = DataSource( + name=name, + source_type=stype, + host=f"{host}.internal.company.com", + port=port, + database_name=dbname, + username=f"{stype}_admin", + encrypted_password=None, + status="active" if random.random() > 0.1 else "error", + dept_id=random.choice(depts).id, + created_by=random.choice(users).id, + ) + db.add(ds) + sources.append(ds) +db.commit() +for s in sources: + db.refresh(s) +print(f" Created {len(sources)} data sources") + +# ============================================================ +# 4. Databases +# ============================================================ +print("Generating databases...") +databases = [] +for source in sources: + num_dbs = random.randint(1, 3) + for i in range(num_dbs): + d = Database( + source_id=source.id, + name=f"{source.database_name}_{i+1}" if num_dbs > 1 else source.database_name, + charset="UTF8" if source.source_type != "sqlserver" else "Chinese_PRC_CI_AS", + table_count=0, + ) + db.add(d) + databases.append(d) +db.commit() +for d in databases: + db.refresh(d) +print(f" Created {len(databases)} databases") + +# ============================================================ +# 5. Data Tables & Columns (the big one) +# ============================================================ +print("Generating tables and columns...") + +table_prefixes = { + "policy": ["t_policy", "t_policy_detail", "t_policy_extension", "t_policy_history", "t_endorsement"], + "claim": ["t_claim", "t_claim_detail", "t_claim_payment", "t_claim_document", "t_survey"], + "customer": ["t_customer", "t_customer_contact", "t_customer_identity", "t_customer_vehicle", "t_customer_preference"], + "finance": ["t_payment", "t_receipt", "t_invoice", "t_commission", "t_reserve"], + "channel": ["t_agent", "t_agent_contract", "t_partner", "t_broker", "t_sales_record"], + "actuary": ["t_pricing_model", "t_risk_factor", "t_loss_ratio", "t_reserve_calc", "t_solvency"], + "regulatory": ["t_report_cbrc", "t_report_circ", "t_stat_premium", "t_stat_claim", "t_stat_channel"], + "vehicle": ["t_vehicle", "t_vehicle_model", "t_vehicle_usage", "t_vehicle_accident", "t_vehicle_maintenance"], + "system": ["t_user", "t_role", "t_permission", "t_log", "t_config", "t_dict"], + "archive": ["t_archive_policy", "t_archive_claim", "t_archive_customer", "t_archive_finance"], +} + +column_templates = [ + ("id", "BIGINT", "主键ID", "system", 2), + ("created_at", "TIMESTAMP", "创建时间", "system", 2), + ("updated_at", "TIMESTAMP", "更新时间", "system", 2), + ("is_deleted", "BOOLEAN", "是否删除", "system", 2), + ("created_by", "BIGINT", "创建人", "system", 2), + ("customer_name", "VARCHAR", "客户姓名", "customer", 4), + ("customer_id_no", "VARCHAR", "客户身份证号", "customer", 4), + ("mobile_phone", "VARCHAR", "手机号码", "customer", 4), + ("email", "VARCHAR", "电子邮箱", "customer", 3), + ("address", "VARCHAR", "联系地址", "customer", 3), + ("bank_account", "VARCHAR", "银行账户", "finance", 4), + ("bank_card_no", "VARCHAR", "银行卡号", "finance", 4), + ("policy_no", "VARCHAR", "保单号", "policy", 3), + ("policy_status", "VARCHAR", "保单状态", "policy", 2), + ("premium_amount", "DECIMAL", "保费金额", "finance", 3), + ("claim_no", "VARCHAR", "理赔号", "claim", 3), + ("claim_amount", "DECIMAL", "理赔金额", "claim", 4), + ("loss_description", "TEXT", "损失描述", "claim", 3), + ("accident_location", "VARCHAR", "出险地点", "claim", 3), + ("vehicle_plate", "VARCHAR", "车牌号", "vehicle", 3), + ("vin_code", "VARCHAR", "车辆识别代码VIN", "vehicle", 4), + ("agent_name", "VARCHAR", "代理人姓名", "channel", 3), + ("agent_license", "VARCHAR", "代理人执业证号", "channel", 3), + ("commission_rate", "DECIMAL", "佣金比例", "finance", 3), + ("reserve_amount", "DECIMAL", "准备金金额", "finance", 5), + ("solvency_ratio", "DECIMAL", "偿付能力充足率", "finance", 5), + ("password_hash", "VARCHAR", "密码哈希", "system", 5), + ("api_secret", "VARCHAR", "API密钥", "system", 5), + ("session_token", "VARCHAR", "会话令牌", "system", 4), + ("gps_location", "VARCHAR", "GPS定位信息", "vehicle", 4), + ("driving_record", "TEXT", "行驶记录", "vehicle", 4), + ("medical_record", "TEXT", "医疗记录", "claim", 4), + ("income_info", "DECIMAL", "收入信息", "customer", 4), + ("credit_score", "INT", "信用评分", "customer", 4), + ("family_member", "VARCHAR", "家庭成员信息", "customer", 3), + ("emergency_contact", "VARCHAR", "紧急联系人", "customer", 3), + ("beneficiary_name", "VARCHAR", "受益人姓名", "policy", 4), + ("beneficiary_id_no", "VARCHAR", "受益人身份证号", "policy", 4), + ("underwriting_decision", "VARCHAR", "核保结论", "policy", 3), + ("risk_score", "DECIMAL", "风险评分", "actuary", 3), + ("fraud_flag", "BOOLEAN", "欺诈标记", "claim", 3), + ("audit_comment", "TEXT", "审计意见", "system", 3), + ("report_period", "VARCHAR", "报表期间", "regulatory", 2), + ("regulatory_code", "VARCHAR", "监管编码", "regulatory", 2), +] + +all_tables = [] +all_columns = [] + +for database in databases: + prefix_key = "system" + for k in table_prefixes: + if k in database.name.lower() or k in database.source.name.lower(): + prefix_key = k + break + + prefix_list = table_prefixes.get(prefix_key, table_prefixes["system"]) + num_tables = random.randint(15, 40) + + for tidx in range(num_tables): + table_name = f"{random.choice(prefix_list)}_{tidx+1:03d}" + tbl = DataTable( + database_id=database.id, + name=table_name, + comment=f"{table_name}数据表", + row_count=random.randint(10000, 10000000), + column_count=0, + ) + db.add(tbl) + all_tables.append(tbl) + +db.commit() +for t in all_tables: + db.refresh(t) + +print(f" Created {len(all_tables)} tables") + +# Now generate columns +print(" Generating columns (this may take a moment)...") +levels = db.query(DataLevel).all() +level_map = {l.code: l.id for l in levels} + +categories = db.query(Category).all() +cat_map = {} +for c in categories: + if c.code.startswith("CUST") and "customer" not in cat_map: + cat_map["customer"] = c.id + elif c.code.startswith("POLICY") and "policy" not in cat_map: + cat_map["policy"] = c.id + elif c.code.startswith("CLAIM") and "claim" not in cat_map: + cat_map["claim"] = c.id + elif c.code.startswith("FIN") and "finance" not in cat_map: + cat_map["finance"] = c.id + elif c.code.startswith("CHANNEL") and "channel" not in cat_map: + cat_map["channel"] = c.id + elif c.code.startswith("REG") and "regulatory" not in cat_map: + cat_map["regulatory"] = c.id + elif c.code.startswith("INT") and "system" not in cat_map: + cat_map["system"] = c.id + elif c.code.startswith("SUB") and "vehicle" not in cat_map: + cat_map["vehicle"] = c.id + +sample_values = { + "customer_name": ["张三", "李四", "王五", "赵六", "钱七"], + "customer_id_no": ["110101199001011234", "310101198502023456", "440106197803034567"], + "mobile_phone": ["13800138000", "13900139000", "13700137000"], + "email": ["user1@example.com", "user2@test.com", "contact@company.com"], + "bank_card_no": ["6222021234567890123", "6228481234567890123"], + "vin_code": ["LSVAG2180E2100001", "LFV3A28K8A3000001"], + "vehicle_plate": ["京A12345", "沪B67890", "粤C11111"], + "policy_no": ["PICC2024000001", "PICC2024000002", "PICC2024000003"], + "claim_no": ["CLM2024000001", "CLM2024000002", "CLM2024000003"], + "address": ["北京市海淀区xxx路1号", "上海市浦东新区xxx路2号"], +} + +batch_size = 500 +column_batch = [] + +for tbl in all_tables: + num_cols = random.randint(12, 28) + selected_templates = random.sample(column_templates, k=min(num_cols, len(column_templates))) + + for cidx, (col_name, col_type, comment, cat_hint, lvl_hint) in enumerate(selected_templates): + actual_name = col_name if cidx == 0 else f"{col_name}_{cidx}" + samples = None + if col_name in sample_values: + samples = json.dumps(random.sample(sample_values[col_name], k=min(3, len(sample_values[col_name]))), ensure_ascii=False) + + col = DataColumn( + table_id=tbl.id, + name=actual_name, + data_type=col_type, + length=random.choice([20, 50, 100, 200, 500]) if "VARCHAR" in col_type else None, + comment=comment, + is_nullable=random.random() > 0.2, + sample_data=samples, + ) + column_batch.append(col) + + if len(column_batch) >= batch_size: + db.bulk_save_objects(column_batch) + db.commit() + all_columns.extend(column_batch) + column_batch = [] + +if column_batch: + db.bulk_save_objects(column_batch) + db.commit() + all_columns.extend(column_batch) + +print(f" Created {len(all_columns)} columns") + +# Update table counts +for tbl in all_tables: + tbl.column_count = db.query(DataColumn).filter(DataColumn.table_id == tbl.id).count() + db.add(tbl) +db.commit() + +for database in databases: + database.table_count = db.query(DataTable).filter(DataTable.database_id == database.id).count() + db.add(database) +db.commit() + +# ============================================================ +# 6. Classification Projects +# ============================================================ +print("Generating classification projects...") +templates = db.query(ClassificationTemplate).all() +projects = [] +project_names = [ + "2024年度数据分类分级专项", + "核心系统敏感数据梳理", + "新核心上线数据定级", + "客户个人信息保护专项", + "财务数据安全治理", + "理赔数据合规检查", + "渠道数据梳理项目", + "监管报送数据定级", +] + +for i, name in enumerate(project_names): + p = ClassificationProject( + name=name, + template_id=random.choice(templates).id, + description=f"{name} - 数据分类分级治理项目", + status=random.choice(["created", "scanning", "labeling", "reviewing", "published"]), + target_source_ids=",".join(str(s.id) for s in random.sample(sources, k=random.randint(2, 5))), + planned_start=datetime.now() - timedelta(days=random.randint(10, 60)), + planned_end=datetime.now() + timedelta(days=random.randint(10, 90)), + created_by=random.choice(users).id, + ) + db.add(p) + projects.append(p) +db.commit() +for p in projects: + db.refresh(p) +print(f" Created {len(projects)} projects") + +# ============================================================ +# 7. Classification Results (the critical mass) +# ============================================================ +print("Generating classification results...") + +all_col_ids = [c.id for c in all_columns] +random.shuffle(all_col_ids) + +result_batch = [] +total_results_target = 12000 +results_per_project = total_results_target // len(projects) + +for proj in projects: + assigned_cols = random.sample(all_col_ids, k=min(results_per_project, len(all_col_ids))) + + for col_id in assigned_cols: + source_type = random.choices(["auto", "manual"], weights=[0.7, 0.3])[0] + status_val = "auto" if source_type == "auto" else random.choice(["manual", "reviewed"]) + + cat = random.choice(categories) + lvl = random.choice(levels) + conf = round(random.uniform(0.3, 0.98), 2) + + r = ClassificationResult( + project_id=proj.id, + column_id=col_id, + category_id=cat.id, + level_id=lvl.id, + source=source_type, + confidence=conf, + status=status_val, + labeler_id=random.choice(users).id if source_type == "manual" else None, + ) + result_batch.append(r) + + if len(result_batch) >= batch_size: + db.bulk_save_objects(result_batch) + db.commit() + result_batch = [] + +if result_batch: + db.bulk_save_objects(result_batch) + db.commit() + +total_results = db.query(ClassificationResult).count() +print(f" Created {total_results} classification results") + +# ============================================================ +# 8. Classification Tasks +# ============================================================ +print("Generating classification tasks...") +tasks = [] +for proj in projects: + num_tasks = random.randint(2, 5) + for tidx in range(num_tasks): + task = ClassificationTask( + project_id=proj.id, + name=f"{proj.name}-任务{tidx+1}", + assigner_id=random.choice(users).id, + assignee_id=random.choice(users).id, + target_type="column", + status=random.choice(["pending", "in_progress", "completed"]), + deadline=datetime.now() + timedelta(days=random.randint(5, 30)), + ) + db.add(task) + tasks.append(task) +db.commit() +print(f" Created {len(tasks)} tasks") + +# ============================================================ +# 9. Operation Logs +# ============================================================ +print("Generating operation logs...") +log_actions = ["登录", "查询数据源", "创建项目", "自动分类", "人工打标", "导出报告", "修改规则", "删除任务"] +log_modules = ["auth", "datasource", "project", "classification", "task", "report", "rule", "system"] + +log_batch = [] +for i in range(5000): + log = OperationLog( + user_id=random.choice([None] + [u.id for u in users]), + username=random.choice(["admin"] + [u.username for u in users]), + module=random.choice(log_modules), + action=random.choice(log_actions), + method=random.choice(["GET", "POST", "PUT", "DELETE"]), + path=f"/api/v1/{random.choice(log_modules)}/{random.randint(1, 100)}", + ip=f"10.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(0,255)}", + status_code=random.choice([200, 200, 200, 201, 400, 401, 404, 500]), + duration_ms=random.randint(10, 2000), + created_at=datetime.now() - timedelta(days=random.randint(0, 30), hours=random.randint(0, 23)), + ) + log_batch.append(log) + if len(log_batch) >= batch_size: + db.bulk_save_objects(log_batch) + db.commit() + log_batch = [] + +if log_batch: + db.bulk_save_objects(log_batch) + db.commit() + +total_logs = db.query(OperationLog).count() +print(f" Created {total_logs} operation logs") + +# ============================================================ +# Summary +# ============================================================ +print("\n" + "="*60) +print("Test data generation complete!") +print("="*60) +print(f" Departments: {db.query(Dept).count()}") +print(f" Users: {db.query(User).count()}") +print(f" Data Sources: {db.query(DataSource).count()}") +print(f" Databases: {db.query(Database).count()}") +print(f" Tables: {db.query(DataTable).count()}") +print(f" Columns: {db.query(DataColumn).count()}") +print(f" Categories: {db.query(Category).count()}") +print(f" Data Levels: {db.query(DataLevel).count()}") +print(f" Rules: {db.query(RecognitionRule).count()}") +print(f" Templates: {db.query(ClassificationTemplate).count()}") +print(f" Projects: {db.query(ClassificationProject).count()}") +print(f" Tasks: {db.query(ClassificationTask).count()}") +print(f" Results: {db.query(ClassificationResult).count()}") +print(f" Operation Logs: {db.query(OperationLog).count()}") +print("="*60) + +db.close()