import re import json from typing import List, Optional, Tuple from sqlalchemy.orm import Session from app.models.classification import RecognitionRule, Category, DataLevel from app.models.metadata import DataColumn, DataTable from app.models.project import ClassificationProject, ClassificationResult, ResultStatus def match_rule(rule: RecognitionRule, column: DataColumn) -> Tuple[bool, float]: """Match a single rule against a column. Returns (matched, confidence).""" targets = [] if rule.target_field == "column_name": targets = [column.name] elif rule.target_field == "comment": targets = [column.comment or ""] elif rule.target_field == "sample_data": targets = [] if column.sample_data: try: samples = json.loads(column.sample_data) if isinstance(samples, list): targets = [str(s) for s in samples] except Exception: targets = [column.sample_data] if not targets: return False, 0.0 if rule.rule_type == "regex": try: pattern = re.compile(rule.rule_content) for t in targets: if pattern.search(t): return True, 0.85 except re.error: return False, 0.0 elif rule.rule_type == "keyword": keywords = [k.strip().lower() for k in rule.rule_content.split(",")] for t in targets: t_lower = t.lower() for kw in keywords: if kw in t_lower: return True, 0.75 elif rule.rule_type == "enum": enums = [e.strip().lower() for e in rule.rule_content.split(",")] for t in targets: if t.strip().lower() in enums: return True, 0.90 elif rule.rule_type == "similarity": benchmarks = [b.strip().lower() for b in rule.rule_content.split(",") if b.strip()] if not benchmarks: return False, 0.0 from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity texts = [t.lower() for t in targets] + benchmarks try: vectorizer = TfidfVectorizer(analyzer="char_wb", ngram_range=(2, 3)) tfidf = vectorizer.fit_transform(texts) target_vecs = tfidf[:len(targets)] bench_vecs = tfidf[len(targets):] sim_matrix = cosine_similarity(target_vecs, bench_vecs) max_sim = float(sim_matrix.max()) if max_sim >= 0.75: return True, round(min(max_sim, 0.99), 4) except Exception: pass return False, 0.0 def run_auto_classification( db: Session, project_id: int, source_ids: Optional[List[int]] = None, progress_callback=None, ) -> dict: """Run automatic classification for a project. Args: progress_callback: Optional callable(scanned, matched, total) to report progress. """ project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first() if not project: return {"success": False, "message": "项目不存在"} # Get active rules from project's template rules = db.query(RecognitionRule).filter( RecognitionRule.is_active == True, RecognitionRule.template_id == project.template_id, ).order_by(RecognitionRule.priority).all() if not rules: return {"success": False, "message": "没有可用的识别规则"} # Get columns to classify from app.services.metadata_service import list_tables, list_columns columns_query = db.query(DataColumn).join(DataTable).join(app.models.metadata.Database) if source_ids: columns_query = columns_query.filter(app.models.metadata.Database.source_id.in_(source_ids)) elif project.target_source_ids: sids = [int(x) for x in project.target_source_ids.split(",") if x] columns_query = columns_query.filter(app.models.metadata.Database.source_id.in_(sids)) columns = columns_query.all() matched_count = 0 total = len(columns) report_interval = max(1, total // 20) # report ~20 times for idx, col in enumerate(columns): # Check if already has a result for this project existing = db.query(ClassificationResult).filter( ClassificationResult.project_id == project_id, ClassificationResult.column_id == col.id, ).first() best_rule = None best_confidence = 0.0 for rule in rules: matched, confidence = match_rule(rule, col) if matched and confidence > best_confidence: best_confidence = confidence best_rule = rule if best_rule: matched_count += 1 if existing: existing.category_id = best_rule.category_id existing.level_id = best_rule.level_id existing.confidence = best_confidence existing.source = "auto" existing.status = ResultStatus.AUTO.value else: result = ClassificationResult( project_id=project_id, column_id=col.id, category_id=best_rule.category_id, level_id=best_rule.level_id, source="auto", confidence=best_confidence, status=ResultStatus.AUTO.value, ) db.add(result) # Increment hit count best_rule.hit_count = (best_rule.hit_count or 0) + 1 # Report progress periodically if progress_callback and (idx + 1) % report_interval == 0: progress_callback(scanned=idx + 1, matched=matched_count, total=total) db.commit() # Final progress report if progress_callback: progress_callback(scanned=total, matched=matched_count, total=total) return { "success": True, "message": f"自动分类完成,共扫描 {total} 个字段,命中 {matched_count} 个", "scanned": total, "matched": matched_count, } import app.models.metadata