Files
prop-data-guard/backend/app/services/classification_engine.py
T
hiderfong 6d70520e79 feat: 全量功能模块开发与集成测试修复
- 新增后端模块:Alert、APIAsset、Compliance、Lineage、Masking、Risk、SchemaChange、Unstructured、Watermark
- 新增前端模块页面与API接口
- 新增Alembic迁移脚本(002-014)覆盖全量业务表
- 新增测试数据生成脚本与集成测试脚本
- 修复metadata模型JSON类型导入缺失导致启动失败的问题
- 修复前端Alert/APIAsset页面request模块路径错误
- 更新docker-compose与开发计划文档
2026-04-25 08:51:38 +08:00

174 lines
6.1 KiB
Python

import re
import json
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from app.models.classification import RecognitionRule, Category, DataLevel
from app.models.metadata import DataColumn, DataTable
from app.models.project import ClassificationProject, ClassificationResult, ResultStatus
def match_rule(rule: RecognitionRule, column: DataColumn) -> Tuple[bool, float]:
"""Match a single rule against a column. Returns (matched, confidence)."""
targets = []
if rule.target_field == "column_name":
targets = [column.name]
elif rule.target_field == "comment":
targets = [column.comment or ""]
elif rule.target_field == "sample_data":
targets = []
if column.sample_data:
try:
samples = json.loads(column.sample_data)
if isinstance(samples, list):
targets = [str(s) for s in samples]
except Exception:
targets = [column.sample_data]
if not targets:
return False, 0.0
if rule.rule_type == "regex":
try:
pattern = re.compile(rule.rule_content)
for t in targets:
if pattern.search(t):
return True, 0.85
except re.error:
return False, 0.0
elif rule.rule_type == "keyword":
keywords = [k.strip().lower() for k in rule.rule_content.split(",")]
for t in targets:
t_lower = t.lower()
for kw in keywords:
if kw in t_lower:
return True, 0.75
elif rule.rule_type == "enum":
enums = [e.strip().lower() for e in rule.rule_content.split(",")]
for t in targets:
if t.strip().lower() in enums:
return True, 0.90
elif rule.rule_type == "similarity":
benchmarks = [b.strip().lower() for b in rule.rule_content.split(",") if b.strip()]
if not benchmarks:
return False, 0.0
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
texts = [t.lower() for t in targets] + benchmarks
try:
vectorizer = TfidfVectorizer(analyzer="char_wb", ngram_range=(2, 3))
tfidf = vectorizer.fit_transform(texts)
target_vecs = tfidf[:len(targets)]
bench_vecs = tfidf[len(targets):]
sim_matrix = cosine_similarity(target_vecs, bench_vecs)
max_sim = float(sim_matrix.max())
if max_sim >= 0.75:
return True, round(min(max_sim, 0.99), 4)
except Exception:
pass
return False, 0.0
def run_auto_classification(
db: Session,
project_id: int,
source_ids: Optional[List[int]] = None,
progress_callback=None,
) -> dict:
"""Run automatic classification for a project.
Args:
progress_callback: Optional callable(scanned, matched, total) to report progress.
"""
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if not project:
return {"success": False, "message": "项目不存在"}
# Get active rules from project's template
rules = db.query(RecognitionRule).filter(
RecognitionRule.is_active == True,
RecognitionRule.template_id == project.template_id,
).order_by(RecognitionRule.priority).all()
if not rules:
return {"success": False, "message": "没有可用的识别规则"}
# Get columns to classify
from app.services.metadata_service import list_tables, list_columns
columns_query = db.query(DataColumn).join(DataTable).join(app.models.metadata.Database)
if source_ids:
columns_query = columns_query.filter(app.models.metadata.Database.source_id.in_(source_ids))
elif project.target_source_ids:
sids = [int(x) for x in project.target_source_ids.split(",") if x]
columns_query = columns_query.filter(app.models.metadata.Database.source_id.in_(sids))
columns = columns_query.all()
matched_count = 0
total = len(columns)
report_interval = max(1, total // 20) # report ~20 times
for idx, col in enumerate(columns):
# Check if already has a result for this project
existing = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project_id,
ClassificationResult.column_id == col.id,
).first()
best_rule = None
best_confidence = 0.0
for rule in rules:
matched, confidence = match_rule(rule, col)
if matched and confidence > best_confidence:
best_confidence = confidence
best_rule = rule
if best_rule:
matched_count += 1
if existing:
existing.category_id = best_rule.category_id
existing.level_id = best_rule.level_id
existing.confidence = best_confidence
existing.source = "auto"
existing.status = ResultStatus.AUTO.value
else:
result = ClassificationResult(
project_id=project_id,
column_id=col.id,
category_id=best_rule.category_id,
level_id=best_rule.level_id,
source="auto",
confidence=best_confidence,
status=ResultStatus.AUTO.value,
)
db.add(result)
# Increment hit count
best_rule.hit_count = (best_rule.hit_count or 0) + 1
# Report progress periodically
if progress_callback and (idx + 1) % report_interval == 0:
progress_callback(scanned=idx + 1, matched=matched_count, total=total)
db.commit()
# Final progress report
if progress_callback:
progress_callback(scanned=total, matched=matched_count, total=total)
return {
"success": True,
"message": f"自动分类完成,共扫描 {total} 个字段,命中 {matched_count}",
"scanned": total,
"matched": matched_count,
}
import app.models.metadata