Files
prop-data-guard/backend/app/services/ml_service.py
T
hiderfong 6d70520e79 feat: 全量功能模块开发与集成测试修复
- 新增后端模块:Alert、APIAsset、Compliance、Lineage、Masking、Risk、SchemaChange、Unstructured、Watermark
- 新增前端模块页面与API接口
- 新增Alembic迁移脚本(002-014)覆盖全量业务表
- 新增测试数据生成脚本与集成测试脚本
- 修复metadata模型JSON类型导入缺失导致启动失败的问题
- 修复前端Alert/APIAsset页面request模块路径错误
- 更新docker-compose与开发计划文档
2026-04-25 08:51:38 +08:00

196 lines
7.1 KiB
Python

import os
import json
import logging
from typing import List, Optional, Tuple
from datetime import datetime
import joblib
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sqlalchemy.orm import Session
from app.models.project import ClassificationResult
from app.models.classification import Category
from app.models.ml import MLModelVersion
logger = logging.getLogger(__name__)
MODELS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "ml_models")
os.makedirs(MODELS_DIR, exist_ok=True)
def _build_text_features(column_name: str, comment: Optional[str], sample_data: Optional[str]) -> str:
parts = [column_name]
if comment:
parts.append(comment)
if sample_data:
try:
samples = json.loads(sample_data)
if isinstance(samples, list):
parts.extend([str(s) for s in samples[:5]])
except Exception:
parts.append(sample_data)
return " ".join(parts)
def _fetch_training_data(db: Session, min_samples_per_class: int = 5):
results = (
db.query(ClassificationResult)
.filter(ClassificationResult.source == "manual")
.filter(ClassificationResult.category_id.isnot(None))
.all()
)
texts = []
labels = []
for r in results:
if r.column:
text = _build_text_features(r.column.name, r.column.comment, r.column.sample_data)
texts.append(text)
labels.append(r.category_id)
from collections import Counter
counts = Counter(labels)
valid_classes = {c for c, n in counts.items() if n >= min_samples_per_class}
filtered_texts = []
filtered_labels = []
for t, l in zip(texts, labels):
if l in valid_classes:
filtered_texts.append(t)
filtered_labels.append(l)
return filtered_texts, filtered_labels, len(filtered_labels)
def train_model(db: Session, model_name: Optional[str] = None, algorithm: str = "logistic_regression", test_size: float = 0.2):
texts, labels, total = _fetch_training_data(db)
if total < 20:
logger.warning("Not enough training data (need >= 20, got %d)", total)
return None
X_train, X_test, y_train, y_test = train_test_split(
texts, labels, test_size=test_size, random_state=42, stratify=labels
)
vectorizer = TfidfVectorizer(analyzer="char_wb", ngram_range=(2, 4), max_features=5000)
X_train_vec = vectorizer.fit_transform(X_train)
X_test_vec = vectorizer.transform(X_test)
if algorithm == "logistic_regression":
clf = LogisticRegression(max_iter=1000, multi_class="multinomial", solver="lbfgs")
elif algorithm == "random_forest":
clf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
else:
clf = LogisticRegression(max_iter=1000, multi_class="multinomial", solver="lbfgs")
clf.fit(X_train_vec, y_train)
y_pred = clf.predict(X_test_vec)
acc = accuracy_score(y_test, y_pred)
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
name = model_name or f"model_{timestamp}"
model_path = os.path.join(MODELS_DIR, f"{name}_clf.joblib")
vec_path = os.path.join(MODELS_DIR, f"{name}_tfidf.joblib")
joblib.dump(clf, model_path)
joblib.dump(vectorizer, vec_path)
db.query(MLModelVersion).filter(MLModelVersion.is_active == True).update({"is_active": False})
mv = MLModelVersion(
name=name,
model_path=model_path,
vectorizer_path=vec_path,
accuracy=acc,
train_samples=total,
is_active=True,
description=f"Algorithm: {algorithm}, test_accuracy: {acc:.4f}",
)
db.add(mv)
db.commit()
db.refresh(mv)
logger.info("Trained model %s with accuracy %.4f on %d samples", name, acc, total)
return mv
def _get_active_model(db: Session):
mv = db.query(MLModelVersion).filter(MLModelVersion.is_active == True).first()
if not mv or not os.path.exists(mv.model_path) or not os.path.exists(mv.vectorizer_path):
return None
clf = joblib.load(mv.model_path)
vectorizer = joblib.load(mv.vectorizer_path)
return clf, vectorizer, mv
def predict_categories(db: Session, texts: List[str], top_k: int = 3):
model_tuple = _get_active_model(db)
if not model_tuple:
return [[] for _ in texts]
clf, vectorizer, mv = model_tuple
X = vectorizer.transform(texts)
if hasattr(clf, "predict_proba"):
probs = clf.predict_proba(X)
else:
preds = clf.predict(X)
return [[{"category_id": int(p), "confidence": 1.0}] for p in preds]
classes = [int(c) for c in clf.classes_]
results = []
for prob in probs:
top_idx = np.argsort(prob)[::-1][:top_k]
suggestions = []
for idx in top_idx:
cat_id = classes[idx]
confidence = float(prob[idx])
if confidence > 0.01:
suggestions.append({"category_id": cat_id, "confidence": round(confidence, 4)})
results.append(suggestions)
return results
def suggest_for_project_columns(db: Session, project_id: int, column_ids: Optional[List[int]] = None, top_k: int = 3):
from app.models.project import ClassificationProject
from app.models.metadata import DataColumn
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if not project:
return {"success": False, "message": "项目不存在"}
query = db.query(DataColumn).join(
ClassificationResult,
(ClassificationResult.column_id == DataColumn.id) & (ClassificationResult.project_id == project_id),
isouter=True,
)
if column_ids:
query = query.filter(DataColumn.id.in_(column_ids))
columns = query.all()
texts = []
col_map = []
for col in columns:
texts.append(_build_text_features(col.name, col.comment, col.sample_data))
col_map.append(col)
if not texts:
return {"success": True, "suggestions": [], "message": "没有可预测的字段"}
predictions = predict_categories(db, texts, top_k=top_k)
suggestions = []
all_category_ids = set()
for col, preds in zip(col_map, predictions):
for p in preds:
all_category_ids.add(p["category_id"])
categories = {c.id: c for c in db.query(Category).filter(Category.id.in_(list(all_category_ids))).all()}
for col, preds in zip(col_map, predictions):
item = {
"column_id": col.id,
"column_name": col.name,
"table_name": col.table.name if col.table else None,
"suggestions": [],
}
for p in preds:
cat = categories.get(p["category_id"])
item["suggestions"].append({
"category_id": p["category_id"],
"category_name": cat.name if cat else None,
"category_code": cat.code if cat else None,
"confidence": p["confidence"],
})
suggestions.append(item)
return {"success": True, "suggestions": suggestions}