feat: initial commit - Phase 1 & 2 core features

This commit is contained in:
hiderfong
2026-04-22 17:07:33 +08:00
commit 1773bda06b
25005 changed files with 6252106 additions and 0 deletions
View File
+33
View File
@@ -0,0 +1,33 @@
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.user import User
from app.core.security import verify_password, create_token_pair
def authenticate_user(db: Session, username: str, password: str):
user = db.query(User).filter(User.username == username).first()
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
if not user.is_active:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用")
return user
def login(db: Session, username: str, password: str):
user = authenticate_user(db, username, password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
access_token, refresh_token = create_token_pair(user.id, user.username)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"expires_in": 30 * 60,
}
@@ -0,0 +1,134 @@
import re
import json
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from app.models.classification import RecognitionRule, Category, DataLevel
from app.models.metadata import DataColumn, DataTable
from app.models.project import ClassificationProject, ClassificationResult, ResultStatus
def match_rule(rule: RecognitionRule, column: DataColumn) -> Tuple[bool, float]:
"""Match a single rule against a column. Returns (matched, confidence)."""
targets = []
if rule.target_field == "column_name":
targets = [column.name]
elif rule.target_field == "comment":
targets = [column.comment or ""]
elif rule.target_field == "sample_data":
targets = []
if column.sample_data:
try:
samples = json.loads(column.sample_data)
if isinstance(samples, list):
targets = [str(s) for s in samples]
except Exception:
targets = [column.sample_data]
if not targets:
return False, 0.0
if rule.rule_type == "regex":
try:
pattern = re.compile(rule.rule_content)
for t in targets:
if pattern.search(t):
return True, 0.85
except re.error:
return False, 0.0
elif rule.rule_type == "keyword":
keywords = [k.strip().lower() for k in rule.rule_content.split(",")]
for t in targets:
t_lower = t.lower()
for kw in keywords:
if kw in t_lower:
return True, 0.75
elif rule.rule_type == "enum":
enums = [e.strip().lower() for e in rule.rule_content.split(",")]
for t in targets:
if t.strip().lower() in enums:
return True, 0.90
return False, 0.0
def run_auto_classification(db: Session, project_id: int, source_ids: Optional[List[int]] = None) -> dict:
"""Run automatic classification for a project."""
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if not project:
return {"success": False, "message": "项目不存在"}
# Get active rules from project's template
rules = db.query(RecognitionRule).filter(
RecognitionRule.is_active == True,
RecognitionRule.template_id == project.template_id,
).order_by(RecognitionRule.priority).all()
if not rules:
return {"success": False, "message": "没有可用的识别规则"}
# Get columns to classify
from app.services.metadata_service import list_tables, list_columns
columns_query = db.query(DataColumn).join(DataTable).join(app.models.metadata.Database)
if source_ids:
columns_query = columns_query.filter(app.models.metadata.Database.source_id.in_(source_ids))
elif project.target_source_ids:
sids = [int(x) for x in project.target_source_ids.split(",") if x]
columns_query = columns_query.filter(app.models.metadata.Database.source_id.in_(sids))
columns = columns_query.all()
matched_count = 0
for col in columns:
# Check if already has a result for this project
existing = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project_id,
ClassificationResult.column_id == col.id,
).first()
best_rule = None
best_confidence = 0.0
for rule in rules:
matched, confidence = match_rule(rule, col)
if matched and confidence > best_confidence:
best_confidence = confidence
best_rule = rule
if best_rule:
matched_count += 1
if existing:
existing.category_id = best_rule.category_id
existing.level_id = best_rule.level_id
existing.confidence = best_confidence
existing.source = "auto"
existing.status = ResultStatus.AUTO.value
else:
result = ClassificationResult(
project_id=project_id,
column_id=col.id,
category_id=best_rule.category_id,
level_id=best_rule.level_id,
source="auto",
confidence=best_confidence,
status=ResultStatus.AUTO.value,
)
db.add(result)
# Increment hit count
best_rule.hit_count = (best_rule.hit_count or 0) + 1
db.commit()
return {
"success": True,
"message": f"自动分类完成,共扫描 {len(columns)} 个字段,命中 {matched_count}",
"scanned": len(columns),
"matched": matched_count,
}
import app.models.metadata
@@ -0,0 +1,268 @@
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.classification import Category, DataLevel, RecognitionRule, ClassificationTemplate
from app.schemas.classification import CategoryCreate, CategoryUpdate, RecognitionRuleCreate, RecognitionRuleUpdate
def get_category(db: Session, category_id: int) -> Optional[Category]:
return db.query(Category).filter(Category.id == category_id).first()
def list_categories(db: Session, parent_id: Optional[int] = None) -> List[Category]:
query = db.query(Category)
if parent_id is not None:
query = query.filter(Category.parent_id == parent_id)
return query.order_by(Category.sort_order).all()
def build_category_tree(db: Session) -> List[dict]:
def build_tree(parent_id: Optional[int]) -> List[dict]:
nodes = db.query(Category).filter(Category.parent_id == parent_id).order_by(Category.sort_order).all()
result = []
for node in nodes:
result.append({
"id": node.id,
"parent_id": node.parent_id,
"level": node.level,
"code": node.code,
"name": node.name,
"description": node.description,
"sort_order": node.sort_order,
"created_at": node.created_at,
"children": build_tree(node.id),
})
return result
return build_tree(None)
def create_category(db: Session, obj_in: CategoryCreate) -> Category:
db_obj = Category(**obj_in.model_dump())
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update_category(db: Session, db_obj: Category, obj_in: CategoryUpdate) -> Category:
update_data = obj_in.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_obj, field, value)
db.commit()
db.refresh(db_obj)
return db_obj
def delete_category(db: Session, category_id: int) -> None:
db_obj = get_category(db, category_id)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="分类不存在")
# Check children
children = db.query(Category).filter(Category.parent_id == category_id).first()
if children:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="存在子分类,无法删除")
db.delete(db_obj)
db.commit()
def list_data_levels(db: Session) -> List[DataLevel]:
return db.query(DataLevel).order_by(DataLevel.sort_order).all()
def get_data_level(db: Session, level_id: int) -> Optional[DataLevel]:
return db.query(DataLevel).filter(DataLevel.id == level_id).first()
def create_data_level(db: Session, code: str, name: str, description: str, color: str, sort_order: int = 0, control_requirements: Optional[dict] = None) -> DataLevel:
db_obj = DataLevel(code=code, name=name, description=description, color=color, sort_order=sort_order, control_requirements=control_requirements)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def get_rule(db: Session, rule_id: int) -> Optional[RecognitionRule]:
return db.query(RecognitionRule).filter(RecognitionRule.id == rule_id).first()
def list_rules(db: Session, template_id: Optional[int] = None, keyword: Optional[str] = None, page: int = 1, page_size: int = 20) -> Tuple[List[RecognitionRule], int]:
query = db.query(RecognitionRule)
if template_id:
query = query.filter(RecognitionRule.template_id == template_id)
if keyword:
query = query.filter(
(RecognitionRule.rule_name.contains(keyword)) | (RecognitionRule.rule_content.contains(keyword))
)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
def create_rule(db: Session, obj_in: RecognitionRuleCreate) -> RecognitionRule:
db_obj = RecognitionRule(**obj_in.model_dump())
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update_rule(db: Session, db_obj: RecognitionRule, obj_in: RecognitionRuleUpdate) -> RecognitionRule:
update_data = obj_in.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_obj, field, value)
db.commit()
db.refresh(db_obj)
return db_obj
def delete_rule(db: Session, rule_id: int) -> None:
db_obj = get_rule(db, rule_id)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="规则不存在")
db.delete(db_obj)
db.commit()
def get_template(db: Session, template_id: int) -> Optional[ClassificationTemplate]:
return db.query(ClassificationTemplate).filter(ClassificationTemplate.id == template_id).first()
def list_templates(db: Session) -> List[ClassificationTemplate]:
return db.query(ClassificationTemplate).order_by(ClassificationTemplate.id).all()
def init_builtin_data(db: Session):
# Data Levels
if not db.query(DataLevel).first():
levels = [
("L1", "公开级", "可对外公开发布", "#67c23a", 1, {"storage": "无特殊要求", "access": "公开访问"}),
("L2", "内部级", "公司内部共享使用", "#409eff", 2, {"storage": "内部环境", "access": "内部员工"}),
("L3", "敏感级", "部门/授权人员访问,外部分享需审批", "#e6a23c", 3, {"storage": "加密存储", "access": "授权访问"}),
("L4", "重要级", "严格授权管理,外部分享需严格审批", "#f56c6c", 4, {"storage": "强加密", "access": "最小权限"}),
("L5", "核心级", "禁止对外共享", "#909399", 5, {"storage": "物理隔离", "access": "核心人员"}),
]
for code, name, desc, color, sort, ctrl in levels:
create_data_level(db, code, name, desc, color, sort, ctrl)
# Categories
if not db.query(Category).first():
categories = [
# Level 1
{"code": "CUST", "name": "客户数据", "level": 1, "sort_order": 1},
{"code": "POLICY", "name": "保单数据", "level": 1, "sort_order": 2},
{"code": "CLAIM", "name": "理赔数据", "level": 1, "sort_order": 3},
{"code": "FIN", "name": "财务数据", "level": 1, "sort_order": 4},
{"code": "CHANNEL", "name": "渠道数据", "level": 1, "sort_order": 5},
{"code": "REG", "name": "监管报送数据", "level": 1, "sort_order": 6},
{"code": "INTERNAL", "name": "内部管理数据", "level": 1, "sort_order": 7},
{"code": "SUBJECT", "name": "车辆/财产标的数据", "level": 1, "sort_order": 8},
]
cat_map = {}
for c in categories:
obj = Category(parent_id=None, level=c["level"], code=c["code"], name=c["name"], sort_order=c["sort_order"])
db.add(obj)
db.commit()
db.refresh(obj)
cat_map[c["code"]] = obj.id
# Level 2
sub_categories = [
{"parent_code": "CUST", "code": "CUST_PERSONAL", "name": "个人客户信息", "sort_order": 1},
{"parent_code": "CUST", "code": "CUST_ENTERPRISE", "name": "企业客户信息", "sort_order": 2},
{"parent_code": "CUST", "code": "CUST_BENEFICIARY", "name": "受益人信息", "sort_order": 3},
{"parent_code": "POLICY", "code": "POLICY_APPLY", "name": "投保信息", "sort_order": 1},
{"parent_code": "POLICY", "code": "POLICY_UNDERWRITE", "name": "承保信息", "sort_order": 2},
{"parent_code": "POLICY", "code": "POLICY_RENEW", "name": "续保信息", "sort_order": 3},
{"parent_code": "CLAIM", "code": "CLAIM_REPORT", "name": "报案信息", "sort_order": 1},
{"parent_code": "CLAIM", "code": "CLAIM_SURVEY", "name": "查勘定损信息", "sort_order": 2},
{"parent_code": "CLAIM", "code": "CLAIM_PAY", "name": "赔付信息", "sort_order": 3},
{"parent_code": "FIN", "code": "FIN_PAYMENT", "name": "收付费数据", "sort_order": 1},
{"parent_code": "FIN", "code": "FIN_RESERVE", "name": "准备金数据", "sort_order": 2},
{"parent_code": "FIN", "code": "FIN_INVEST", "name": "投资数据", "sort_order": 3},
{"parent_code": "CHANNEL", "code": "CHN_AGENT", "name": "代理人/经纪人信息", "sort_order": 1},
{"parent_code": "CHANNEL", "code": "CHN_PARTNER", "name": "第三方合作方", "sort_order": 2},
{"parent_code": "REG", "code": "REG_SOLVENCY", "name": "偿付能力数据", "sort_order": 1},
{"parent_code": "REG", "code": "REG_STAT", "name": "统计报表数据", "sort_order": 2},
{"parent_code": "INTERNAL", "code": "INT_EMPLOYEE", "name": "员工信息", "sort_order": 1},
{"parent_code": "INTERNAL", "code": "INT_OPS", "name": "系统运维数据", "sort_order": 2},
{"parent_code": "SUBJECT", "code": "SUB_VEHICLE", "name": "车辆信息", "sort_order": 1},
{"parent_code": "SUBJECT", "code": "SUB_PROPERTY", "name": "财产标的", "sort_order": 2},
]
for sc in sub_categories:
parent_id = cat_map.get(sc["parent_code"])
if parent_id:
obj = Category(parent_id=parent_id, level=2, code=sc["code"], name=sc["name"], sort_order=sc["sort_order"])
db.add(obj)
db.commit()
# Template
if not db.query(ClassificationTemplate).first():
tpl = ClassificationTemplate(
name="财产保险行业分类分级模板",
industry_type="insurance_property",
version="1.0",
description="基于《金融数据安全 数据安全分级指南》及保险行业特点制定的分类分级模板",
is_builtin=True,
is_active=True,
)
db.add(tpl)
db.commit()
db.refresh(tpl)
# Create some sample rules
level_l4 = db.query(DataLevel).filter(DataLevel.code == "L4").first()
level_l3 = db.query(DataLevel).filter(DataLevel.code == "L3").first()
level_l5 = db.query(DataLevel).filter(DataLevel.code == "L5").first()
cat_cust_personal = db.query(Category).filter(Category.code == "CUST_PERSONAL").first()
cat_fin_reserve = db.query(Category).filter(Category.code == "FIN_RESERVE").first()
cat_int_ops = db.query(Category).filter(Category.code == "INT_OPS").first()
rules = []
if cat_cust_personal and level_l4:
rules.append(RecognitionRule(
template_id=tpl.id,
category_id=cat_cust_personal.id,
level_id=level_l4.id,
rule_type="regex",
rule_name="身份证号识别",
rule_content=r"(\d{15}|\d{18}|\d{17}[xX])",
target_field="sample_data",
priority=10,
))
rules.append(RecognitionRule(
template_id=tpl.id,
category_id=cat_cust_personal.id,
level_id=level_l4.id,
rule_type="keyword",
rule_name="手机号字段识别",
rule_content="手机,mobile,phone,telephone,tel",
target_field="column_name",
priority=20,
))
if cat_fin_reserve and level_l5:
rules.append(RecognitionRule(
template_id=tpl.id,
category_id=cat_fin_reserve.id,
level_id=level_l5.id,
rule_type="keyword",
rule_name="精算模型识别",
rule_content="精算,actuarial,准备金,reserve,偿付能力,solvency",
target_field="column_name",
priority=10,
))
if cat_int_ops and level_l5:
rules.append(RecognitionRule(
template_id=tpl.id,
category_id=cat_int_ops.id,
level_id=level_l5.id,
rule_type="keyword",
rule_name="密码密钥识别",
rule_content="password,secret,key,token,密钥,密码",
target_field="column_name",
priority=5,
))
for r in rules:
db.add(r)
db.commit()
+121
View File
@@ -0,0 +1,121 @@
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from cryptography.fernet import Fernet
from app.models.metadata import DataSource
from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSourceTest
from app.core.config import settings
# Simple AES-like symmetric encryption for DB passwords
# In production, use a proper KMS
_fernet = Fernet(Fernet.generate_key())
def _encrypt_password(password: str) -> str:
return _fernet.encrypt(password.encode()).decode()
def _decrypt_password(encrypted: str) -> str:
return _fernet.decrypt(encrypted.encode()).decode()
def get_datasource(db: Session, source_id: int) -> Optional[DataSource]:
return db.query(DataSource).filter(DataSource.id == source_id).first()
def list_datasources(
db: Session, keyword: Optional[str] = None, page: int = 1, page_size: int = 20
) -> Tuple[List[DataSource], int]:
query = db.query(DataSource)
if keyword:
query = query.filter(
(DataSource.name.contains(keyword)) | (DataSource.host.contains(keyword))
)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
def create_datasource(db: Session, obj_in: DataSourceCreate, user_id: int) -> DataSource:
db_obj = DataSource(
name=obj_in.name,
source_type=obj_in.source_type,
host=obj_in.host,
port=obj_in.port,
database_name=obj_in.database_name,
username=obj_in.username,
encrypted_password=_encrypt_password(obj_in.password) if obj_in.password else None,
extra_params=obj_in.extra_params,
status=obj_in.status or "active",
dept_id=obj_in.dept_id,
created_by=user_id,
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update_datasource(db: Session, db_obj: DataSource, obj_in: DataSourceUpdate) -> DataSource:
update_data = obj_in.model_dump(exclude_unset=True)
if "password" in update_data and update_data["password"]:
update_data["encrypted_password"] = _encrypt_password(update_data.pop("password"))
else:
update_data.pop("password", None)
for field, value in update_data.items():
setattr(db_obj, field, value)
db.commit()
db.refresh(db_obj)
return db_obj
def delete_datasource(db: Session, source_id: int) -> None:
db_obj = get_datasource(db, source_id)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在")
db.delete(db_obj)
db.commit()
def test_connection(obj_in: DataSourceTest) -> dict:
from sqlalchemy import create_engine, inspect, text
driver_map = {
"mysql": "mysql+pymysql",
"postgresql": "postgresql+psycopg2",
"oracle": "oracle+cx_oracle",
"sqlserver": "mssql+pymssql",
"dm": "dm-python", # placeholder
}
driver = driver_map.get(obj_in.source_type, obj_in.source_type)
if obj_in.source_type == "dm":
# For MVP, mock test for Dameng
return {"success": True, "message": "达梦数据库连接测试通过(模拟)"}
host = obj_in.host or "localhost"
port = obj_in.port or 5432
database = obj_in.database_name or ""
username = obj_in.username or ""
password = obj_in.password or ""
try:
if obj_in.source_type == "postgresql":
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
elif obj_in.source_type == "mysql":
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
elif obj_in.source_type == "oracle":
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
elif obj_in.source_type == "sqlserver":
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
else:
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
engine = create_engine(url, pool_pre_ping=True, connect_args={"connect_timeout": 5})
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
return {"success": True, "message": "连接测试通过"}
except Exception as e:
return {"success": False, "message": f"连接失败: {str(e)}"}
+183
View File
@@ -0,0 +1,183 @@
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.metadata import DataSource, Database, DataTable, DataColumn
from app.services.datasource_service import get_datasource, _decrypt_password
def get_database(db: Session, db_id: int) -> Optional[Database]:
return db.query(Database).filter(Database.id == db_id).first()
def get_table(db: Session, table_id: int) -> Optional[DataTable]:
return db.query(DataTable).filter(DataTable.id == table_id).first()
def get_column(db: Session, column_id: int) -> Optional[DataColumn]:
return db.query(DataColumn).filter(DataColumn.id == column_id).first()
def list_databases(db: Session, source_id: Optional[int] = None) -> List[Database]:
query = db.query(Database)
if source_id:
query = query.filter(Database.source_id == source_id)
return query.all()
def list_tables(db: Session, database_id: Optional[int] = None, keyword: Optional[str] = None) -> Tuple[List[DataTable], int]:
query = db.query(DataTable)
if database_id:
query = query.filter(DataTable.database_id == database_id)
if keyword:
query = query.filter(
(DataTable.name.contains(keyword)) | (DataTable.comment.contains(keyword))
)
return query.all(), query.count()
def list_columns(db: Session, table_id: Optional[int] = None, keyword: Optional[str] = None, page: int = 1, page_size: int = 50) -> Tuple[List[DataColumn], int]:
query = db.query(DataColumn)
if table_id:
query = query.filter(DataColumn.table_id == table_id)
if keyword:
query = query.filter(
(DataColumn.name.contains(keyword)) | (DataColumn.comment.contains(keyword))
)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
def build_tree(db: Session, source_id: Optional[int] = None) -> List[dict]:
sources = db.query(DataSource)
if source_id:
sources = sources.filter(DataSource.id == source_id)
sources = sources.all()
result = []
for s in sources:
source_node = {
"id": s.id,
"name": s.name,
"type": "source",
"children": [],
"meta": {"source_type": s.source_type, "status": s.status},
}
for d in s.databases:
db_node = {
"id": d.id,
"name": d.name,
"type": "database",
"children": [],
"meta": {"charset": d.charset, "table_count": d.table_count},
}
for t in d.tables:
table_node = {
"id": t.id,
"name": t.name,
"type": "table",
"children": [],
"meta": {"comment": t.comment, "row_count": t.row_count, "column_count": t.column_count},
}
db_node["children"].append(table_node)
source_node["children"].append(db_node)
result.append(source_node)
return result
def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
from sqlalchemy import create_engine, inspect, text
import json
source = get_datasource(db, source_id)
if not source:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在")
driver_map = {
"mysql": "mysql+pymysql",
"postgresql": "postgresql+psycopg2",
"oracle": "oracle+cx_oracle",
"sqlserver": "mssql+pymssql",
}
driver = driver_map.get(source.source_type, source.source_type)
if source.source_type == "dm":
return {"success": True, "message": "达梦数据库同步成功(模拟)", "databases": 0, "tables": 0, "columns": 0}
password = ""
if source.encrypted_password:
try:
password = _decrypt_password(source.encrypted_password)
except Exception:
pass
try:
url = f"{driver}://{source.username}:{password}@{source.host}:{source.port}/{source.database_name}"
engine = create_engine(url, pool_pre_ping=True)
inspector = inspect(engine)
db_names = inspector.get_schema_names() or [source.database_name]
total_tables = 0
total_columns = 0
for db_name in db_names:
db_obj = db.query(Database).filter(Database.source_id == source.id, Database.name == db_name).first()
if not db_obj:
db_obj = Database(source_id=source.id, name=db_name)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
table_names = inspector.get_table_names(schema=db_name)
for tname in table_names:
table_obj = db.query(DataTable).filter(DataTable.database_id == db_obj.id, DataTable.name == tname).first()
if not table_obj:
table_obj = DataTable(database_id=db_obj.id, name=tname)
db.add(table_obj)
db.commit()
db.refresh(table_obj)
columns = inspector.get_columns(tname, schema=db_name)
for col in columns:
col_obj = db.query(DataColumn).filter(DataColumn.table_id == table_obj.id, DataColumn.name == col["name"]).first()
if not col_obj:
sample = None
try:
with engine.connect() as conn:
result = conn.execute(text(f'SELECT "{col["name"]}" FROM "{db_name}"."{tname}" LIMIT 5'))
samples = [str(r[0]) for r in result if r[0] is not None]
sample = json.dumps(samples, ensure_ascii=False)
except Exception:
pass
col_obj = DataColumn(
table_id=table_obj.id,
name=col["name"],
data_type=str(col.get("type", "")),
length=col.get("max_length"),
comment=col.get("comment"),
is_nullable=col.get("nullable", True),
sample_data=sample,
)
db.add(col_obj)
total_columns += 1
total_tables += 1
db.commit()
source.status = "active"
db.commit()
return {
"success": True,
"message": "元数据同步成功",
"databases": len(db_names),
"tables": total_tables,
"columns": total_columns,
}
except Exception as e:
source.status = "error"
db.commit()
return {"success": False, "message": f"同步失败: {str(e)}", "databases": 0, "tables": 0, "columns": 0}
+120
View File
@@ -0,0 +1,120 @@
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.project import ClassificationProject, ClassificationTask, ClassificationResult
from app.models.classification import Category, DataLevel
from app.models.metadata import DataColumn, DataTable, Database as MetaDatabase
def get_project(db: Session, project_id: int) -> Optional[ClassificationProject]:
return db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
def list_projects(
db: Session, keyword: Optional[str] = None, page: int = 1, page_size: int = 20
) -> Tuple[List[ClassificationProject], int]:
query = db.query(ClassificationProject)
if keyword:
query = query.filter(ClassificationProject.name.contains(keyword))
total = query.count()
items = query.order_by(ClassificationProject.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return items, total
def create_project(db: Session, name: str, template_id: int, created_by: int, **kwargs) -> ClassificationProject:
db_obj = ClassificationProject(
name=name,
template_id=template_id,
created_by=created_by,
**kwargs,
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update_project(db: Session, db_obj: ClassificationProject, **kwargs) -> ClassificationProject:
for k, v in kwargs.items():
if v is not None:
setattr(db_obj, k, v)
db.commit()
db.refresh(db_obj)
return db_obj
def delete_project(db: Session, project_id: int) -> None:
db_obj = get_project(db, project_id)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
db.delete(db_obj)
db.commit()
def get_project_stats(db: Session, project_id: int) -> dict:
total = db.query(ClassificationResult).filter(ClassificationResult.project_id == project_id).count()
auto_count = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project_id,
ClassificationResult.source == "auto",
).count()
manual_count = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project_id,
ClassificationResult.source == "manual",
).count()
reviewed_count = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project_id,
ClassificationResult.status == "reviewed",
).count()
return {
"total": total,
"auto": auto_count,
"manual": manual_count,
"reviewed": reviewed_count,
}
def list_results(
db: Session,
project_id: Optional[int] = None,
table_id: Optional[int] = None,
status: Optional[str] = None,
keyword: Optional[str] = None,
page: int = 1,
page_size: int = 50,
) -> Tuple[List[ClassificationResult], int]:
query = db.query(ClassificationResult)
if project_id:
query = query.filter(ClassificationResult.project_id == project_id)
if table_id:
query = query.join(DataColumn).filter(DataColumn.table_id == table_id)
if status:
query = query.filter(ClassificationResult.status == status)
if keyword:
query = query.join(DataColumn).filter(
(DataColumn.name.contains(keyword)) | (DataColumn.comment.contains(keyword))
)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
def update_result_label(
db: Session,
result_id: int,
category_id: int,
level_id: int,
labeler_id: int,
) -> ClassificationResult:
result = db.query(ClassificationResult).filter(ClassificationResult.id == result_id).first()
if not result:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="结果不存在")
result.category_id = category_id
result.level_id = level_id
result.labeler_id = labeler_id
result.source = "manual"
result.status = "manual"
result.label_time = __import__('datetime').datetime.utcnow()
db.commit()
db.refresh(result)
return result
+127
View File
@@ -0,0 +1,127 @@
from typing import Optional, List
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.user import User, Role, Dept, UserRole
from app.schemas.user import UserCreate, UserUpdate
from app.core.security import get_password_hash
def get_user_by_id(db: Session, user_id: int) -> Optional[User]:
return db.query(User).filter(User.id == user_id).first()
def get_user_by_username(db: Session, username: str) -> Optional[User]:
return db.query(User).filter(User.username == username).first()
def create_user(db: Session, obj_in: UserCreate) -> User:
if get_user_by_username(db, obj_in.username):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在")
db_obj = User(
username=obj_in.username,
email=obj_in.email,
hashed_password=get_password_hash(obj_in.password),
real_name=obj_in.real_name,
phone=obj_in.phone,
dept_id=obj_in.dept_id,
is_active=obj_in.is_active,
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
if obj_in.role_ids:
for rid in obj_in.role_ids:
role = db.query(Role).filter(Role.id == rid).first()
if role:
db.add(UserRole(user_id=db_obj.id, role_id=rid))
db.commit()
db.refresh(db_obj)
return db_obj
def update_user(db: Session, db_obj: User, obj_in: UserUpdate) -> User:
update_data = obj_in.model_dump(exclude_unset=True)
role_ids = update_data.pop("role_ids", None)
for field, value in update_data.items():
setattr(db_obj, field, value)
if role_ids is not None:
db.query(UserRole).filter(UserRole.user_id == db_obj.id).delete()
for rid in role_ids:
role = db.query(Role).filter(Role.id == rid).first()
if role:
db.add(UserRole(user_id=db_obj.id, role_id=rid))
db.commit()
db.refresh(db_obj)
return db_obj
def delete_user(db: Session, user_id: int) -> None:
user = get_user_by_id(db, user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
if user.is_superuser:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="不能删除超级管理员")
db.delete(user)
db.commit()
def list_users(db: Session, keyword: Optional[str] = None, page: int = 1, page_size: int = 20):
query = db.query(User)
if keyword:
query = query.filter(
(User.username.contains(keyword))
| (User.real_name.contains(keyword))
| (User.email.contains(keyword))
)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
def create_initial_data(db: Session):
# Create default roles
default_roles = [
{"name": "超级管理员", "code": "superadmin", "description": "系统超级管理员"},
{"name": "管理员", "code": "admin", "description": "系统管理员"},
{"name": "项目负责人", "code": "project_manager", "description": "分类分级项目负责人"},
{"name": "打标员", "code": "labeler", "description": "数据打标人员"},
{"name": "审核员", "code": "reviewer", "description": "结果审核人员"},
{"name": "访客", "code": "guest", "description": "只读访客"},
]
for r in default_roles:
if not db.query(Role).filter(Role.code == r["code"]).first():
db.add(Role(**r))
# Create root dept
if not db.query(Dept).filter(Dept.id == 1).first():
db.add(Dept(id=1, name="根部门", parent_id=None, sort_order=0))
db.commit()
# Create superuser
from app.core.config import settings
if not get_user_by_username(db, settings.FIRST_SUPERUSER_USERNAME):
superuser = User(
username=settings.FIRST_SUPERUSER_USERNAME,
email=settings.FIRST_SUPERUSER_EMAIL,
hashed_password=get_password_hash(settings.FIRST_SUPERUSER_PASSWORD),
real_name="超级管理员",
is_active=True,
is_superuser=True,
dept_id=1,
)
db.add(superuser)
db.commit()
db.refresh(superuser)
superadmin_role = db.query(Role).filter(Role.code == "superadmin").first()
if superadmin_role:
db.add(UserRole(user_id=superuser.id, role_id=superadmin_role.id))
db.commit()