feat: 全量功能模块开发与集成测试修复
- 新增后端模块:Alert、APIAsset、Compliance、Lineage、Masking、Risk、SchemaChange、Unstructured、Watermark - 新增前端模块页面与API接口 - 新增Alembic迁移脚本(002-014)覆盖全量业务表 - 新增测试数据生成脚本与集成测试脚本 - 修复metadata模型JSON类型导入缺失导致启动失败的问题 - 修复前端Alert/APIAsset页面request模块路径错误 - 更新docker-compose与开发计划文档
This commit is contained in:
@@ -0,0 +1,195 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import accuracy_score
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.project import ClassificationResult
|
||||
from app.models.classification import Category
|
||||
from app.models.ml import MLModelVersion
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODELS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "ml_models")
|
||||
os.makedirs(MODELS_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def _build_text_features(column_name: str, comment: Optional[str], sample_data: Optional[str]) -> str:
|
||||
parts = [column_name]
|
||||
if comment:
|
||||
parts.append(comment)
|
||||
if sample_data:
|
||||
try:
|
||||
samples = json.loads(sample_data)
|
||||
if isinstance(samples, list):
|
||||
parts.extend([str(s) for s in samples[:5]])
|
||||
except Exception:
|
||||
parts.append(sample_data)
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _fetch_training_data(db: Session, min_samples_per_class: int = 5):
|
||||
results = (
|
||||
db.query(ClassificationResult)
|
||||
.filter(ClassificationResult.source == "manual")
|
||||
.filter(ClassificationResult.category_id.isnot(None))
|
||||
.all()
|
||||
)
|
||||
texts = []
|
||||
labels = []
|
||||
for r in results:
|
||||
if r.column:
|
||||
text = _build_text_features(r.column.name, r.column.comment, r.column.sample_data)
|
||||
texts.append(text)
|
||||
labels.append(r.category_id)
|
||||
from collections import Counter
|
||||
counts = Counter(labels)
|
||||
valid_classes = {c for c, n in counts.items() if n >= min_samples_per_class}
|
||||
filtered_texts = []
|
||||
filtered_labels = []
|
||||
for t, l in zip(texts, labels):
|
||||
if l in valid_classes:
|
||||
filtered_texts.append(t)
|
||||
filtered_labels.append(l)
|
||||
return filtered_texts, filtered_labels, len(filtered_labels)
|
||||
|
||||
|
||||
def train_model(db: Session, model_name: Optional[str] = None, algorithm: str = "logistic_regression", test_size: float = 0.2):
|
||||
texts, labels, total = _fetch_training_data(db)
|
||||
if total < 20:
|
||||
logger.warning("Not enough training data (need >= 20, got %d)", total)
|
||||
return None
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
texts, labels, test_size=test_size, random_state=42, stratify=labels
|
||||
)
|
||||
vectorizer = TfidfVectorizer(analyzer="char_wb", ngram_range=(2, 4), max_features=5000)
|
||||
X_train_vec = vectorizer.fit_transform(X_train)
|
||||
X_test_vec = vectorizer.transform(X_test)
|
||||
if algorithm == "logistic_regression":
|
||||
clf = LogisticRegression(max_iter=1000, multi_class="multinomial", solver="lbfgs")
|
||||
elif algorithm == "random_forest":
|
||||
clf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
|
||||
else:
|
||||
clf = LogisticRegression(max_iter=1000, multi_class="multinomial", solver="lbfgs")
|
||||
clf.fit(X_train_vec, y_train)
|
||||
y_pred = clf.predict(X_test_vec)
|
||||
acc = accuracy_score(y_test, y_pred)
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
name = model_name or f"model_{timestamp}"
|
||||
model_path = os.path.join(MODELS_DIR, f"{name}_clf.joblib")
|
||||
vec_path = os.path.join(MODELS_DIR, f"{name}_tfidf.joblib")
|
||||
joblib.dump(clf, model_path)
|
||||
joblib.dump(vectorizer, vec_path)
|
||||
db.query(MLModelVersion).filter(MLModelVersion.is_active == True).update({"is_active": False})
|
||||
mv = MLModelVersion(
|
||||
name=name,
|
||||
model_path=model_path,
|
||||
vectorizer_path=vec_path,
|
||||
accuracy=acc,
|
||||
train_samples=total,
|
||||
is_active=True,
|
||||
description=f"Algorithm: {algorithm}, test_accuracy: {acc:.4f}",
|
||||
)
|
||||
db.add(mv)
|
||||
db.commit()
|
||||
db.refresh(mv)
|
||||
logger.info("Trained model %s with accuracy %.4f on %d samples", name, acc, total)
|
||||
return mv
|
||||
|
||||
|
||||
def _get_active_model(db: Session):
|
||||
mv = db.query(MLModelVersion).filter(MLModelVersion.is_active == True).first()
|
||||
if not mv or not os.path.exists(mv.model_path) or not os.path.exists(mv.vectorizer_path):
|
||||
return None
|
||||
clf = joblib.load(mv.model_path)
|
||||
vectorizer = joblib.load(mv.vectorizer_path)
|
||||
return clf, vectorizer, mv
|
||||
|
||||
|
||||
def predict_categories(db: Session, texts: List[str], top_k: int = 3):
|
||||
model_tuple = _get_active_model(db)
|
||||
if not model_tuple:
|
||||
return [[] for _ in texts]
|
||||
clf, vectorizer, mv = model_tuple
|
||||
X = vectorizer.transform(texts)
|
||||
if hasattr(clf, "predict_proba"):
|
||||
probs = clf.predict_proba(X)
|
||||
else:
|
||||
preds = clf.predict(X)
|
||||
return [[{"category_id": int(p), "confidence": 1.0}] for p in preds]
|
||||
classes = [int(c) for c in clf.classes_]
|
||||
results = []
|
||||
for prob in probs:
|
||||
top_idx = np.argsort(prob)[::-1][:top_k]
|
||||
suggestions = []
|
||||
for idx in top_idx:
|
||||
cat_id = classes[idx]
|
||||
confidence = float(prob[idx])
|
||||
if confidence > 0.01:
|
||||
suggestions.append({"category_id": cat_id, "confidence": round(confidence, 4)})
|
||||
results.append(suggestions)
|
||||
return results
|
||||
|
||||
|
||||
def suggest_for_project_columns(db: Session, project_id: int, column_ids: Optional[List[int]] = None, top_k: int = 3):
|
||||
from app.models.project import ClassificationProject
|
||||
from app.models.metadata import DataColumn
|
||||
|
||||
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
|
||||
if not project:
|
||||
return {"success": False, "message": "项目不存在"}
|
||||
|
||||
query = db.query(DataColumn).join(
|
||||
ClassificationResult,
|
||||
(ClassificationResult.column_id == DataColumn.id) & (ClassificationResult.project_id == project_id),
|
||||
isouter=True,
|
||||
)
|
||||
if column_ids:
|
||||
query = query.filter(DataColumn.id.in_(column_ids))
|
||||
|
||||
columns = query.all()
|
||||
texts = []
|
||||
col_map = []
|
||||
for col in columns:
|
||||
texts.append(_build_text_features(col.name, col.comment, col.sample_data))
|
||||
col_map.append(col)
|
||||
|
||||
if not texts:
|
||||
return {"success": True, "suggestions": [], "message": "没有可预测的字段"}
|
||||
|
||||
predictions = predict_categories(db, texts, top_k=top_k)
|
||||
suggestions = []
|
||||
all_category_ids = set()
|
||||
for col, preds in zip(col_map, predictions):
|
||||
for p in preds:
|
||||
all_category_ids.add(p["category_id"])
|
||||
|
||||
categories = {c.id: c for c in db.query(Category).filter(Category.id.in_(list(all_category_ids))).all()}
|
||||
|
||||
for col, preds in zip(col_map, predictions):
|
||||
item = {
|
||||
"column_id": col.id,
|
||||
"column_name": col.name,
|
||||
"table_name": col.table.name if col.table else None,
|
||||
"suggestions": [],
|
||||
}
|
||||
for p in preds:
|
||||
cat = categories.get(p["category_id"])
|
||||
item["suggestions"].append({
|
||||
"category_id": p["category_id"],
|
||||
"category_name": cat.name if cat else None,
|
||||
"category_code": cat.code if cat else None,
|
||||
"confidence": p["confidence"],
|
||||
})
|
||||
suggestions.append(item)
|
||||
|
||||
return {"success": True, "suggestions": suggestions}
|
||||
Reference in New Issue
Block a user