feat: initial commit - Phase 1 & 2 core features

This commit is contained in:
hiderfong
2026-04-22 17:07:33 +08:00
commit 1773bda06b
25005 changed files with 6252106 additions and 0 deletions
View File
View File
+49
View File
@@ -0,0 +1,49 @@
from fastapi import Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from jose import JWTError
from app.core.database import get_db
from app.core.security import decode_token
from app.models.user import User
from app.services import user_service
def get_token_from_request(request: Request) -> str:
auth = request.headers.get("Authorization", "")
if auth.startswith("Bearer "):
return auth[7:]
# Fallback to query param for some special cases
token = request.query_params.get("token")
if token:
return token
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="缺少认证令牌",
headers={"WWW-Authenticate": "Bearer"},
)
def get_current_user(
request: Request,
db: Session = Depends(get_db),
) -> User:
token = get_token_from_request(request)
payload = decode_token(token)
if not payload:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证令牌",
headers={"WWW-Authenticate": "Bearer"},
)
user = user_service.get_user_by_id(db, int(payload.get("sub")))
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
if not user.is_active:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用")
return user
def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
return current_user
+12
View File
@@ -0,0 +1,12 @@
from fastapi import APIRouter
from app.api.v1 import auth, user, datasource, metadata, classification, project, task
api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["认证"])
api_router.include_router(user.router, prefix="/users", tags=["用户管理"])
api_router.include_router(datasource.router, prefix="/datasources", tags=["数据源管理"])
api_router.include_router(metadata.router, prefix="/metadata", tags=["元数据管理"])
api_router.include_router(classification.router, prefix="/classifications", tags=["分类分级标准"])
api_router.include_router(project.router, prefix="/projects", tags=["项目管理"])
api_router.include_router(task.router, prefix="/tasks", tags=["任务管理"])
+35
View File
@@ -0,0 +1,35 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.core.security import decode_token, create_token_pair
from app.schemas.auth import LoginRequest, Token, TokenRefresh
from app.schemas.common import ResponseModel
from app.services.auth_service import login
router = APIRouter()
@router.post("/login", response_model=ResponseModel[Token])
def api_login(req: LoginRequest, db: Session = Depends(get_db)):
token_data = login(db, req.username, req.password)
return ResponseModel(data=Token(**token_data))
@router.post("/refresh", response_model=ResponseModel[Token])
def api_refresh(req: TokenRefresh):
payload = decode_token(req.refresh_token)
if not payload or payload.get("type") != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的刷新令牌",
)
user_id = int(payload.get("sub"))
username = payload.get("username")
access_token, refresh_token = create_token_pair(user_id, username)
return ResponseModel(data=Token(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=30 * 60,
))
+161
View File
@@ -0,0 +1,161 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.classification import (
CategoryCreate, CategoryUpdate, CategoryOut, CategoryTree,
DataLevelOut, RecognitionRuleCreate, RecognitionRuleUpdate, RecognitionRuleOut,
TemplateOut,
)
from app.schemas.common import ResponseModel, ListResponse
from app.services import classification_service, classification_engine
from app.api.deps import get_current_user
router = APIRouter()
@router.get("/categories/tree", response_model=ResponseModel[list])
def get_category_tree(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
tree = classification_service.build_category_tree(db)
return ResponseModel(data=tree)
@router.get("/categories", response_model=ResponseModel[list[CategoryOut]])
def list_categories(
parent_id: Optional[int] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items = classification_service.list_categories(db, parent_id=parent_id)
return ResponseModel(data=[CategoryOut.model_validate(i) for i in items])
@router.post("/categories", response_model=ResponseModel[CategoryOut])
def create_category(
req: CategoryCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
item = classification_service.create_category(db, req)
return ResponseModel(data=CategoryOut.model_validate(item))
@router.put("/categories/{category_id}", response_model=ResponseModel[CategoryOut])
def update_category(
category_id: int,
req: CategoryUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
db_obj = classification_service.get_category(db, category_id)
if not db_obj:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="分类不存在")
item = classification_service.update_category(db, db_obj, req)
return ResponseModel(data=CategoryOut.model_validate(item))
@router.delete("/categories/{category_id}")
def delete_category(
category_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
classification_service.delete_category(db, category_id)
return ResponseModel(message="删除成功")
@router.get("/levels", response_model=ResponseModel[list[DataLevelOut]])
def list_levels(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items = classification_service.list_data_levels(db)
return ResponseModel(data=[DataLevelOut.model_validate(i) for i in items])
@router.get("/rules", response_model=ListResponse[RecognitionRuleOut])
def list_rules(
template_id: Optional[int] = Query(None),
keyword: Optional[str] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items, total = classification_service.list_rules(db, template_id=template_id, keyword=keyword, page=page, page_size=page_size)
out = []
for i in items:
data = RecognitionRuleOut.model_validate(i)
data.category_name = i.category.name if i.category else None
data.level_name = i.level.name if i.level else None
data.level_color = i.level.color if i.level else None
out.append(data)
return ListResponse(data=out, total=total, page=page, page_size=page_size)
@router.post("/rules", response_model=ResponseModel[RecognitionRuleOut])
def create_rule(
req: RecognitionRuleCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
item = classification_service.create_rule(db, req)
data = RecognitionRuleOut.model_validate(item)
data.category_name = item.category.name if item.category else None
data.level_name = item.level.name if item.level else None
data.level_color = item.level.color if item.level else None
return ResponseModel(data=data)
@router.put("/rules/{rule_id}", response_model=ResponseModel[RecognitionRuleOut])
def update_rule(
rule_id: int,
req: RecognitionRuleUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
db_obj = classification_service.get_rule(db, rule_id)
if not db_obj:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="规则不存在")
item = classification_service.update_rule(db, db_obj, req)
data = RecognitionRuleOut.model_validate(item)
data.category_name = item.category.name if item.category else None
data.level_name = item.level.name if item.level else None
data.level_color = item.level.color if item.level else None
return ResponseModel(data=data)
@router.delete("/rules/{rule_id}")
def delete_rule(
rule_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
classification_service.delete_rule(db, rule_id)
return ResponseModel(message="删除成功")
@router.get("/templates", response_model=ResponseModel[list[TemplateOut]])
def list_templates(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items = classification_service.list_templates(db)
return ResponseModel(data=[TemplateOut.model_validate(i) for i in items])
@router.post("/auto-classify/{project_id}")
def auto_classify(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
result = classification_engine.run_auto_classification(db, project_id)
return ResponseModel(data=result)
+81
View File
@@ -0,0 +1,81 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSourceOut, DataSourceTest
from app.schemas.common import ResponseModel, ListResponse
from app.services import datasource_service
from app.api.deps import get_current_user
router = APIRouter()
@router.get("", response_model=ListResponse[DataSourceOut])
def list_datasources(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
keyword: Optional[str] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items, total = datasource_service.list_datasources(db, keyword=keyword, page=page, page_size=page_size)
return ListResponse(data=[DataSourceOut.model_validate(i) for i in items], total=total, page=page, page_size=page_size)
@router.get("/{source_id}", response_model=ResponseModel[DataSourceOut])
def get_datasource(
source_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
item = datasource_service.get_datasource(db, source_id)
if not item:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在")
return ResponseModel(data=DataSourceOut.model_validate(item))
@router.post("", response_model=ResponseModel[DataSourceOut])
def create_datasource(
req: DataSourceCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
item = datasource_service.create_datasource(db, req, current_user.id)
return ResponseModel(data=DataSourceOut.model_validate(item))
@router.put("/{source_id}", response_model=ResponseModel[DataSourceOut])
def update_datasource(
source_id: int,
req: DataSourceUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
db_obj = datasource_service.get_datasource(db, source_id)
if not db_obj:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在")
item = datasource_service.update_datasource(db, db_obj, req)
return ResponseModel(data=DataSourceOut.model_validate(item))
@router.delete("/{source_id}")
def delete_datasource(
source_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
datasource_service.delete_datasource(db, source_id)
return ResponseModel(message="删除成功")
@router.post("/test-connection")
def test_connection(
req: DataSourceTest,
current_user: User = Depends(get_current_user),
):
result = datasource_service.test_connection(req)
return ResponseModel(data=result)
+66
View File
@@ -0,0 +1,66 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.metadata import DatabaseOut, DataTableOut, DataColumnOut
from app.schemas.common import ResponseModel, ListResponse
from app.services import metadata_service
from app.api.deps import get_current_user
router = APIRouter()
@router.get("/tree")
def get_tree(
source_id: Optional[int] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
tree = metadata_service.build_tree(db, source_id=source_id)
return ResponseModel(data=tree)
@router.get("/databases")
def list_databases(
source_id: Optional[int] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items = metadata_service.list_databases(db, source_id=source_id)
return ResponseModel(data=[DatabaseOut.model_validate(i) for i in items])
@router.get("/tables")
def list_tables(
database_id: Optional[int] = Query(None),
keyword: Optional[str] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items, total = metadata_service.list_tables(db, database_id=database_id, keyword=keyword)
return ListResponse(data=[DataTableOut.model_validate(i) for i in items], total=total, page=1, page_size=len(items))
@router.get("/columns")
def list_columns(
table_id: Optional[int] = Query(None),
keyword: Optional[str] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items, total = metadata_service.list_columns(db, table_id=table_id, keyword=keyword, page=page, page_size=page_size)
return ListResponse(data=[DataColumnOut.model_validate(i) for i in items], total=total, page=page, page_size=page_size)
@router.post("/sync/{source_id}")
def sync_metadata(
source_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
result = metadata_service.sync_metadata(db, source_id, current_user.id)
return ResponseModel(data=result)
+100
View File
@@ -0,0 +1,100 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse
from app.services import project_service
from app.api.deps import get_current_user
router = APIRouter()
@router.get("")
def list_projects(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
keyword: Optional[str] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items, total = project_service.list_projects(db, keyword=keyword, page=page, page_size=page_size)
data = []
for p in items:
stats = project_service.get_project_stats(db, p.id)
data.append({
"id": p.id,
"name": p.name,
"template_id": p.template_id,
"status": p.status,
"planned_start": p.planned_start.isoformat() if p.planned_start else None,
"planned_end": p.planned_end.isoformat() if p.planned_end else None,
"created_at": p.created_at.isoformat() if p.created_at else None,
"stats": stats,
})
return ListResponse(data=data, total=total, page=page, page_size=page_size)
@router.post("")
def create_project(
name: str,
template_id: int,
target_source_ids: Optional[str] = None,
description: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
item = project_service.create_project(
db, name=name, template_id=template_id,
created_by=current_user.id,
target_source_ids=target_source_ids,
description=description,
)
return ResponseModel(data={"id": item.id, "name": item.name})
@router.get("/{project_id}")
def get_project(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
p = project_service.get_project(db, project_id)
if not p:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
stats = project_service.get_project_stats(db, p.id)
return ResponseModel(data={
"id": p.id,
"name": p.name,
"template_id": p.template_id,
"status": p.status,
"description": p.description,
"target_source_ids": p.target_source_ids,
"planned_start": p.planned_start.isoformat() if p.planned_start else None,
"planned_end": p.planned_end.isoformat() if p.planned_end else None,
"created_at": p.created_at.isoformat() if p.created_at else None,
"stats": stats,
})
@router.delete("/{project_id}")
def delete_project(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
project_service.delete_project(db, project_id)
return ResponseModel(message="删除成功")
@router.post("/{project_id}/auto-classify")
def project_auto_classify(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from app.services.classification_engine import run_auto_classification
result = run_auto_classification(db, project_id)
return ResponseModel(data=result)
+80
View File
@@ -0,0 +1,80 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.common import ResponseModel, ListResponse
from app.api.deps import get_current_user
router = APIRouter()
@router.get("/my-tasks")
def my_tasks(
status: Optional[str] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from app.models.project import ClassificationTask
query = db.query(ClassificationTask).filter(ClassificationTask.assignee_id == current_user.id)
if status:
query = query.filter(ClassificationTask.status == status)
items = query.order_by(ClassificationTask.created_at.desc()).all()
data = []
for t in items:
data.append({
"id": t.id,
"name": t.name,
"project_id": t.project_id,
"status": t.status,
"deadline": t.deadline.isoformat() if t.deadline else None,
"created_at": t.created_at.isoformat() if t.created_at else None,
})
return ResponseModel(data=data)
@router.get("/my-tasks/{task_id}/items")
def task_items(
task_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from app.models.project import ClassificationTask, ClassificationResult
from app.models.metadata import DataColumn, DataTable, Database as MetaDatabase, DataSource
from app.models.classification import Category, DataLevel
task = db.query(ClassificationTask).filter(ClassificationTask.id == task_id).first()
if not task:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="任务不存在")
results = db.query(ClassificationResult).filter(
ClassificationResult.project_id == task.project_id,
).join(DataColumn).all()
data = []
for r in results:
col = r.column
table = col.table if col else None
database = table.database if table else None
source = database.source if database else None
data.append({
"result_id": r.id,
"column_id": col.id if col else None,
"column_name": col.name if col else None,
"data_type": col.data_type if col else None,
"comment": col.comment if col else None,
"table_name": table.name if table else None,
"database_name": database.name if database else None,
"source_name": source.name if source else None,
"category_id": r.category_id,
"category_name": r.category.name if r.category else None,
"level_id": r.level_id,
"level_name": r.level.name if r.level else None,
"level_color": r.level.color if r.level else None,
"source": r.source,
"confidence": r.confidence,
"status": r.status,
})
return ResponseModel(data=data)
+64
View File
@@ -0,0 +1,64 @@
from typing import Optional, List
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate, UserOut
from app.schemas.common import ResponseModel, ListResponse, PageParams
from app.services import user_service
from app.api.deps import get_current_user
router = APIRouter()
@router.get("/me", response_model=ResponseModel[UserOut])
def read_me(current_user: User = Depends(get_current_user)):
return ResponseModel(data=UserOut.model_validate(current_user))
@router.get("", response_model=ListResponse[UserOut])
def list_users(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=500),
keyword: Optional[str] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
items, total = user_service.list_users(db, keyword=keyword, page=page, page_size=page_size)
return ListResponse(data=[UserOut.model_validate(u) for u in items], total=total, page=page, page_size=page_size)
@router.post("", response_model=ResponseModel[UserOut])
def create_user(
req: UserCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
user = user_service.create_user(db, req)
return ResponseModel(data=UserOut.model_validate(user))
@router.put("/{user_id}", response_model=ResponseModel[UserOut])
def update_user(
user_id: int,
req: UserUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
user = user_service.get_user_by_id(db, user_id)
if not user:
from fastapi import HTTPException, status
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
user = user_service.update_user(db, user, req)
return ResponseModel(data=UserOut.model_validate(user))
@router.delete("/{user_id}")
def delete_user(
user_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
user_service.delete_user(db, user_id)
return ResponseModel(message="删除成功")
View File
+35
View File
@@ -0,0 +1,35 @@
from pydantic_settings import BaseSettings
from typing import List
class Settings(BaseSettings):
PROJECT_NAME: str = "PropDataGuard"
VERSION: str = "0.1.0"
DESCRIPTION: str = "财产保险行业数据分级分类管理平台"
DATABASE_URL: str = "postgresql+psycopg2://pdg:pdg_secret_2024@localhost:5432/prop_data_guard"
REDIS_URL: str = "redis://localhost:6379/0"
SECRET_KEY: str = "prop-data-guard-super-secret-key-change-in-production"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
MINIO_ENDPOINT: str = "localhost:9000"
MINIO_ACCESS_KEY: str = "pdgminio"
MINIO_SECRET_KEY: str = "pdgminio_secret_2024"
MINIO_SECURE: bool = False
MINIO_BUCKET_NAME: str = "pdg-files"
CORS_ORIGINS: List[str] = ["http://localhost:5173", "http://127.0.0.1:5173"]
FIRST_SUPERUSER_USERNAME: str = "admin"
FIRST_SUPERUSER_PASSWORD: str = "admin123"
FIRST_SUPERUSER_EMAIL: str = "admin@propdataguard.com"
class Config:
env_file = ".env"
case_sensitive = True
settings = Settings()
+17
View File
@@ -0,0 +1,17 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from app.core.config import settings
engine = create_engine(settings.DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
+20
View File
@@ -0,0 +1,20 @@
from app.core.config import settings
import redis
import minio
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
minio_client = minio.Minio(
settings.MINIO_ENDPOINT,
access_key=settings.MINIO_ACCESS_KEY,
secret_key=settings.MINIO_SECRET_KEY,
secure=settings.MINIO_SECURE,
)
def init_minio_bucket():
try:
if not minio_client.bucket_exists(settings.MINIO_BUCKET_NAME):
minio_client.make_bucket(settings.MINIO_BUCKET_NAME)
except Exception as e:
print(f"MinIO init warning: {e}")
+50
View File
@@ -0,0 +1,50 @@
from datetime import datetime, timedelta, timezone
from typing import Optional, Tuple
from jose import jwt, JWTError
from passlib.context import CryptContext
from app.core.config import settings
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire, "type": "access"})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def create_refresh_token(data: dict) -> str:
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire, "type": "refresh"})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def decode_token(token: str) -> Optional[dict]:
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
return payload
except JWTError:
return None
def create_token_pair(user_id: int, username: str) -> Tuple[str, str]:
access = create_access_token({"sub": str(user_id), "username": username})
refresh = create_refresh_token({"sub": str(user_id), "username": username})
return access, refresh
+93
View File
@@ -0,0 +1,93 @@
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import time
import json
from app.core.config import settings
from app.core.database import engine, Base
from app.core.events import init_minio_bucket
from app.api.v1 import api_router
from app.models import log as log_models
# Create tables (dev convenience; use Alembic in production)
Base.metadata.create_all(bind=engine)
app = FastAPI(
title=settings.PROJECT_NAME,
version=settings.VERSION,
description=settings.DESCRIPTION,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http")
async def log_requests(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
duration = int((time.time() - start_time) * 1000)
# Skip health checks
if request.url.path in ["/docs", "/openapi.json", "/redoc"]:
return response
from app.core.database import SessionLocal
try:
db = SessionLocal()
body_bytes = b""
if request.method in ["POST", "PUT", "PATCH"]:
try:
body_bytes = await request.body()
# Re-assign body for downstream
async def receive():
return {"type": "http.request", "body": body_bytes}
request._receive = receive
except Exception:
pass
log_entry = log_models.OperationLog(
module=request.url.path.split("/")[2] if len(request.url.path.split("/")) > 2 else "",
action=request.url.path,
method=request.method,
path=str(request.url),
ip=request.client.host if request.client else None,
status_code=response.status_code,
duration_ms=duration,
)
db.add(log_entry)
db.commit()
except Exception:
pass
finally:
db.close()
return response
@app.on_event("startup")
async def startup_event():
init_minio_bucket()
from app.core.database import SessionLocal
from app.services.user_service import create_initial_data
from app.services.classification_service import init_builtin_data
db = SessionLocal()
try:
create_initial_data(db)
init_builtin_data(db)
finally:
db.close()
@app.get("/health")
def health_check():
return {"status": "ok"}
app.include_router(api_router, prefix="/api/v1")
+13
View File
@@ -0,0 +1,13 @@
from app.models.user import User, Role, Dept, UserRole
from app.models.metadata import DataSource, Database, DataTable, DataColumn, UnstructuredFile
from app.models.classification import Category, DataLevel, RecognitionRule, ClassificationTemplate
from app.models.project import ClassificationProject, ClassificationTask, ClassificationResult, ClassificationChange
from app.models.log import OperationLog
__all__ = [
"User", "Role", "Dept", "UserRole",
"DataSource", "Database", "DataTable", "DataColumn", "UnstructuredFile",
"Category", "DataLevel", "RecognitionRule", "ClassificationTemplate",
"ClassificationProject", "ClassificationTask", "ClassificationResult", "ClassificationChange",
"OperationLog",
]
+68
View File
@@ -0,0 +1,68 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, ForeignKey, DateTime, JSON, Float, Boolean
from sqlalchemy.orm import relationship
from app.core.database import Base
class Category(Base):
__tablename__ = "category"
id = Column(Integer, primary_key=True, index=True)
parent_id = Column(Integer, ForeignKey("category.id"), nullable=True)
level = Column(Integer, default=1) # 1, 2, 3
code = Column(String(50), unique=True, nullable=False)
name = Column(String(100), nullable=False)
description = Column(Text)
sort_order = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.utcnow)
parent = relationship("Category", remote_side=[id], backref="children")
class DataLevel(Base):
__tablename__ = "data_level"
id = Column(Integer, primary_key=True, index=True)
code = Column(String(20), unique=True, nullable=False) # L1, L2, L3, L4, L5
name = Column(String(50), nullable=False)
description = Column(Text)
color = Column(String(20), default="#999999")
control_requirements = Column(JSON)
sort_order = Column(Integer, default=0)
class RecognitionRule(Base):
__tablename__ = "recognition_rule"
id = Column(Integer, primary_key=True, index=True)
template_id = Column(Integer, ForeignKey("classification_template.id"), nullable=False)
category_id = Column(Integer, ForeignKey("category.id"), nullable=True)
level_id = Column(Integer, ForeignKey("data_level.id"), nullable=True)
rule_type = Column(String(20), nullable=False) # regex, keyword, enum, ml
rule_name = Column(String(100))
rule_content = Column(Text, nullable=False) # regex pattern / keyword list / enum values
target_field = Column(String(20), default="column_name") # column_name, comment, sample_data
priority = Column(Integer, default=100)
hit_count = Column(Integer, default=0)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
template = relationship("ClassificationTemplate", back_populates="rules")
category = relationship("Category")
level = relationship("DataLevel")
class ClassificationTemplate(Base):
__tablename__ = "classification_template"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False)
industry_type = Column(String(50), default="insurance_property")
version = Column(String(20), default="1.0")
description = Column(Text)
is_builtin = Column(Boolean, default=False)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
rules = relationship("RecognitionRule", back_populates="template", cascade="all, delete-orphan")
+22
View File
@@ -0,0 +1,22 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey
from app.core.database import Base
class OperationLog(Base):
__tablename__ = "sys_operation_log"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("sys_user.id"), nullable=True)
username = Column(String(50))
module = Column(String(50))
action = Column(String(50))
method = Column(String(10))
path = Column(String(500))
ip = Column(String(50))
request_body = Column(Text)
response_body = Column(Text)
status_code = Column(Integer)
duration_ms = Column(Integer)
created_at = Column(DateTime, default=datetime.utcnow)
+86
View File
@@ -0,0 +1,86 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, Text, BigInteger
from sqlalchemy.orm import relationship
from app.core.database import Base
class DataSource(Base):
__tablename__ = "data_source"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False)
source_type = Column(String(50), nullable=False) # mysql, postgresql, oracle, dm, etc.
host = Column(String(200))
port = Column(Integer)
database_name = Column(String(100))
username = Column(String(100))
encrypted_password = Column(Text)
extra_params = Column(Text) # JSON string
status = Column(String(20), default="active") # active, inactive, error
dept_id = Column(Integer, ForeignKey("sys_dept.id"), nullable=True)
created_by = Column(Integer, ForeignKey("sys_user.id"))
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
databases = relationship("Database", back_populates="source", cascade="all, delete-orphan")
creator = relationship("User")
class Database(Base):
__tablename__ = "meta_database"
id = Column(Integer, primary_key=True, index=True)
source_id = Column(Integer, ForeignKey("data_source.id"), nullable=False)
name = Column(String(100), nullable=False)
charset = Column(String(50))
table_count = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.utcnow)
source = relationship("DataSource", back_populates="databases")
tables = relationship("DataTable", back_populates="database", cascade="all, delete-orphan")
class DataTable(Base):
__tablename__ = "meta_table"
id = Column(Integer, primary_key=True, index=True)
database_id = Column(Integer, ForeignKey("meta_database.id"), nullable=False)
name = Column(String(200), nullable=False)
comment = Column(String(500))
row_count = Column(BigInteger, default=0)
column_count = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.utcnow)
database = relationship("Database", back_populates="tables")
columns = relationship("DataColumn", back_populates="table", cascade="all, delete-orphan")
class DataColumn(Base):
__tablename__ = "meta_column"
id = Column(Integer, primary_key=True, index=True)
table_id = Column(Integer, ForeignKey("meta_table.id"), nullable=False)
name = Column(String(200), nullable=False)
data_type = Column(String(100))
length = Column(Integer)
comment = Column(String(500))
is_nullable = Column(Boolean, default=True)
sample_data = Column(Text) # JSON array of sample values
created_at = Column(DateTime, default=datetime.utcnow)
table = relationship("DataTable", back_populates="columns")
class UnstructuredFile(Base):
__tablename__ = "unstructured_file"
id = Column(Integer, primary_key=True, index=True)
original_name = Column(String(255), nullable=False)
file_type = Column(String(50)) # word, pdf, txt, excel
file_size = Column(BigInteger)
storage_path = Column(String(500))
extracted_text = Column(Text)
status = Column(String(20), default="pending") # pending, processed, error
created_by = Column(Integer, ForeignKey("sys_user.id"))
created_at = Column(DateTime, default=datetime.utcnow)
+114
View File
@@ -0,0 +1,114 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, ForeignKey, DateTime, Float, Enum as SAEnum
from sqlalchemy.orm import relationship
import enum
from app.core.database import Base
class ProjectStatus(str, enum.Enum):
CREATED = "created"
SCANNING = "scanning"
ASSIGNING = "assigning"
LABELING = "labeling"
REVIEWING = "reviewing"
ACCEPTING = "accepting"
PUBLISHED = "published"
class TaskStatus(str, enum.Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
REJECTED = "rejected"
class ResultStatus(str, enum.Enum):
AUTO = "auto"
MANUAL = "manual"
REVIEWED = "reviewed"
PUBLISHED = "published"
CONFLICT = "conflict"
class ClassificationProject(Base):
__tablename__ = "classification_project"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(200), nullable=False)
template_id = Column(Integer, ForeignKey("classification_template.id"))
description = Column(Text)
status = Column(String(20), default=ProjectStatus.CREATED.value)
target_source_ids = Column(Text) # comma separated source ids
target_database_ids = Column(Text)
target_table_ids = Column(Text)
planned_start = Column(DateTime)
planned_end = Column(DateTime)
created_by = Column(Integer, ForeignKey("sys_user.id"), nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
tasks = relationship("ClassificationTask", back_populates="project", cascade="all, delete-orphan")
results = relationship("ClassificationResult", back_populates="project", cascade="all, delete-orphan")
class ClassificationTask(Base):
__tablename__ = "classification_task"
id = Column(Integer, primary_key=True, index=True)
project_id = Column(Integer, ForeignKey("classification_project.id"), nullable=False)
name = Column(String(200))
assigner_id = Column(Integer, ForeignKey("sys_user.id"))
assignee_id = Column(Integer, ForeignKey("sys_user.id"))
target_type = Column(String(20), default="table") # table, column, file
target_ids = Column(Text) # comma separated ids
status = Column(String(20), default=TaskStatus.PENDING.value)
deadline = Column(DateTime)
completed_at = Column(DateTime)
created_at = Column(DateTime, default=datetime.utcnow)
project = relationship("ClassificationProject", back_populates="tasks")
assigner = relationship("User", foreign_keys=[assigner_id])
assignee = relationship("User", foreign_keys=[assignee_id])
class ClassificationResult(Base):
__tablename__ = "classification_result"
id = Column(Integer, primary_key=True, index=True)
project_id = Column(Integer, ForeignKey("classification_project.id"), nullable=False)
column_id = Column(Integer, ForeignKey("meta_column.id"), nullable=True)
file_id = Column(Integer, ForeignKey("unstructured_file.id"), nullable=True)
category_id = Column(Integer, ForeignKey("category.id"))
level_id = Column(Integer, ForeignKey("data_level.id"))
source = Column(String(20), default="auto") # auto, manual, ml
confidence = Column(Float, default=0.0) # 0-1
labeler_id = Column(Integer, ForeignKey("sys_user.id"))
reviewer_id = Column(Integer, ForeignKey("sys_user.id"))
status = Column(String(20), default=ResultStatus.AUTO.value)
label_time = Column(DateTime)
review_time = Column(DateTime)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
project = relationship("ClassificationProject", back_populates="results")
category = relationship("Category")
level = relationship("DataLevel")
class ClassificationChange(Base):
__tablename__ = "classification_change"
id = Column(Integer, primary_key=True, index=True)
result_id = Column(Integer, ForeignKey("classification_result.id"), nullable=False)
change_type = Column(String(20), nullable=False) # category, level, both
old_category_id = Column(Integer, ForeignKey("category.id"))
new_category_id = Column(Integer, ForeignKey("category.id"))
old_level_id = Column(Integer, ForeignKey("data_level.id"))
new_level_id = Column(Integer, ForeignKey("data_level.id"))
reason = Column(Text)
applicant_id = Column(Integer, ForeignKey("sys_user.id"))
approver_id = Column(Integer, ForeignKey("sys_user.id"))
approval_status = Column(String(20), default="pending") # pending, approved, rejected
approval_comment = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
+54
View File
@@ -0,0 +1,54 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, Text
from sqlalchemy.orm import relationship
from app.core.database import Base
class Dept(Base):
__tablename__ = "sys_dept"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False)
parent_id = Column(Integer, ForeignKey("sys_dept.id"), nullable=True)
sort_order = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.utcnow)
parent = relationship("Dept", remote_side=[id], backref="children")
class Role(Base):
__tablename__ = "sys_role"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(50), unique=True, nullable=False)
code = Column(String(50), unique=True, nullable=False)
description = Column(String(200))
created_at = Column(DateTime, default=datetime.utcnow)
class User(Base):
__tablename__ = "sys_user"
id = Column(Integer, primary_key=True, index=True)
username = Column(String(50), unique=True, nullable=False, index=True)
email = Column(String(100), unique=True, nullable=True)
hashed_password = Column(String(255), nullable=False)
real_name = Column(String(50))
phone = Column(String(20))
is_active = Column(Boolean, default=True)
is_superuser = Column(Boolean, default=False)
dept_id = Column(Integer, ForeignKey("sys_dept.id"), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
dept = relationship("Dept")
roles = relationship("Role", secondary="sys_user_role", backref="users")
class UserRole(Base):
__tablename__ = "sys_user_role"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("sys_user.id"), nullable=False)
role_id = Column(Integer, ForeignKey("sys_role.id"), nullable=False)
+3
View File
@@ -0,0 +1,3 @@
from app.schemas.auth import Token, TokenRefresh, LoginRequest
from app.schemas.user import UserCreate, UserUpdate, UserOut, RoleOut, DeptOut, DeptTree
from app.schemas.common import ResponseModel, ListResponse, PageParams
+17
View File
@@ -0,0 +1,17 @@
from pydantic import BaseModel
class LoginRequest(BaseModel):
username: str
password: str
class Token(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int
class TokenRefresh(BaseModel):
refresh_token: str
+102
View File
@@ -0,0 +1,102 @@
from typing import Optional, List
from pydantic import BaseModel, Field
from datetime import datetime
class CategoryBase(BaseModel):
parent_id: Optional[int] = None
level: int = 1
code: str = Field(..., max_length=50)
name: str = Field(..., max_length=100)
description: Optional[str] = None
sort_order: int = 0
class CategoryCreate(CategoryBase):
pass
class CategoryUpdate(BaseModel):
parent_id: Optional[int] = None
code: Optional[str] = Field(None, max_length=50)
name: Optional[str] = Field(None, max_length=100)
description: Optional[str] = None
sort_order: Optional[int] = None
class CategoryOut(CategoryBase):
id: int
created_at: datetime
class Config:
from_attributes = True
class CategoryTree(CategoryOut):
children: List["CategoryTree"] = []
class DataLevelOut(BaseModel):
id: int
code: str
name: str
description: Optional[str] = None
color: str
control_requirements: Optional[dict] = None
sort_order: int
class Config:
from_attributes = True
class RecognitionRuleBase(BaseModel):
template_id: int
category_id: Optional[int] = None
level_id: Optional[int] = None
rule_type: str = Field(..., max_length=20) # regex, keyword, enum
rule_name: Optional[str] = Field(None, max_length=100)
rule_content: str
target_field: str = Field(default="column_name", max_length=20)
priority: int = 100
is_active: bool = True
class RecognitionRuleCreate(RecognitionRuleBase):
pass
class RecognitionRuleUpdate(BaseModel):
category_id: Optional[int] = None
level_id: Optional[int] = None
rule_type: Optional[str] = Field(None, max_length=20)
rule_name: Optional[str] = Field(None, max_length=100)
rule_content: Optional[str] = None
target_field: Optional[str] = Field(None, max_length=20)
priority: Optional[int] = None
is_active: Optional[bool] = None
class RecognitionRuleOut(RecognitionRuleBase):
id: int
hit_count: int
created_at: datetime
category_name: Optional[str] = None
level_name: Optional[str] = None
level_color: Optional[str] = None
class Config:
from_attributes = True
class TemplateOut(BaseModel):
id: int
name: str
industry_type: str
version: str
description: Optional[str] = None
is_builtin: bool
is_active: bool
created_at: datetime
class Config:
from_attributes = True
+25
View File
@@ -0,0 +1,25 @@
from typing import Generic, TypeVar, Optional, List
from pydantic import BaseModel, Field
T = TypeVar("T")
class ResponseModel(BaseModel, Generic[T]):
code: int = 200
message: str = "success"
data: Optional[T] = None
class ListResponse(BaseModel, Generic[T]):
code: int = 200
message: str = "success"
data: List[T] = []
total: int = 0
page: int = 1
page_size: int = 20
class PageParams(BaseModel):
page: int = Field(1, ge=1)
page_size: int = Field(20, ge=1, le=500)
keyword: Optional[str] = None
+51
View File
@@ -0,0 +1,51 @@
from typing import Optional
from pydantic import BaseModel, Field
from datetime import datetime
class DataSourceBase(BaseModel):
name: str = Field(..., max_length=100)
source_type: str = Field(..., max_length=50)
host: Optional[str] = Field(None, max_length=200)
port: Optional[int] = None
database_name: Optional[str] = Field(None, max_length=100)
username: Optional[str] = Field(None, max_length=100)
extra_params: Optional[str] = None
status: Optional[str] = "active"
dept_id: Optional[int] = None
class DataSourceCreate(DataSourceBase):
password: Optional[str] = None
class DataSourceUpdate(BaseModel):
name: Optional[str] = Field(None, max_length=100)
source_type: Optional[str] = Field(None, max_length=50)
host: Optional[str] = Field(None, max_length=200)
port: Optional[int] = None
database_name: Optional[str] = Field(None, max_length=100)
username: Optional[str] = Field(None, max_length=100)
password: Optional[str] = None
extra_params: Optional[str] = None
dept_id: Optional[int] = None
class DataSourceOut(DataSourceBase):
id: int
created_by: Optional[int] = None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class DataSourceTest(BaseModel):
source_type: str
host: Optional[str] = None
port: Optional[int] = None
database_name: Optional[str] = None
username: Optional[str] = None
password: Optional[str] = None
extra_params: Optional[str] = None
+51
View File
@@ -0,0 +1,51 @@
from typing import Optional, List
from pydantic import BaseModel
from datetime import datetime
class DatabaseOut(BaseModel):
id: int
source_id: int
name: str
charset: Optional[str] = None
table_count: int = 0
created_at: datetime
class Config:
from_attributes = True
class DataTableOut(BaseModel):
id: int
database_id: int
name: str
comment: Optional[str] = None
row_count: int = 0
column_count: int = 0
created_at: datetime
class Config:
from_attributes = True
class DataColumnOut(BaseModel):
id: int
table_id: int
name: str
data_type: Optional[str] = None
length: Optional[int] = None
comment: Optional[str] = None
is_nullable: bool = True
sample_data: Optional[str] = None
created_at: datetime
class Config:
from_attributes = True
class MetadataTreeNode(BaseModel):
id: int
name: str
type: str # source, database, table, column
children: Optional[List["MetadataTreeNode"]] = None
meta: Optional[dict] = None
+61
View File
@@ -0,0 +1,61 @@
from typing import Optional, List
from pydantic import BaseModel, EmailStr
from datetime import datetime
class RoleOut(BaseModel):
id: int
name: str
code: str
description: Optional[str] = None
class Config:
from_attributes = True
class DeptOut(BaseModel):
id: int
name: str
parent_id: Optional[int] = None
sort_order: int = 0
class Config:
from_attributes = True
class DeptTree(DeptOut):
children: List["DeptTree"] = []
class UserBase(BaseModel):
username: str
email: Optional[str] = None
real_name: Optional[str] = None
phone: Optional[str] = None
dept_id: Optional[int] = None
is_active: bool = True
class UserCreate(UserBase):
password: str
role_ids: List[int] = []
class UserUpdate(BaseModel):
email: Optional[str] = None
real_name: Optional[str] = None
phone: Optional[str] = None
dept_id: Optional[int] = None
is_active: Optional[bool] = None
role_ids: Optional[List[int]] = None
class UserOut(UserBase):
id: int
is_superuser: bool
created_at: datetime
dept: Optional[DeptOut] = None
roles: List[RoleOut] = []
class Config:
from_attributes = True
View File
+33
View File
@@ -0,0 +1,33 @@
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.user import User
from app.core.security import verify_password, create_token_pair
def authenticate_user(db: Session, username: str, password: str):
user = db.query(User).filter(User.username == username).first()
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
if not user.is_active:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用")
return user
def login(db: Session, username: str, password: str):
user = authenticate_user(db, username, password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
access_token, refresh_token = create_token_pair(user.id, user.username)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"expires_in": 30 * 60,
}
@@ -0,0 +1,134 @@
import re
import json
from typing import List, Optional, Tuple
from sqlalchemy.orm import Session
from app.models.classification import RecognitionRule, Category, DataLevel
from app.models.metadata import DataColumn, DataTable
from app.models.project import ClassificationProject, ClassificationResult, ResultStatus
def match_rule(rule: RecognitionRule, column: DataColumn) -> Tuple[bool, float]:
"""Match a single rule against a column. Returns (matched, confidence)."""
targets = []
if rule.target_field == "column_name":
targets = [column.name]
elif rule.target_field == "comment":
targets = [column.comment or ""]
elif rule.target_field == "sample_data":
targets = []
if column.sample_data:
try:
samples = json.loads(column.sample_data)
if isinstance(samples, list):
targets = [str(s) for s in samples]
except Exception:
targets = [column.sample_data]
if not targets:
return False, 0.0
if rule.rule_type == "regex":
try:
pattern = re.compile(rule.rule_content)
for t in targets:
if pattern.search(t):
return True, 0.85
except re.error:
return False, 0.0
elif rule.rule_type == "keyword":
keywords = [k.strip().lower() for k in rule.rule_content.split(",")]
for t in targets:
t_lower = t.lower()
for kw in keywords:
if kw in t_lower:
return True, 0.75
elif rule.rule_type == "enum":
enums = [e.strip().lower() for e in rule.rule_content.split(",")]
for t in targets:
if t.strip().lower() in enums:
return True, 0.90
return False, 0.0
def run_auto_classification(db: Session, project_id: int, source_ids: Optional[List[int]] = None) -> dict:
"""Run automatic classification for a project."""
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if not project:
return {"success": False, "message": "项目不存在"}
# Get active rules from project's template
rules = db.query(RecognitionRule).filter(
RecognitionRule.is_active == True,
RecognitionRule.template_id == project.template_id,
).order_by(RecognitionRule.priority).all()
if not rules:
return {"success": False, "message": "没有可用的识别规则"}
# Get columns to classify
from app.services.metadata_service import list_tables, list_columns
columns_query = db.query(DataColumn).join(DataTable).join(app.models.metadata.Database)
if source_ids:
columns_query = columns_query.filter(app.models.metadata.Database.source_id.in_(source_ids))
elif project.target_source_ids:
sids = [int(x) for x in project.target_source_ids.split(",") if x]
columns_query = columns_query.filter(app.models.metadata.Database.source_id.in_(sids))
columns = columns_query.all()
matched_count = 0
for col in columns:
# Check if already has a result for this project
existing = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project_id,
ClassificationResult.column_id == col.id,
).first()
best_rule = None
best_confidence = 0.0
for rule in rules:
matched, confidence = match_rule(rule, col)
if matched and confidence > best_confidence:
best_confidence = confidence
best_rule = rule
if best_rule:
matched_count += 1
if existing:
existing.category_id = best_rule.category_id
existing.level_id = best_rule.level_id
existing.confidence = best_confidence
existing.source = "auto"
existing.status = ResultStatus.AUTO.value
else:
result = ClassificationResult(
project_id=project_id,
column_id=col.id,
category_id=best_rule.category_id,
level_id=best_rule.level_id,
source="auto",
confidence=best_confidence,
status=ResultStatus.AUTO.value,
)
db.add(result)
# Increment hit count
best_rule.hit_count = (best_rule.hit_count or 0) + 1
db.commit()
return {
"success": True,
"message": f"自动分类完成,共扫描 {len(columns)} 个字段,命中 {matched_count}",
"scanned": len(columns),
"matched": matched_count,
}
import app.models.metadata
@@ -0,0 +1,268 @@
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.classification import Category, DataLevel, RecognitionRule, ClassificationTemplate
from app.schemas.classification import CategoryCreate, CategoryUpdate, RecognitionRuleCreate, RecognitionRuleUpdate
def get_category(db: Session, category_id: int) -> Optional[Category]:
return db.query(Category).filter(Category.id == category_id).first()
def list_categories(db: Session, parent_id: Optional[int] = None) -> List[Category]:
query = db.query(Category)
if parent_id is not None:
query = query.filter(Category.parent_id == parent_id)
return query.order_by(Category.sort_order).all()
def build_category_tree(db: Session) -> List[dict]:
def build_tree(parent_id: Optional[int]) -> List[dict]:
nodes = db.query(Category).filter(Category.parent_id == parent_id).order_by(Category.sort_order).all()
result = []
for node in nodes:
result.append({
"id": node.id,
"parent_id": node.parent_id,
"level": node.level,
"code": node.code,
"name": node.name,
"description": node.description,
"sort_order": node.sort_order,
"created_at": node.created_at,
"children": build_tree(node.id),
})
return result
return build_tree(None)
def create_category(db: Session, obj_in: CategoryCreate) -> Category:
db_obj = Category(**obj_in.model_dump())
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update_category(db: Session, db_obj: Category, obj_in: CategoryUpdate) -> Category:
update_data = obj_in.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_obj, field, value)
db.commit()
db.refresh(db_obj)
return db_obj
def delete_category(db: Session, category_id: int) -> None:
db_obj = get_category(db, category_id)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="分类不存在")
# Check children
children = db.query(Category).filter(Category.parent_id == category_id).first()
if children:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="存在子分类,无法删除")
db.delete(db_obj)
db.commit()
def list_data_levels(db: Session) -> List[DataLevel]:
return db.query(DataLevel).order_by(DataLevel.sort_order).all()
def get_data_level(db: Session, level_id: int) -> Optional[DataLevel]:
return db.query(DataLevel).filter(DataLevel.id == level_id).first()
def create_data_level(db: Session, code: str, name: str, description: str, color: str, sort_order: int = 0, control_requirements: Optional[dict] = None) -> DataLevel:
db_obj = DataLevel(code=code, name=name, description=description, color=color, sort_order=sort_order, control_requirements=control_requirements)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def get_rule(db: Session, rule_id: int) -> Optional[RecognitionRule]:
return db.query(RecognitionRule).filter(RecognitionRule.id == rule_id).first()
def list_rules(db: Session, template_id: Optional[int] = None, keyword: Optional[str] = None, page: int = 1, page_size: int = 20) -> Tuple[List[RecognitionRule], int]:
query = db.query(RecognitionRule)
if template_id:
query = query.filter(RecognitionRule.template_id == template_id)
if keyword:
query = query.filter(
(RecognitionRule.rule_name.contains(keyword)) | (RecognitionRule.rule_content.contains(keyword))
)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
def create_rule(db: Session, obj_in: RecognitionRuleCreate) -> RecognitionRule:
db_obj = RecognitionRule(**obj_in.model_dump())
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update_rule(db: Session, db_obj: RecognitionRule, obj_in: RecognitionRuleUpdate) -> RecognitionRule:
update_data = obj_in.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_obj, field, value)
db.commit()
db.refresh(db_obj)
return db_obj
def delete_rule(db: Session, rule_id: int) -> None:
db_obj = get_rule(db, rule_id)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="规则不存在")
db.delete(db_obj)
db.commit()
def get_template(db: Session, template_id: int) -> Optional[ClassificationTemplate]:
return db.query(ClassificationTemplate).filter(ClassificationTemplate.id == template_id).first()
def list_templates(db: Session) -> List[ClassificationTemplate]:
return db.query(ClassificationTemplate).order_by(ClassificationTemplate.id).all()
def init_builtin_data(db: Session):
# Data Levels
if not db.query(DataLevel).first():
levels = [
("L1", "公开级", "可对外公开发布", "#67c23a", 1, {"storage": "无特殊要求", "access": "公开访问"}),
("L2", "内部级", "公司内部共享使用", "#409eff", 2, {"storage": "内部环境", "access": "内部员工"}),
("L3", "敏感级", "部门/授权人员访问,外部分享需审批", "#e6a23c", 3, {"storage": "加密存储", "access": "授权访问"}),
("L4", "重要级", "严格授权管理,外部分享需严格审批", "#f56c6c", 4, {"storage": "强加密", "access": "最小权限"}),
("L5", "核心级", "禁止对外共享", "#909399", 5, {"storage": "物理隔离", "access": "核心人员"}),
]
for code, name, desc, color, sort, ctrl in levels:
create_data_level(db, code, name, desc, color, sort, ctrl)
# Categories
if not db.query(Category).first():
categories = [
# Level 1
{"code": "CUST", "name": "客户数据", "level": 1, "sort_order": 1},
{"code": "POLICY", "name": "保单数据", "level": 1, "sort_order": 2},
{"code": "CLAIM", "name": "理赔数据", "level": 1, "sort_order": 3},
{"code": "FIN", "name": "财务数据", "level": 1, "sort_order": 4},
{"code": "CHANNEL", "name": "渠道数据", "level": 1, "sort_order": 5},
{"code": "REG", "name": "监管报送数据", "level": 1, "sort_order": 6},
{"code": "INTERNAL", "name": "内部管理数据", "level": 1, "sort_order": 7},
{"code": "SUBJECT", "name": "车辆/财产标的数据", "level": 1, "sort_order": 8},
]
cat_map = {}
for c in categories:
obj = Category(parent_id=None, level=c["level"], code=c["code"], name=c["name"], sort_order=c["sort_order"])
db.add(obj)
db.commit()
db.refresh(obj)
cat_map[c["code"]] = obj.id
# Level 2
sub_categories = [
{"parent_code": "CUST", "code": "CUST_PERSONAL", "name": "个人客户信息", "sort_order": 1},
{"parent_code": "CUST", "code": "CUST_ENTERPRISE", "name": "企业客户信息", "sort_order": 2},
{"parent_code": "CUST", "code": "CUST_BENEFICIARY", "name": "受益人信息", "sort_order": 3},
{"parent_code": "POLICY", "code": "POLICY_APPLY", "name": "投保信息", "sort_order": 1},
{"parent_code": "POLICY", "code": "POLICY_UNDERWRITE", "name": "承保信息", "sort_order": 2},
{"parent_code": "POLICY", "code": "POLICY_RENEW", "name": "续保信息", "sort_order": 3},
{"parent_code": "CLAIM", "code": "CLAIM_REPORT", "name": "报案信息", "sort_order": 1},
{"parent_code": "CLAIM", "code": "CLAIM_SURVEY", "name": "查勘定损信息", "sort_order": 2},
{"parent_code": "CLAIM", "code": "CLAIM_PAY", "name": "赔付信息", "sort_order": 3},
{"parent_code": "FIN", "code": "FIN_PAYMENT", "name": "收付费数据", "sort_order": 1},
{"parent_code": "FIN", "code": "FIN_RESERVE", "name": "准备金数据", "sort_order": 2},
{"parent_code": "FIN", "code": "FIN_INVEST", "name": "投资数据", "sort_order": 3},
{"parent_code": "CHANNEL", "code": "CHN_AGENT", "name": "代理人/经纪人信息", "sort_order": 1},
{"parent_code": "CHANNEL", "code": "CHN_PARTNER", "name": "第三方合作方", "sort_order": 2},
{"parent_code": "REG", "code": "REG_SOLVENCY", "name": "偿付能力数据", "sort_order": 1},
{"parent_code": "REG", "code": "REG_STAT", "name": "统计报表数据", "sort_order": 2},
{"parent_code": "INTERNAL", "code": "INT_EMPLOYEE", "name": "员工信息", "sort_order": 1},
{"parent_code": "INTERNAL", "code": "INT_OPS", "name": "系统运维数据", "sort_order": 2},
{"parent_code": "SUBJECT", "code": "SUB_VEHICLE", "name": "车辆信息", "sort_order": 1},
{"parent_code": "SUBJECT", "code": "SUB_PROPERTY", "name": "财产标的", "sort_order": 2},
]
for sc in sub_categories:
parent_id = cat_map.get(sc["parent_code"])
if parent_id:
obj = Category(parent_id=parent_id, level=2, code=sc["code"], name=sc["name"], sort_order=sc["sort_order"])
db.add(obj)
db.commit()
# Template
if not db.query(ClassificationTemplate).first():
tpl = ClassificationTemplate(
name="财产保险行业分类分级模板",
industry_type="insurance_property",
version="1.0",
description="基于《金融数据安全 数据安全分级指南》及保险行业特点制定的分类分级模板",
is_builtin=True,
is_active=True,
)
db.add(tpl)
db.commit()
db.refresh(tpl)
# Create some sample rules
level_l4 = db.query(DataLevel).filter(DataLevel.code == "L4").first()
level_l3 = db.query(DataLevel).filter(DataLevel.code == "L3").first()
level_l5 = db.query(DataLevel).filter(DataLevel.code == "L5").first()
cat_cust_personal = db.query(Category).filter(Category.code == "CUST_PERSONAL").first()
cat_fin_reserve = db.query(Category).filter(Category.code == "FIN_RESERVE").first()
cat_int_ops = db.query(Category).filter(Category.code == "INT_OPS").first()
rules = []
if cat_cust_personal and level_l4:
rules.append(RecognitionRule(
template_id=tpl.id,
category_id=cat_cust_personal.id,
level_id=level_l4.id,
rule_type="regex",
rule_name="身份证号识别",
rule_content=r"(\d{15}|\d{18}|\d{17}[xX])",
target_field="sample_data",
priority=10,
))
rules.append(RecognitionRule(
template_id=tpl.id,
category_id=cat_cust_personal.id,
level_id=level_l4.id,
rule_type="keyword",
rule_name="手机号字段识别",
rule_content="手机,mobile,phone,telephone,tel",
target_field="column_name",
priority=20,
))
if cat_fin_reserve and level_l5:
rules.append(RecognitionRule(
template_id=tpl.id,
category_id=cat_fin_reserve.id,
level_id=level_l5.id,
rule_type="keyword",
rule_name="精算模型识别",
rule_content="精算,actuarial,准备金,reserve,偿付能力,solvency",
target_field="column_name",
priority=10,
))
if cat_int_ops and level_l5:
rules.append(RecognitionRule(
template_id=tpl.id,
category_id=cat_int_ops.id,
level_id=level_l5.id,
rule_type="keyword",
rule_name="密码密钥识别",
rule_content="password,secret,key,token,密钥,密码",
target_field="column_name",
priority=5,
))
for r in rules:
db.add(r)
db.commit()
+121
View File
@@ -0,0 +1,121 @@
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from cryptography.fernet import Fernet
from app.models.metadata import DataSource
from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSourceTest
from app.core.config import settings
# Simple AES-like symmetric encryption for DB passwords
# In production, use a proper KMS
_fernet = Fernet(Fernet.generate_key())
def _encrypt_password(password: str) -> str:
return _fernet.encrypt(password.encode()).decode()
def _decrypt_password(encrypted: str) -> str:
return _fernet.decrypt(encrypted.encode()).decode()
def get_datasource(db: Session, source_id: int) -> Optional[DataSource]:
return db.query(DataSource).filter(DataSource.id == source_id).first()
def list_datasources(
db: Session, keyword: Optional[str] = None, page: int = 1, page_size: int = 20
) -> Tuple[List[DataSource], int]:
query = db.query(DataSource)
if keyword:
query = query.filter(
(DataSource.name.contains(keyword)) | (DataSource.host.contains(keyword))
)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
def create_datasource(db: Session, obj_in: DataSourceCreate, user_id: int) -> DataSource:
db_obj = DataSource(
name=obj_in.name,
source_type=obj_in.source_type,
host=obj_in.host,
port=obj_in.port,
database_name=obj_in.database_name,
username=obj_in.username,
encrypted_password=_encrypt_password(obj_in.password) if obj_in.password else None,
extra_params=obj_in.extra_params,
status=obj_in.status or "active",
dept_id=obj_in.dept_id,
created_by=user_id,
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update_datasource(db: Session, db_obj: DataSource, obj_in: DataSourceUpdate) -> DataSource:
update_data = obj_in.model_dump(exclude_unset=True)
if "password" in update_data and update_data["password"]:
update_data["encrypted_password"] = _encrypt_password(update_data.pop("password"))
else:
update_data.pop("password", None)
for field, value in update_data.items():
setattr(db_obj, field, value)
db.commit()
db.refresh(db_obj)
return db_obj
def delete_datasource(db: Session, source_id: int) -> None:
db_obj = get_datasource(db, source_id)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在")
db.delete(db_obj)
db.commit()
def test_connection(obj_in: DataSourceTest) -> dict:
from sqlalchemy import create_engine, inspect, text
driver_map = {
"mysql": "mysql+pymysql",
"postgresql": "postgresql+psycopg2",
"oracle": "oracle+cx_oracle",
"sqlserver": "mssql+pymssql",
"dm": "dm-python", # placeholder
}
driver = driver_map.get(obj_in.source_type, obj_in.source_type)
if obj_in.source_type == "dm":
# For MVP, mock test for Dameng
return {"success": True, "message": "达梦数据库连接测试通过(模拟)"}
host = obj_in.host or "localhost"
port = obj_in.port or 5432
database = obj_in.database_name or ""
username = obj_in.username or ""
password = obj_in.password or ""
try:
if obj_in.source_type == "postgresql":
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
elif obj_in.source_type == "mysql":
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
elif obj_in.source_type == "oracle":
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
elif obj_in.source_type == "sqlserver":
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
else:
url = f"{driver}://{username}:{password}@{host}:{port}/{database}"
engine = create_engine(url, pool_pre_ping=True, connect_args={"connect_timeout": 5})
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
return {"success": True, "message": "连接测试通过"}
except Exception as e:
return {"success": False, "message": f"连接失败: {str(e)}"}
+183
View File
@@ -0,0 +1,183 @@
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.metadata import DataSource, Database, DataTable, DataColumn
from app.services.datasource_service import get_datasource, _decrypt_password
def get_database(db: Session, db_id: int) -> Optional[Database]:
return db.query(Database).filter(Database.id == db_id).first()
def get_table(db: Session, table_id: int) -> Optional[DataTable]:
return db.query(DataTable).filter(DataTable.id == table_id).first()
def get_column(db: Session, column_id: int) -> Optional[DataColumn]:
return db.query(DataColumn).filter(DataColumn.id == column_id).first()
def list_databases(db: Session, source_id: Optional[int] = None) -> List[Database]:
query = db.query(Database)
if source_id:
query = query.filter(Database.source_id == source_id)
return query.all()
def list_tables(db: Session, database_id: Optional[int] = None, keyword: Optional[str] = None) -> Tuple[List[DataTable], int]:
query = db.query(DataTable)
if database_id:
query = query.filter(DataTable.database_id == database_id)
if keyword:
query = query.filter(
(DataTable.name.contains(keyword)) | (DataTable.comment.contains(keyword))
)
return query.all(), query.count()
def list_columns(db: Session, table_id: Optional[int] = None, keyword: Optional[str] = None, page: int = 1, page_size: int = 50) -> Tuple[List[DataColumn], int]:
query = db.query(DataColumn)
if table_id:
query = query.filter(DataColumn.table_id == table_id)
if keyword:
query = query.filter(
(DataColumn.name.contains(keyword)) | (DataColumn.comment.contains(keyword))
)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
def build_tree(db: Session, source_id: Optional[int] = None) -> List[dict]:
sources = db.query(DataSource)
if source_id:
sources = sources.filter(DataSource.id == source_id)
sources = sources.all()
result = []
for s in sources:
source_node = {
"id": s.id,
"name": s.name,
"type": "source",
"children": [],
"meta": {"source_type": s.source_type, "status": s.status},
}
for d in s.databases:
db_node = {
"id": d.id,
"name": d.name,
"type": "database",
"children": [],
"meta": {"charset": d.charset, "table_count": d.table_count},
}
for t in d.tables:
table_node = {
"id": t.id,
"name": t.name,
"type": "table",
"children": [],
"meta": {"comment": t.comment, "row_count": t.row_count, "column_count": t.column_count},
}
db_node["children"].append(table_node)
source_node["children"].append(db_node)
result.append(source_node)
return result
def sync_metadata(db: Session, source_id: int, user_id: int) -> dict:
from sqlalchemy import create_engine, inspect, text
import json
source = get_datasource(db, source_id)
if not source:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="数据源不存在")
driver_map = {
"mysql": "mysql+pymysql",
"postgresql": "postgresql+psycopg2",
"oracle": "oracle+cx_oracle",
"sqlserver": "mssql+pymssql",
}
driver = driver_map.get(source.source_type, source.source_type)
if source.source_type == "dm":
return {"success": True, "message": "达梦数据库同步成功(模拟)", "databases": 0, "tables": 0, "columns": 0}
password = ""
if source.encrypted_password:
try:
password = _decrypt_password(source.encrypted_password)
except Exception:
pass
try:
url = f"{driver}://{source.username}:{password}@{source.host}:{source.port}/{source.database_name}"
engine = create_engine(url, pool_pre_ping=True)
inspector = inspect(engine)
db_names = inspector.get_schema_names() or [source.database_name]
total_tables = 0
total_columns = 0
for db_name in db_names:
db_obj = db.query(Database).filter(Database.source_id == source.id, Database.name == db_name).first()
if not db_obj:
db_obj = Database(source_id=source.id, name=db_name)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
table_names = inspector.get_table_names(schema=db_name)
for tname in table_names:
table_obj = db.query(DataTable).filter(DataTable.database_id == db_obj.id, DataTable.name == tname).first()
if not table_obj:
table_obj = DataTable(database_id=db_obj.id, name=tname)
db.add(table_obj)
db.commit()
db.refresh(table_obj)
columns = inspector.get_columns(tname, schema=db_name)
for col in columns:
col_obj = db.query(DataColumn).filter(DataColumn.table_id == table_obj.id, DataColumn.name == col["name"]).first()
if not col_obj:
sample = None
try:
with engine.connect() as conn:
result = conn.execute(text(f'SELECT "{col["name"]}" FROM "{db_name}"."{tname}" LIMIT 5'))
samples = [str(r[0]) for r in result if r[0] is not None]
sample = json.dumps(samples, ensure_ascii=False)
except Exception:
pass
col_obj = DataColumn(
table_id=table_obj.id,
name=col["name"],
data_type=str(col.get("type", "")),
length=col.get("max_length"),
comment=col.get("comment"),
is_nullable=col.get("nullable", True),
sample_data=sample,
)
db.add(col_obj)
total_columns += 1
total_tables += 1
db.commit()
source.status = "active"
db.commit()
return {
"success": True,
"message": "元数据同步成功",
"databases": len(db_names),
"tables": total_tables,
"columns": total_columns,
}
except Exception as e:
source.status = "error"
db.commit()
return {"success": False, "message": f"同步失败: {str(e)}", "databases": 0, "tables": 0, "columns": 0}
+120
View File
@@ -0,0 +1,120 @@
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.project import ClassificationProject, ClassificationTask, ClassificationResult
from app.models.classification import Category, DataLevel
from app.models.metadata import DataColumn, DataTable, Database as MetaDatabase
def get_project(db: Session, project_id: int) -> Optional[ClassificationProject]:
return db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
def list_projects(
db: Session, keyword: Optional[str] = None, page: int = 1, page_size: int = 20
) -> Tuple[List[ClassificationProject], int]:
query = db.query(ClassificationProject)
if keyword:
query = query.filter(ClassificationProject.name.contains(keyword))
total = query.count()
items = query.order_by(ClassificationProject.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
return items, total
def create_project(db: Session, name: str, template_id: int, created_by: int, **kwargs) -> ClassificationProject:
db_obj = ClassificationProject(
name=name,
template_id=template_id,
created_by=created_by,
**kwargs,
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update_project(db: Session, db_obj: ClassificationProject, **kwargs) -> ClassificationProject:
for k, v in kwargs.items():
if v is not None:
setattr(db_obj, k, v)
db.commit()
db.refresh(db_obj)
return db_obj
def delete_project(db: Session, project_id: int) -> None:
db_obj = get_project(db, project_id)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="项目不存在")
db.delete(db_obj)
db.commit()
def get_project_stats(db: Session, project_id: int) -> dict:
total = db.query(ClassificationResult).filter(ClassificationResult.project_id == project_id).count()
auto_count = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project_id,
ClassificationResult.source == "auto",
).count()
manual_count = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project_id,
ClassificationResult.source == "manual",
).count()
reviewed_count = db.query(ClassificationResult).filter(
ClassificationResult.project_id == project_id,
ClassificationResult.status == "reviewed",
).count()
return {
"total": total,
"auto": auto_count,
"manual": manual_count,
"reviewed": reviewed_count,
}
def list_results(
db: Session,
project_id: Optional[int] = None,
table_id: Optional[int] = None,
status: Optional[str] = None,
keyword: Optional[str] = None,
page: int = 1,
page_size: int = 50,
) -> Tuple[List[ClassificationResult], int]:
query = db.query(ClassificationResult)
if project_id:
query = query.filter(ClassificationResult.project_id == project_id)
if table_id:
query = query.join(DataColumn).filter(DataColumn.table_id == table_id)
if status:
query = query.filter(ClassificationResult.status == status)
if keyword:
query = query.join(DataColumn).filter(
(DataColumn.name.contains(keyword)) | (DataColumn.comment.contains(keyword))
)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
def update_result_label(
db: Session,
result_id: int,
category_id: int,
level_id: int,
labeler_id: int,
) -> ClassificationResult:
result = db.query(ClassificationResult).filter(ClassificationResult.id == result_id).first()
if not result:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="结果不存在")
result.category_id = category_id
result.level_id = level_id
result.labeler_id = labeler_id
result.source = "manual"
result.status = "manual"
result.label_time = __import__('datetime').datetime.utcnow()
db.commit()
db.refresh(result)
return result
+127
View File
@@ -0,0 +1,127 @@
from typing import Optional, List
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.user import User, Role, Dept, UserRole
from app.schemas.user import UserCreate, UserUpdate
from app.core.security import get_password_hash
def get_user_by_id(db: Session, user_id: int) -> Optional[User]:
return db.query(User).filter(User.id == user_id).first()
def get_user_by_username(db: Session, username: str) -> Optional[User]:
return db.query(User).filter(User.username == username).first()
def create_user(db: Session, obj_in: UserCreate) -> User:
if get_user_by_username(db, obj_in.username):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在")
db_obj = User(
username=obj_in.username,
email=obj_in.email,
hashed_password=get_password_hash(obj_in.password),
real_name=obj_in.real_name,
phone=obj_in.phone,
dept_id=obj_in.dept_id,
is_active=obj_in.is_active,
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
if obj_in.role_ids:
for rid in obj_in.role_ids:
role = db.query(Role).filter(Role.id == rid).first()
if role:
db.add(UserRole(user_id=db_obj.id, role_id=rid))
db.commit()
db.refresh(db_obj)
return db_obj
def update_user(db: Session, db_obj: User, obj_in: UserUpdate) -> User:
update_data = obj_in.model_dump(exclude_unset=True)
role_ids = update_data.pop("role_ids", None)
for field, value in update_data.items():
setattr(db_obj, field, value)
if role_ids is not None:
db.query(UserRole).filter(UserRole.user_id == db_obj.id).delete()
for rid in role_ids:
role = db.query(Role).filter(Role.id == rid).first()
if role:
db.add(UserRole(user_id=db_obj.id, role_id=rid))
db.commit()
db.refresh(db_obj)
return db_obj
def delete_user(db: Session, user_id: int) -> None:
user = get_user_by_id(db, user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
if user.is_superuser:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="不能删除超级管理员")
db.delete(user)
db.commit()
def list_users(db: Session, keyword: Optional[str] = None, page: int = 1, page_size: int = 20):
query = db.query(User)
if keyword:
query = query.filter(
(User.username.contains(keyword))
| (User.real_name.contains(keyword))
| (User.email.contains(keyword))
)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
def create_initial_data(db: Session):
# Create default roles
default_roles = [
{"name": "超级管理员", "code": "superadmin", "description": "系统超级管理员"},
{"name": "管理员", "code": "admin", "description": "系统管理员"},
{"name": "项目负责人", "code": "project_manager", "description": "分类分级项目负责人"},
{"name": "打标员", "code": "labeler", "description": "数据打标人员"},
{"name": "审核员", "code": "reviewer", "description": "结果审核人员"},
{"name": "访客", "code": "guest", "description": "只读访客"},
]
for r in default_roles:
if not db.query(Role).filter(Role.code == r["code"]).first():
db.add(Role(**r))
# Create root dept
if not db.query(Dept).filter(Dept.id == 1).first():
db.add(Dept(id=1, name="根部门", parent_id=None, sort_order=0))
db.commit()
# Create superuser
from app.core.config import settings
if not get_user_by_username(db, settings.FIRST_SUPERUSER_USERNAME):
superuser = User(
username=settings.FIRST_SUPERUSER_USERNAME,
email=settings.FIRST_SUPERUSER_EMAIL,
hashed_password=get_password_hash(settings.FIRST_SUPERUSER_PASSWORD),
real_name="超级管理员",
is_active=True,
is_superuser=True,
dept_id=1,
)
db.add(superuser)
db.commit()
db.refresh(superuser)
superadmin_role = db.query(Role).filter(Role.code == "superadmin").first()
if superadmin_role:
db.add(UserRole(user_id=superuser.id, role_id=superadmin_role.id))
db.commit()
View File
+32
View File
@@ -0,0 +1,32 @@
from app.tasks.worker import celery_app
@celery_app.task(bind=True)
def auto_classify_task(self, project_id: int, source_ids: list = None):
"""
Async task to run automatic classification on metadata.
Phase 1 placeholder.
"""
from app.core.database import SessionLocal
from app.models.project import ClassificationProject, ClassificationResult, ResultStatus
from app.models.classification import RecognitionRule
from app.models.metadata import DataColumn
db = SessionLocal()
try:
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
if not project:
return {"status": "failed", "reason": "project not found"}
# Update project status
project.status = "scanning"
db.commit()
rules = db.query(RecognitionRule).filter(RecognitionRule.is_active == True).all()
# TODO: implement rule matching logic in Phase 2
project.status = "assigning"
db.commit()
return {"status": "completed", "project_id": project_id, "matched": 0}
finally:
db.close()
+20
View File
@@ -0,0 +1,20 @@
from celery import Celery
from app.core.config import settings
celery_app = Celery(
"prop_data_guard",
broker=settings.REDIS_URL,
backend=settings.REDIS_URL,
include=["app.tasks.classification_tasks"],
)
celery_app.conf.update(
task_serializer="json",
accept_content=["json"],
result_serializer="json",
timezone="Asia/Shanghai",
enable_utc=True,
task_track_started=True,
task_time_limit=3600,
worker_prefetch_multiplier=1,
)
View File