6d70520e79
- 新增后端模块:Alert、APIAsset、Compliance、Lineage、Masking、Risk、SchemaChange、Unstructured、Watermark - 新增前端模块页面与API接口 - 新增Alembic迁移脚本(002-014)覆盖全量业务表 - 新增测试数据生成脚本与集成测试脚本 - 修复metadata模型JSON类型导入缺失导致启动失败的问题 - 修复前端Alert/APIAsset页面request模块路径错误 - 更新docker-compose与开发计划文档
196 lines
7.1 KiB
Python
196 lines
7.1 KiB
Python
import os
|
|
import json
|
|
import logging
|
|
from typing import List, Optional, Tuple
|
|
from datetime import datetime
|
|
|
|
import joblib
|
|
import numpy as np
|
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.metrics import accuracy_score
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.models.project import ClassificationResult
|
|
from app.models.classification import Category
|
|
from app.models.ml import MLModelVersion
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MODELS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "ml_models")
|
|
os.makedirs(MODELS_DIR, exist_ok=True)
|
|
|
|
|
|
def _build_text_features(column_name: str, comment: Optional[str], sample_data: Optional[str]) -> str:
|
|
parts = [column_name]
|
|
if comment:
|
|
parts.append(comment)
|
|
if sample_data:
|
|
try:
|
|
samples = json.loads(sample_data)
|
|
if isinstance(samples, list):
|
|
parts.extend([str(s) for s in samples[:5]])
|
|
except Exception:
|
|
parts.append(sample_data)
|
|
return " ".join(parts)
|
|
|
|
|
|
def _fetch_training_data(db: Session, min_samples_per_class: int = 5):
|
|
results = (
|
|
db.query(ClassificationResult)
|
|
.filter(ClassificationResult.source == "manual")
|
|
.filter(ClassificationResult.category_id.isnot(None))
|
|
.all()
|
|
)
|
|
texts = []
|
|
labels = []
|
|
for r in results:
|
|
if r.column:
|
|
text = _build_text_features(r.column.name, r.column.comment, r.column.sample_data)
|
|
texts.append(text)
|
|
labels.append(r.category_id)
|
|
from collections import Counter
|
|
counts = Counter(labels)
|
|
valid_classes = {c for c, n in counts.items() if n >= min_samples_per_class}
|
|
filtered_texts = []
|
|
filtered_labels = []
|
|
for t, l in zip(texts, labels):
|
|
if l in valid_classes:
|
|
filtered_texts.append(t)
|
|
filtered_labels.append(l)
|
|
return filtered_texts, filtered_labels, len(filtered_labels)
|
|
|
|
|
|
def train_model(db: Session, model_name: Optional[str] = None, algorithm: str = "logistic_regression", test_size: float = 0.2):
|
|
texts, labels, total = _fetch_training_data(db)
|
|
if total < 20:
|
|
logger.warning("Not enough training data (need >= 20, got %d)", total)
|
|
return None
|
|
X_train, X_test, y_train, y_test = train_test_split(
|
|
texts, labels, test_size=test_size, random_state=42, stratify=labels
|
|
)
|
|
vectorizer = TfidfVectorizer(analyzer="char_wb", ngram_range=(2, 4), max_features=5000)
|
|
X_train_vec = vectorizer.fit_transform(X_train)
|
|
X_test_vec = vectorizer.transform(X_test)
|
|
if algorithm == "logistic_regression":
|
|
clf = LogisticRegression(max_iter=1000, multi_class="multinomial", solver="lbfgs")
|
|
elif algorithm == "random_forest":
|
|
clf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
|
|
else:
|
|
clf = LogisticRegression(max_iter=1000, multi_class="multinomial", solver="lbfgs")
|
|
clf.fit(X_train_vec, y_train)
|
|
y_pred = clf.predict(X_test_vec)
|
|
acc = accuracy_score(y_test, y_pred)
|
|
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
|
name = model_name or f"model_{timestamp}"
|
|
model_path = os.path.join(MODELS_DIR, f"{name}_clf.joblib")
|
|
vec_path = os.path.join(MODELS_DIR, f"{name}_tfidf.joblib")
|
|
joblib.dump(clf, model_path)
|
|
joblib.dump(vectorizer, vec_path)
|
|
db.query(MLModelVersion).filter(MLModelVersion.is_active == True).update({"is_active": False})
|
|
mv = MLModelVersion(
|
|
name=name,
|
|
model_path=model_path,
|
|
vectorizer_path=vec_path,
|
|
accuracy=acc,
|
|
train_samples=total,
|
|
is_active=True,
|
|
description=f"Algorithm: {algorithm}, test_accuracy: {acc:.4f}",
|
|
)
|
|
db.add(mv)
|
|
db.commit()
|
|
db.refresh(mv)
|
|
logger.info("Trained model %s with accuracy %.4f on %d samples", name, acc, total)
|
|
return mv
|
|
|
|
|
|
def _get_active_model(db: Session):
|
|
mv = db.query(MLModelVersion).filter(MLModelVersion.is_active == True).first()
|
|
if not mv or not os.path.exists(mv.model_path) or not os.path.exists(mv.vectorizer_path):
|
|
return None
|
|
clf = joblib.load(mv.model_path)
|
|
vectorizer = joblib.load(mv.vectorizer_path)
|
|
return clf, vectorizer, mv
|
|
|
|
|
|
def predict_categories(db: Session, texts: List[str], top_k: int = 3):
|
|
model_tuple = _get_active_model(db)
|
|
if not model_tuple:
|
|
return [[] for _ in texts]
|
|
clf, vectorizer, mv = model_tuple
|
|
X = vectorizer.transform(texts)
|
|
if hasattr(clf, "predict_proba"):
|
|
probs = clf.predict_proba(X)
|
|
else:
|
|
preds = clf.predict(X)
|
|
return [[{"category_id": int(p), "confidence": 1.0}] for p in preds]
|
|
classes = [int(c) for c in clf.classes_]
|
|
results = []
|
|
for prob in probs:
|
|
top_idx = np.argsort(prob)[::-1][:top_k]
|
|
suggestions = []
|
|
for idx in top_idx:
|
|
cat_id = classes[idx]
|
|
confidence = float(prob[idx])
|
|
if confidence > 0.01:
|
|
suggestions.append({"category_id": cat_id, "confidence": round(confidence, 4)})
|
|
results.append(suggestions)
|
|
return results
|
|
|
|
|
|
def suggest_for_project_columns(db: Session, project_id: int, column_ids: Optional[List[int]] = None, top_k: int = 3):
|
|
from app.models.project import ClassificationProject
|
|
from app.models.metadata import DataColumn
|
|
|
|
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
|
|
if not project:
|
|
return {"success": False, "message": "项目不存在"}
|
|
|
|
query = db.query(DataColumn).join(
|
|
ClassificationResult,
|
|
(ClassificationResult.column_id == DataColumn.id) & (ClassificationResult.project_id == project_id),
|
|
isouter=True,
|
|
)
|
|
if column_ids:
|
|
query = query.filter(DataColumn.id.in_(column_ids))
|
|
|
|
columns = query.all()
|
|
texts = []
|
|
col_map = []
|
|
for col in columns:
|
|
texts.append(_build_text_features(col.name, col.comment, col.sample_data))
|
|
col_map.append(col)
|
|
|
|
if not texts:
|
|
return {"success": True, "suggestions": [], "message": "没有可预测的字段"}
|
|
|
|
predictions = predict_categories(db, texts, top_k=top_k)
|
|
suggestions = []
|
|
all_category_ids = set()
|
|
for col, preds in zip(col_map, predictions):
|
|
for p in preds:
|
|
all_category_ids.add(p["category_id"])
|
|
|
|
categories = {c.id: c for c in db.query(Category).filter(Category.id.in_(list(all_category_ids))).all()}
|
|
|
|
for col, preds in zip(col_map, predictions):
|
|
item = {
|
|
"column_id": col.id,
|
|
"column_name": col.name,
|
|
"table_name": col.table.name if col.table else None,
|
|
"suggestions": [],
|
|
}
|
|
for p in preds:
|
|
cat = categories.get(p["category_id"])
|
|
item["suggestions"].append({
|
|
"category_id": p["category_id"],
|
|
"category_name": cat.name if cat else None,
|
|
"category_code": cat.code if cat else None,
|
|
"confidence": p["confidence"],
|
|
})
|
|
suggestions.append(item)
|
|
|
|
return {"success": True, "suggestions": suggestions}
|