feat: Phase 3-5 - workflow, labeling, reports, dashboard enhancement, tests
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1 import auth, user, datasource, metadata, classification, project, task
|
||||
from app.api.v1 import auth, user, datasource, metadata, classification, project, task, report
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["认证"])
|
||||
@@ -10,3 +10,4 @@ api_router.include_router(metadata.router, prefix="/metadata", tags=["元数据
|
||||
api_router.include_router(classification.router, prefix="/classifications", tags=["分类分级标准"])
|
||||
api_router.include_router(project.router, prefix="/projects", tags=["项目管理"])
|
||||
api_router.include_router(task.router, prefix="/tasks", tags=["任务管理"])
|
||||
api_router.include_router(report.router, prefix="/reports", tags=["报告管理"])
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
from fastapi import APIRouter, Depends, Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.services import report_service
|
||||
from app.api.deps import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}/download")
|
||||
def download_report(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
content = report_service.generate_classification_report(db, project_id)
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
headers={"Content-Disposition": f"attachment; filename=report_project_{project_id}.docx"},
|
||||
)
|
||||
+71
-41
@@ -6,6 +6,7 @@ from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.common import ResponseModel, ListResponse
|
||||
from app.api.deps import get_current_user
|
||||
from app.services import task_service, project_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -16,17 +17,15 @@ def my_tasks(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
from app.models.project import ClassificationTask
|
||||
query = db.query(ClassificationTask).filter(ClassificationTask.assignee_id == current_user.id)
|
||||
if status:
|
||||
query = query.filter(ClassificationTask.status == status)
|
||||
items = query.order_by(ClassificationTask.created_at.desc()).all()
|
||||
items, _ = task_service.list_tasks(db, assignee_id=current_user.id, status=status)
|
||||
data = []
|
||||
for t in items:
|
||||
project = project_service.get_project(db, t.project_id)
|
||||
data.append({
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"name": t.name or (project.name if project else f"任务#{t.id}"),
|
||||
"project_id": t.project_id,
|
||||
"project_name": project.name if project else None,
|
||||
"status": t.status,
|
||||
"deadline": t.deadline.isoformat() if t.deadline else None,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
@@ -34,47 +33,78 @@ def my_tasks(
|
||||
return ResponseModel(data=data)
|
||||
|
||||
|
||||
@router.get("/my-tasks/{task_id}/items")
|
||||
def task_items(
|
||||
@router.post("/my-tasks/{task_id}/start")
|
||||
def start_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
from app.models.project import ClassificationTask, ClassificationResult
|
||||
from app.models.metadata import DataColumn, DataTable, Database as MetaDatabase, DataSource
|
||||
from app.models.classification import Category, DataLevel
|
||||
|
||||
task = db.query(ClassificationTask).filter(ClassificationTask.id == task_id).first()
|
||||
task = task_service.get_task(db, task_id)
|
||||
if not task:
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="任务不存在")
|
||||
task = task_service.update_task_status(db, task, "in_progress")
|
||||
return ResponseModel(data={"id": task.id, "status": task.status})
|
||||
|
||||
results = db.query(ClassificationResult).filter(
|
||||
ClassificationResult.project_id == task.project_id,
|
||||
).join(DataColumn).all()
|
||||
|
||||
data = []
|
||||
for r in results:
|
||||
col = r.column
|
||||
table = col.table if col else None
|
||||
database = table.database if table else None
|
||||
source = database.source if database else None
|
||||
data.append({
|
||||
"result_id": r.id,
|
||||
"column_id": col.id if col else None,
|
||||
"column_name": col.name if col else None,
|
||||
"data_type": col.data_type if col else None,
|
||||
"comment": col.comment if col else None,
|
||||
"table_name": table.name if table else None,
|
||||
"database_name": database.name if database else None,
|
||||
"source_name": source.name if source else None,
|
||||
"category_id": r.category_id,
|
||||
"category_name": r.category.name if r.category else None,
|
||||
"level_id": r.level_id,
|
||||
"level_name": r.level.name if r.level else None,
|
||||
"level_color": r.level.color if r.level else None,
|
||||
"source": r.source,
|
||||
"confidence": r.confidence,
|
||||
"status": r.status,
|
||||
})
|
||||
return ResponseModel(data=data)
|
||||
@router.post("/my-tasks/{task_id}/complete")
|
||||
def complete_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
task = task_service.get_task(db, task_id)
|
||||
if not task:
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="任务不存在")
|
||||
task = task_service.update_task_status(db, task, "completed")
|
||||
return ResponseModel(data={"id": task.id, "status": task.status})
|
||||
|
||||
|
||||
@router.get("/my-tasks/{task_id}/items")
|
||||
def task_items(
|
||||
task_id: int,
|
||||
keyword: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
task = task_service.get_task(db, task_id)
|
||||
if not task:
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="任务不存在")
|
||||
items = task_service.get_task_label_items(db, task.project_id, keyword=keyword)
|
||||
return ResponseModel(data=items)
|
||||
|
||||
|
||||
@router.post("/results/{result_id}/label")
|
||||
def label_result(
|
||||
result_id: int,
|
||||
category_id: int,
|
||||
level_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
result = project_service.update_result_label(db, result_id, category_id, level_id, current_user.id)
|
||||
return ResponseModel(data={
|
||||
"result_id": result.id,
|
||||
"category_id": result.category_id,
|
||||
"level_id": result.level_id,
|
||||
"status": result.status,
|
||||
})
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/create-task")
|
||||
def create_task_for_project(
|
||||
project_id: int,
|
||||
name: str,
|
||||
assignee_id: int,
|
||||
target_type: str = Query("column"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
task = task_service.create_task(
|
||||
db, project_id=project_id, name=name,
|
||||
assigner_id=current_user.id, assignee_id=assignee_id,
|
||||
target_type=target_type,
|
||||
)
|
||||
return ResponseModel(data={"id": task.id, "name": task.name})
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime
|
||||
|
||||
from docx import Document
|
||||
from docx.shared import Inches, Pt, RGBColor
|
||||
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
||||
|
||||
from app.models.project import ClassificationProject, ClassificationResult
|
||||
from app.models.classification import Category, DataLevel
|
||||
|
||||
|
||||
def generate_classification_report(db: Session, project_id: int) -> bytes:
|
||||
"""Generate a Word report for a classification project."""
|
||||
project = db.query(ClassificationProject).filter(ClassificationProject.id == project_id).first()
|
||||
if not project:
|
||||
raise ValueError("项目不存在")
|
||||
|
||||
doc = Document()
|
||||
|
||||
# Title
|
||||
title = doc.add_heading('数据分类分级项目报告', 0)
|
||||
title.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
||||
|
||||
# Basic info
|
||||
doc.add_heading('一、项目基本信息', level=1)
|
||||
info_table = doc.add_table(rows=4, cols=2)
|
||||
info_table.style = 'Light Grid Accent 1'
|
||||
info_data = [
|
||||
('项目名称', project.name),
|
||||
('报告生成时间', datetime.now().strftime('%Y-%m-%d %H:%M:%S')),
|
||||
('项目状态', project.status),
|
||||
('模板版本', project.template.version if project.template else 'N/A'),
|
||||
]
|
||||
for i, (k, v) in enumerate(info_data):
|
||||
info_table.rows[i].cells[0].text = k
|
||||
info_table.rows[i].cells[1].text = str(v)
|
||||
|
||||
# Statistics
|
||||
doc.add_heading('二、分类分级统计', level=1)
|
||||
results = db.query(ClassificationResult).filter(ClassificationResult.project_id == project_id).all()
|
||||
|
||||
total = len(results)
|
||||
auto_count = sum(1 for r in results if r.source == 'auto')
|
||||
manual_count = sum(1 for r in results if r.source == 'manual')
|
||||
|
||||
level_stats = {}
|
||||
for r in results:
|
||||
if r.level:
|
||||
level_stats[r.level.name] = level_stats.get(r.level.name, 0) + 1
|
||||
|
||||
doc.add_paragraph(f'总字段数: {total}')
|
||||
doc.add_paragraph(f'自动识别: {auto_count}')
|
||||
doc.add_paragraph(f'人工打标: {manual_count}')
|
||||
|
||||
doc.add_heading('三、分级分布', level=1)
|
||||
level_table = doc.add_table(rows=1, cols=3)
|
||||
level_table.style = 'Light Grid Accent 1'
|
||||
hdr_cells = level_table.rows[0].cells
|
||||
hdr_cells[0].text = '分级'
|
||||
hdr_cells[1].text = '数量'
|
||||
hdr_cells[2].text = '占比'
|
||||
for level_name, count in sorted(level_stats.items(), key=lambda x: -x[1]):
|
||||
row_cells = level_table.add_row().cells
|
||||
row_cells[0].text = level_name
|
||||
row_cells[1].text = str(count)
|
||||
row_cells[2].text = f'{count / total * 100:.1f}%' if total > 0 else '0%'
|
||||
|
||||
# High risk data
|
||||
doc.add_heading('四、高敏感数据清单(L4/L5)', level=1)
|
||||
high_risk = [r for r in results if r.level and r.level.code in ('L4', 'L5')]
|
||||
if high_risk:
|
||||
risk_table = doc.add_table(rows=1, cols=5)
|
||||
risk_table.style = 'Light Grid Accent 1'
|
||||
hdr = risk_table.rows[0].cells
|
||||
hdr[0].text = '字段名'
|
||||
hdr[1].text = '所属表'
|
||||
hdr[2].text = '分类'
|
||||
hdr[3].text = '分级'
|
||||
hdr[4].text = '来源'
|
||||
for r in high_risk[:100]: # limit to 100 rows
|
||||
row = risk_table.add_row().cells
|
||||
row[0].text = r.column.name if r.column else 'N/A'
|
||||
row[1].text = r.column.table.name if r.column and r.column.table else 'N/A'
|
||||
row[2].text = r.category.name if r.category else 'N/A'
|
||||
row[3].text = r.level.name if r.level else 'N/A'
|
||||
row[4].text = '自动' if r.source == 'auto' else '人工'
|
||||
else:
|
||||
doc.add_paragraph('暂无L4/L5级高敏感数据。')
|
||||
|
||||
# Save to bytes
|
||||
buffer = BytesIO()
|
||||
doc.save(buffer)
|
||||
buffer.seek(0)
|
||||
return buffer.read()
|
||||
@@ -0,0 +1,122 @@
|
||||
from typing import Optional, List, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.models.project import ClassificationTask, ClassificationProject, ClassificationResult, TaskStatus, ResultStatus
|
||||
from app.models.metadata import DataColumn, DataTable, Database as MetaDatabase
|
||||
|
||||
|
||||
def get_task(db: Session, task_id: int) -> Optional[ClassificationTask]:
|
||||
return db.query(ClassificationTask).filter(ClassificationTask.id == task_id).first()
|
||||
|
||||
|
||||
def list_tasks(
|
||||
db: Session,
|
||||
project_id: Optional[int] = None,
|
||||
assignee_id: Optional[int] = None,
|
||||
status: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> Tuple[List[ClassificationTask], int]:
|
||||
query = db.query(ClassificationTask)
|
||||
if project_id:
|
||||
query = query.filter(ClassificationTask.project_id == project_id)
|
||||
if assignee_id:
|
||||
query = query.filter(ClassificationTask.assignee_id == assignee_id)
|
||||
if status:
|
||||
query = query.filter(ClassificationTask.status == status)
|
||||
total = query.count()
|
||||
items = query.order_by(ClassificationTask.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
return items, total
|
||||
|
||||
|
||||
def create_task(
|
||||
db: Session,
|
||||
project_id: int,
|
||||
name: str,
|
||||
assigner_id: int,
|
||||
assignee_id: int,
|
||||
target_type: str = "column",
|
||||
target_ids: Optional[str] = None,
|
||||
deadline: Optional[str] = None,
|
||||
) -> ClassificationTask:
|
||||
from datetime import datetime
|
||||
db_obj = ClassificationTask(
|
||||
project_id=project_id,
|
||||
name=name,
|
||||
assigner_id=assigner_id,
|
||||
assignee_id=assignee_id,
|
||||
target_type=target_type,
|
||||
target_ids=target_ids,
|
||||
status=TaskStatus.PENDING.value,
|
||||
deadline=datetime.fromisoformat(deadline) if deadline else None,
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
|
||||
def update_task_status(db: Session, task: ClassificationTask, status: str) -> ClassificationTask:
|
||||
task.status = status
|
||||
if status == TaskStatus.COMPLETED.value:
|
||||
from datetime import datetime
|
||||
task.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
|
||||
def assign_columns_to_task(db: Session, project_id: int, task_id: int, column_ids: List[int]) -> None:
|
||||
"""Assign columns to a task by creating/updating classification results."""
|
||||
from app.services.project_service import list_results
|
||||
for col_id in column_ids:
|
||||
result = db.query(ClassificationResult).filter(
|
||||
ClassificationResult.project_id == project_id,
|
||||
ClassificationResult.column_id == col_id,
|
||||
).first()
|
||||
if not result:
|
||||
result = ClassificationResult(
|
||||
project_id=project_id,
|
||||
column_id=col_id,
|
||||
status=ResultStatus.AUTO.value,
|
||||
source="auto",
|
||||
confidence=0.0,
|
||||
)
|
||||
db.add(result)
|
||||
db.commit()
|
||||
|
||||
|
||||
def get_task_label_items(db: Session, project_id: int, keyword: Optional[str] = None) -> List[dict]:
|
||||
"""Get all label items for a project (used in task labeling view)."""
|
||||
query = db.query(ClassificationResult).filter(ClassificationResult.project_id == project_id)
|
||||
results = query.all()
|
||||
|
||||
items = []
|
||||
for r in results:
|
||||
col = r.column
|
||||
if not col:
|
||||
continue
|
||||
table = col.table
|
||||
database = table.database if table else None
|
||||
source = database.source if database else None
|
||||
|
||||
items.append({
|
||||
"result_id": r.id,
|
||||
"column_id": col.id,
|
||||
"column_name": col.name,
|
||||
"data_type": col.data_type,
|
||||
"comment": col.comment,
|
||||
"table_name": table.name if table else None,
|
||||
"database_name": database.name if database else None,
|
||||
"source_name": source.name if source else None,
|
||||
"category_id": r.category_id,
|
||||
"category_name": r.category.name if r.category else None,
|
||||
"level_id": r.level_id,
|
||||
"level_name": r.level.name if r.level else None,
|
||||
"level_color": r.level.color if r.level else None,
|
||||
"source": r.source,
|
||||
"confidence": r.confidence,
|
||||
"status": r.status,
|
||||
})
|
||||
return items
|
||||
@@ -388,3 +388,65 @@ INFO: 127.0.0.1:63058 - "GET /api/v1/classifications/categories/tree HTTP/1.
|
||||
INFO: 127.0.0.1:63065 - "GET /api/v1/projects?page=1&page_size=100 HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:63067 - "GET /api/v1/classifications/levels HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:63069 - "GET /api/v1/projects?page=1&page_size=50 HTTP/1.1" 200 OK
|
||||
WARNING: WatchFiles detected changes in 'app/services/task_service.py'. Reloading...
|
||||
INFO: Shutting down
|
||||
INFO: Waiting for application shutdown.
|
||||
INFO: Application shutdown complete.
|
||||
INFO: Finished server process [27900]
|
||||
/Users/nathan/Work/DataPointer/prop-data-guard/backend/.venv/lib/python3.9/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020
|
||||
warnings.warn(
|
||||
INFO: Started server process [32965]
|
||||
INFO: Waiting for application startup.
|
||||
INFO: Application startup complete.
|
||||
WARNING: WatchFiles detected changes in 'app/api/v1/task.py'. Reloading...
|
||||
INFO: Shutting down
|
||||
INFO: Waiting for application shutdown.
|
||||
INFO: Application shutdown complete.
|
||||
INFO: Finished server process [32965]
|
||||
/Users/nathan/Work/DataPointer/prop-data-guard/backend/.venv/lib/python3.9/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020
|
||||
warnings.warn(
|
||||
INFO: Started server process [33004]
|
||||
INFO: Waiting for application startup.
|
||||
INFO: Application startup complete.
|
||||
WARNING: WatchFiles detected changes in 'app/services/report_service.py'. Reloading...
|
||||
INFO: Shutting down
|
||||
INFO: Waiting for application shutdown.
|
||||
INFO: Application shutdown complete.
|
||||
INFO: Finished server process [33004]
|
||||
/Users/nathan/Work/DataPointer/prop-data-guard/backend/.venv/lib/python3.9/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020
|
||||
warnings.warn(
|
||||
INFO: Started server process [33876]
|
||||
INFO: Waiting for application startup.
|
||||
INFO: Application startup complete.
|
||||
WARNING: WatchFiles detected changes in 'app/api/v1/report.py'. Reloading...
|
||||
INFO: Shutting down
|
||||
INFO: Waiting for application shutdown.
|
||||
INFO: Application shutdown complete.
|
||||
INFO: Finished server process [33876]
|
||||
/Users/nathan/Work/DataPointer/prop-data-guard/backend/.venv/lib/python3.9/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020
|
||||
warnings.warn(
|
||||
INFO: Started server process [33886]
|
||||
INFO: Waiting for application startup.
|
||||
INFO: Application startup complete.
|
||||
WARNING: WatchFiles detected changes in 'app/api/v1/__init__.py'. Reloading...
|
||||
INFO: Shutting down
|
||||
INFO: Waiting for application shutdown.
|
||||
INFO: Application shutdown complete.
|
||||
INFO: Finished server process [33886]
|
||||
/Users/nathan/Work/DataPointer/prop-data-guard/backend/.venv/lib/python3.9/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020
|
||||
warnings.warn(
|
||||
INFO: Started server process [34876]
|
||||
INFO: Waiting for application startup.
|
||||
INFO: Application startup complete.
|
||||
WARNING: WatchFiles detected changes in 'tests/test_auth.py'. Reloading...
|
||||
INFO: Shutting down
|
||||
INFO: Waiting for application shutdown.
|
||||
INFO: Application shutdown complete.
|
||||
INFO: Finished server process [34876]
|
||||
/Users/nathan/Work/DataPointer/prop-data-guard/backend/.venv/lib/python3.9/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020
|
||||
warnings.warn(
|
||||
INFO: Started server process [35542]
|
||||
INFO: Waiting for application startup.
|
||||
INFO: Application startup complete.
|
||||
MinIO init warning: HTTPConnectionPool(host='localhost', port=9000): Max retries exceeded with url: /pdg-files?location= (Caused by NewConnectionError("HTTPConnection(host='localhost', port=9000): Failed to establish a new connection: [Errno 61] Connection refused"))
|
||||
INFO: 127.0.0.1:50646 - "GET /health HTTP/1.1" 200 OK
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from app.main import app
|
||||
from app.core.database import Base, get_db
|
||||
from app.core.config import settings
|
||||
|
||||
# Use SQLite for testing
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def override_get_db():
|
||||
db = TestingSessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_db():
|
||||
Base.metadata.create_all(bind=engine)
|
||||
yield
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
def test_health_check():
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
|
||||
def test_login():
|
||||
response = client.post("/api/v1/auth/login", json={"username": "admin", "password": "admin123"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 200
|
||||
assert "access_token" in data["data"]
|
||||
return data["data"]["access_token"]
|
||||
|
||||
|
||||
def test_get_me():
|
||||
token = test_login()
|
||||
response = client.get("/api/v1/users/me", headers={"Authorization": f"Bearer {token}"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["data"]["username"] == "admin"
|
||||
|
||||
|
||||
def test_list_levels():
|
||||
token = test_login()
|
||||
response = client.get("/api/v1/classifications/levels", headers={"Authorization": f"Bearer {token}"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["data"]) == 5
|
||||
Reference in New Issue
Block a user