143 lines
4.6 KiB
Python
143 lines
4.6 KiB
Python
"""试卷上传 + 处理管线"""
|
||
|
||
import asyncio
|
||
import threading
|
||
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends
|
||
from app.services.supabase_client import get_supabase
|
||
from app.services.text_extractor import extract_pdf, get_full_text
|
||
from app.services.paper_processor import process_paper
|
||
from app.dependencies.auth import get_current_user_id
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
def _upload_and_process_sync(
|
||
paper_id: str,
|
||
storage_path: str,
|
||
paper_bytes: bytes,
|
||
answer_bytes: bytes | None,
|
||
):
|
||
"""在独立线程中运行:Storage 上传 + AI 处理"""
|
||
sb = get_supabase()
|
||
try:
|
||
paper_storage_path = f"{storage_path}/paper.pdf"
|
||
sb.storage.from_("papers").upload(
|
||
paper_storage_path, paper_bytes,
|
||
file_options={"content-type": "application/pdf", "upsert": "true"},
|
||
)
|
||
paper_url = sb.storage.from_("papers").get_public_url(paper_storage_path)
|
||
|
||
update_data: dict = {"paper_file_url": paper_url}
|
||
|
||
if answer_bytes:
|
||
answer_storage_path = f"{storage_path}/answer.pdf"
|
||
sb.storage.from_("papers").upload(
|
||
answer_storage_path, answer_bytes,
|
||
file_options={"content-type": "application/pdf", "upsert": "true"},
|
||
)
|
||
update_data["answer_file_url"] = sb.storage.from_("papers").get_public_url(answer_storage_path)
|
||
|
||
sb.table("papers").update(update_data).eq("id", paper_id).execute()
|
||
except Exception:
|
||
pass
|
||
|
||
# process_paper 是 async,在新事件循环里跑
|
||
asyncio.run(process_paper(paper_id, paper_bytes, answer_bytes))
|
||
|
||
|
||
@router.get("/")
|
||
async def list_papers():
|
||
"""获取试卷列表(公共资产,所有用户共享)"""
|
||
sb = get_supabase()
|
||
return (
|
||
sb.table("papers")
|
||
.select("id, course_code, year, term, exam_type, status, question_count, total_score, difficulty_level, processing_step, processing_progress, processing_total, created_at")
|
||
.order("created_at", desc=True)
|
||
.execute()
|
||
.data
|
||
)
|
||
|
||
|
||
@router.get("/mine")
|
||
async def my_papers(user_id: str = Depends(get_current_user_id)):
|
||
"""当前用户上传的试卷(含 processing 状态)"""
|
||
sb = get_supabase()
|
||
return (
|
||
sb.table("papers")
|
||
.select("id, course_code, year, term, exam_type, part_label, status, question_count, processing_step, processing_progress, processing_total, created_at")
|
||
.eq("user_id", user_id)
|
||
.order("created_at", desc=True)
|
||
.execute()
|
||
.data
|
||
)
|
||
|
||
|
||
@router.post("/upload")
|
||
async def upload_paper(
|
||
paper_file: UploadFile = File(...),
|
||
answer_file: UploadFile | None = File(None),
|
||
course_code: str = Form(...),
|
||
year: int = Form(...),
|
||
term: str = Form(...),
|
||
exam_type: str = Form(...),
|
||
user_id: str = Depends(get_current_user_id),
|
||
):
|
||
"""上传试卷 PDF(可选答案 PDF),触发后台处理"""
|
||
sb = get_supabase()
|
||
|
||
# 1. 读取文件内容(已在内存中,快)
|
||
paper_bytes = await paper_file.read()
|
||
answer_bytes = await answer_file.read() if answer_file else None
|
||
|
||
# 2. 立即创建记录(status=processing),马上返回
|
||
storage_path = f"{course_code.upper()}/{year}_{term}_{exam_type}"
|
||
paper_record = sb.table("papers").insert({
|
||
"user_id": user_id,
|
||
"course_code": course_code.upper(),
|
||
"year": year,
|
||
"term": term,
|
||
"exam_type": exam_type,
|
||
"paper_file_url": "", # 后台上传后更新
|
||
"answer_file_url": None,
|
||
"status": "processing",
|
||
}).execute()
|
||
|
||
paper_id = paper_record.data[0]["id"]
|
||
|
||
# 3. 在独立线程中运行,完全不阻塞事件循环
|
||
threading.Thread(
|
||
target=_upload_and_process_sync,
|
||
args=(paper_id, storage_path, paper_bytes, answer_bytes),
|
||
daemon=True,
|
||
).start()
|
||
|
||
return {
|
||
"paper_id": paper_id,
|
||
"status": "processing",
|
||
"message": "试卷已上传,正在处理中...",
|
||
}
|
||
|
||
|
||
@router.get("/{paper_id}")
|
||
async def get_paper(paper_id: str):
|
||
"""获取试卷信息 + 处理状态"""
|
||
sb = get_supabase()
|
||
result = sb.table("papers").select("*").eq("id", paper_id).execute()
|
||
if not result.data:
|
||
raise HTTPException(status_code=404, detail="Paper not found")
|
||
return result.data[0]
|
||
|
||
|
||
@router.get("/{paper_id}/questions")
|
||
async def get_questions(paper_id: str):
|
||
"""获取试卷的所有题目(含 AI 三件套)"""
|
||
sb = get_supabase()
|
||
result = (
|
||
sb.table("paper_questions")
|
||
.select("*")
|
||
.eq("paper_id", paper_id)
|
||
.order("display_order")
|
||
.execute()
|
||
)
|
||
return result.data
|