import os import json import logging from typing import List, Optional, Tuple from datetime import datetime import joblib import numpy as np from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.linear_model import LogisticRegression from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score from sqlalchemy.orm import Session from app.models.project import ClassificationResult from app.models.classification import Category from app.models.ml import MLModelVersion logger = logging.getLogger(__name__) MODELS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "ml_models") os.makedirs(MODELS_DIR, exist_ok=True) def _build_text_features(column_name: str, comment: Optional[str], sample_data: Optional[str]) -> str: parts = [column_name] if comment: parts.append(comment) if sample_data: try: samples = json.loads(sample_data) if isinstance(samples, list): parts.extend([str(s) for s in samples[:5]]) except Exception: parts.append(sample_data) return " ".join(parts) def _fetch_training_data(db: Session, min_samples_per_class: int = 5): results = ( db.query(ClassificationResult) .filter(ClassificationResult.source == "manual") .filter(ClassificationResult.category_id.isnot(None)) .all() ) texts = [] labels = [] for r in results: if r.column: text = _build_text_features(r.column.name, r.column.comment, r.column.sample_data) texts.append(text) labels.append(r.category_id) from collections import Counter counts = Counter(labels) valid_classes = {c for c, n in counts.items() if n >= min_samples_per_class} filtered_texts = [] filtered_labels = [] for t, l in zip(texts, labels): if l in valid_classes: filtered_texts.append(t) filtered_labels.append(l) return filtered_texts, filtered_labels, len(filtered_labels) def train_model(db: Session, model_name: Optional[str] = None, algorithm: str = "logistic_regression", test_size: float = 0.2): texts, labels, total = _fetch_training_data(db) if total < 20: logger.warning("Not enough training data (need >= 20, got %d)", total) return None X_train, X_test, y_train, y_test = train_test_split( texts, labels, test_size=test_size, random_state=42, stratify=labels ) vectorizer = TfidfVectorizer(analyzer="char_wb", ngram_range=(2, 4), max_features=5000) X_train_vec = vectorizer.fit_transform(X_train) X_test_vec = vectorizer.transform(X_test) if algorithm == "logistic_regression": clf = LogisticRegression(max_iter=1000, multi_class="multinomial", solver="lbfgs") elif algorithm == "random_forest": clf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1) else: clf = LogisticRegression(max_iter=1000, multi_class="multinomial", solver="lbfgs") clf.fit(X_train_vec, y_train) y_pred = clf.predict(X_test_vec) acc = accuracy_score(y_test, y_pred) timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") name = model_name or f"model_{timestamp}" model_path = os.path.join(MODELS_DIR, f"{name}_clf.joblib") vec_path = os.path.join(MODELS_DIR, f"{name}_tfidf.joblib") joblib.dump(clf, model_path) joblib.dump(vectorizer, vec_path) db.query(MLModelVersion).filter(MLModelVersion.is_active == True).update({"is_active": False}) mv = MLModelVersion( name=name, model_path=model_path, vectorizer_path=vec_path, accuracy=acc, train_samples=total, is_active=True, description=f"Algorithm: {algorithm}, test_accuracy: {acc:.4f}", ) db.add(mv) db.commit() db.refresh(mv) logger.info("Trained model %s with accuracy %.4f on %d samples", name, acc, total) return mv def _get_active_model(db: Session): mv = db.query(MLModelVersion).filter(MLModelVersion.is_active == True).first() if not mv or not os.path.exists(mv.model_path) or not os.path.exists(mv.vectorizer_path): return None clf = joblib.load(mv.model_path) vectorizer = joblib.load(mv.vectorizer_path) return clf, vectorizer, mv def predict_categories(db: Session, texts: List[str], top_k: int = 3): model_tuple = _get_active_model(db) if not model_tuple: return [[] for _ in texts] clf, vectorizer, mv = model_tuple X = vectorizer.transform(texts) if hasattr(clf, "predict_proba"): probs = clf.predict_proba(X) else: preds = clf.predict(X) return [[{"category_id": int(p), "confidence": 1.0}] for p in preds] classes = [int(c) for c in clf.classes_] results = [] for prob in probs: top_idx = np.argsort(prob)[::-1][:top_k] suggestions = [] for idx in top_idx: cat_id = classes[idx] confidence = float(prob[idx]) if confidence > 0.01: suggestions.append({"category_id": cat_id, "confidence": round(confidence, 4)}) results.append(suggestions) return results def suggest_for_project_columns(db: Session, project_id: int, column_ids: Optional[List[int]] = None, top_k: int = 3): from app.models.project import ClassificationProject from app.models.metadata import DataColumn project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first() if not project: return {"success": False, "message": "项目不存在"} query = db.query(DataColumn).join( ClassificationResult, (ClassificationResult.column_id == DataColumn.id) & (ClassificationResult.project_id == project_id), isouter=True, ) if column_ids: query = query.filter(DataColumn.id.in_(column_ids)) columns = query.all() texts = [] col_map = [] for col in columns: texts.append(_build_text_features(col.name, col.comment, col.sample_data)) col_map.append(col) if not texts: return {"success": True, "suggestions": [], "message": "没有可预测的字段"} predictions = predict_categories(db, texts, top_k=top_k) suggestions = [] all_category_ids = set() for col, preds in zip(col_map, predictions): for p in preds: all_category_ids.add(p["category_id"]) categories = {c.id: c for c in db.query(Category).filter(Category.id.in_(list(all_category_ids))).all()} for col, preds in zip(col_map, predictions): item = { "column_id": col.id, "column_name": col.name, "table_name": col.table.name if col.table else None, "suggestions": [], } for p in preds: cat = categories.get(p["category_id"]) item["suggestions"].append({ "category_id": p["category_id"], "category_name": cat.name if cat else None, "category_code": cat.code if cat else None, "confidence": p["confidence"], }) suggestions.append(item) return {"success": True, "suggestions": suggestions}