Initial commit: PastPaper Master full stack
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
0
backend/app/__init__.py
Normal file
0
backend/app/__init__.py
Normal file
36
backend/app/config.py
Normal file
36
backend/app/config.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
import os
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# Supabase
|
||||
supabase_url: str
|
||||
supabase_anon_key: str
|
||||
supabase_service_role_key: str
|
||||
|
||||
# LLM - laozhang (gpt-4o, gpt-4o-mini)
|
||||
laozhang_base_url: str = "https://api.laozhang.ai/v1"
|
||||
laozhang_api_key: str = ""
|
||||
|
||||
# LLM - DashScope (qwen-plus)
|
||||
dashscope_base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
dashscope_api_key: str = ""
|
||||
|
||||
# LLM - DeepSeek
|
||||
deepseek_base_url: str = "https://api.deepseek.com/v1"
|
||||
deepseek_api_key: str = ""
|
||||
|
||||
# Google Gemini (official)
|
||||
google_gemini_api_key: str = ""
|
||||
|
||||
model_config = {
|
||||
"env_file": os.path.join(os.path.dirname(__file__), "../../.env"),
|
||||
"env_file_encoding": "utf-8",
|
||||
"extra": "ignore",
|
||||
}
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
0
backend/app/dependencies/__init__.py
Normal file
0
backend/app/dependencies/__init__.py
Normal file
34
backend/app/dependencies/auth.py
Normal file
34
backend/app/dependencies/auth.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Auth dependency: validate Supabase JWT and return user_id"""
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from app.services.supabase_client import get_supabase
|
||||
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def get_current_user_id(
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
) -> str:
|
||||
"""Extract and validate Bearer token, return user_id."""
|
||||
if not credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
)
|
||||
token = credentials.credentials
|
||||
sb = get_supabase()
|
||||
try:
|
||||
result = sb.auth.get_user(token)
|
||||
user = result.user
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
)
|
||||
return user.id
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
)
|
||||
59
backend/app/main.py
Normal file
59
backend/app/main.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.routers import analytics, papers, attempts, questions
|
||||
|
||||
|
||||
def _resume_stale_papers():
|
||||
"""启动时检查卡在 processing 的 paper,自动续传 AI trio"""
|
||||
try:
|
||||
from app.services.supabase_client import get_supabase
|
||||
from app.services.paper_processor import process_paper
|
||||
|
||||
sb = get_supabase()
|
||||
stale = sb.table("papers").select("id").eq("status", "processing").execute().data
|
||||
if not stale:
|
||||
return
|
||||
|
||||
for p in stale:
|
||||
paper_id = p["id"]
|
||||
print(f"[STARTUP] Resuming processing for paper {paper_id[:8]}...")
|
||||
|
||||
def run(pid=paper_id):
|
||||
asyncio.run(process_paper(pid, b"", None))
|
||||
|
||||
threading.Thread(target=run, daemon=True).start()
|
||||
except Exception as e:
|
||||
print(f"[STARTUP] Resume skipped: {e}")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
_resume_stale_papers()
|
||||
yield
|
||||
# Shutdown (nothing to do)
|
||||
|
||||
|
||||
app = FastAPI(title="PastPaper Master API", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 开发阶段先放开,上线收紧
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(papers.router, prefix="/api/papers", tags=["papers"])
|
||||
app.include_router(attempts.router, prefix="/api/attempts", tags=["attempts"])
|
||||
app.include_router(questions.router, prefix="/api/questions", tags=["questions"])
|
||||
app.include_router(analytics.router, prefix="/api/analytics", tags=["analytics"])
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok"}
|
||||
0
backend/app/routers/__init__.py
Normal file
0
backend/app/routers/__init__.py
Normal file
285
backend/app/routers/analytics.py
Normal file
285
backend/app/routers/analytics.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""Course-level analytics endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.services.supabase_client import get_supabase
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
DIFFICULTY_SCORE = {"easy": 1, "medium": 2, "hard": 3}
|
||||
DIFFICULTY_LABEL = {1: "Easy", 2: "Medium", 3: "Hard"}
|
||||
|
||||
# ── Topic normalization ──────────────────────────────────────
|
||||
# Map variant spellings to canonical label
|
||||
_TOPIC_ALIASES: dict[str, str] = {
|
||||
"numpy": "NumPy",
|
||||
"naïve bayes": "Naive Bayes",
|
||||
"naïve bayes classifier": "Naive Bayes",
|
||||
"naive bayes classifier": "Naive Bayes",
|
||||
"bayes classifier": "Naive Bayes",
|
||||
"bayes model": "Naive Bayes",
|
||||
"bayes' theorem": "Naive Bayes",
|
||||
"bayes' rule": "Naive Bayes",
|
||||
"k-nearest neighbors": "K-Nearest Neighbors (KNN)",
|
||||
"knn": "K-Nearest Neighbors (KNN)",
|
||||
"k-means clustering": "K-Means Clustering",
|
||||
"k-means": "K-Means Clustering",
|
||||
"k means": "K-Means Clustering",
|
||||
"multilayer perceptron": "Multilayer Perceptron (MLP)",
|
||||
"multi-layer perceptron": "Multilayer Perceptron (MLP)",
|
||||
"multi-layer perceptron (mlp)": "Multilayer Perceptron (MLP)",
|
||||
"mlp": "Multilayer Perceptron (MLP)",
|
||||
"single layer perceptron": "Perceptron",
|
||||
"convolutional neural network": "CNN",
|
||||
"convolutional neural network (cnn)": "CNN",
|
||||
"convolutional neural networks": "CNN",
|
||||
"cnn architecture": "CNN",
|
||||
"cnn properties": "CNN",
|
||||
"python fundamentals": "Python",
|
||||
"python programming": "Python",
|
||||
"python implementation": "Python",
|
||||
"advanced python programming": "Python",
|
||||
"python programming: convolutional neural network": "CNN",
|
||||
"cross-validation": "Cross Validation",
|
||||
"model evaluation implementation": "Model Evaluation",
|
||||
"digital image processing": "Image Processing",
|
||||
"computer vision": "Image Processing",
|
||||
"array slicing": "Array Slicing",
|
||||
"slicing": "Array Slicing",
|
||||
"array indexing": "Array Slicing",
|
||||
"array reshaping": "Reshape",
|
||||
"array views": "Array Slicing",
|
||||
"view vs copy": "Array Slicing",
|
||||
"boolean indexing": "Array Slicing",
|
||||
"arange": "NumPy",
|
||||
"newaxis": "NumPy",
|
||||
"expand dims": "NumPy",
|
||||
"transpose": "NumPy",
|
||||
"type casting": "NumPy",
|
||||
"element-wise operation": "NumPy",
|
||||
"array reduction": "NumPy",
|
||||
"multi-dimensional array": "NumPy",
|
||||
"dot product": "NumPy",
|
||||
"vectorization": "NumPy",
|
||||
"activation functions": "Activation Function",
|
||||
"linear activation function": "Activation Function",
|
||||
"neural network architecture": "Neural Networks",
|
||||
"hidden layer": "Neural Networks",
|
||||
"deep learning": "Neural Networks",
|
||||
"deep learning frameworks": "Neural Networks",
|
||||
"alpha-beta pruning": "Alpha-Beta Pruning",
|
||||
"minimax algorithm": "Minimax",
|
||||
"ethics of ai": "AI Ethics",
|
||||
"ethics": "AI Ethics",
|
||||
"cosine distance": "Cosine Similarity",
|
||||
"distance calculation": "Distance Metrics",
|
||||
"euclidean distance": "Distance Metrics",
|
||||
"manhattan distance": "Distance Metrics",
|
||||
"hamming distance": "Distance Metrics",
|
||||
"precision": "Model Evaluation",
|
||||
"recall": "Model Evaluation",
|
||||
"f1 score": "Model Evaluation",
|
||||
"macro f1 score": "Model Evaluation",
|
||||
"accuracy": "Model Evaluation",
|
||||
"classification accuracy": "Model Evaluation",
|
||||
"confusion matrix": "Model Evaluation",
|
||||
"convolution operation": "Convolution",
|
||||
"dilated convolution": "Convolution",
|
||||
"3d convolution": "Convolution",
|
||||
"gaussian likelihood": "Probability",
|
||||
"gaussian distribution": "Probability",
|
||||
"categorical likelihood": "Probability",
|
||||
"conditional probability": "Probability",
|
||||
"total probability theorem": "Probability",
|
||||
"probability assumptions": "Probability",
|
||||
"tensorflow": "Keras",
|
||||
"model summary": "Keras",
|
||||
"model construction": "Keras",
|
||||
"trainable parameters": "Parameter Calculation",
|
||||
"parameter reduction": "Parameter Calculation",
|
||||
"output shape calculation": "Parameter Calculation",
|
||||
"shape calculation": "Parameter Calculation",
|
||||
}
|
||||
|
||||
|
||||
def normalize_topic(label: str) -> str:
|
||||
return _TOPIC_ALIASES.get(label.lower().strip(), label)
|
||||
|
||||
|
||||
def extract_topic_labels(question: dict) -> list[str]:
|
||||
labels: list[str] = []
|
||||
raw_labels: list[str] = []
|
||||
|
||||
analytics_topic = question.get("analytics_topic")
|
||||
if analytics_topic:
|
||||
raw_labels.append(analytics_topic)
|
||||
|
||||
for tag in question.get("topic_tags") or []:
|
||||
if tag and tag not in raw_labels:
|
||||
raw_labels.append(tag)
|
||||
|
||||
if not raw_labels:
|
||||
for tag in question.get("topics") or []:
|
||||
if tag and tag not in raw_labels:
|
||||
raw_labels.append(tag)
|
||||
|
||||
# Normalize and deduplicate
|
||||
seen: set[str] = set()
|
||||
for raw in raw_labels:
|
||||
norm = normalize_topic(raw)
|
||||
if norm not in seen:
|
||||
seen.add(norm)
|
||||
labels.append(norm)
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def extract_question_family(question: dict) -> str:
|
||||
return (
|
||||
question.get("question_format")
|
||||
or question.get("question_type")
|
||||
or "unknown"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/courses")
|
||||
async def list_courses():
|
||||
"""返回所有有 ready 状态试卷的课程列表"""
|
||||
sb = get_supabase()
|
||||
rows = (
|
||||
sb.table("papers")
|
||||
.select("course_code")
|
||||
.eq("status", "ready")
|
||||
.execute()
|
||||
.data
|
||||
)
|
||||
codes = sorted({row["course_code"] for row in rows if row.get("course_code")})
|
||||
return codes
|
||||
|
||||
|
||||
@router.get("/course/{course_code}")
|
||||
async def get_course_analytics(course_code: str):
|
||||
sb = get_supabase()
|
||||
|
||||
papers = (
|
||||
sb.table("papers")
|
||||
.select("id, course_code, year, term, exam_type, part_label, status")
|
||||
.eq("course_code", course_code.upper())
|
||||
.eq("status", "ready")
|
||||
.order("year", desc=True)
|
||||
.execute()
|
||||
.data
|
||||
)
|
||||
if not papers:
|
||||
return {
|
||||
"course_code": course_code.upper(),
|
||||
"kpi": {"papers": 0, "questions": 0, "topics": 0, "difficulty": "N/A"},
|
||||
"topic_frequency": [],
|
||||
"question_types": [],
|
||||
"difficulty_distribution": {"easy": 0, "medium": 0, "hard": 0},
|
||||
"high_yield_topics": [],
|
||||
}
|
||||
|
||||
paper_ids = [paper["id"] for paper in papers]
|
||||
questions = (
|
||||
sb.table("paper_questions")
|
||||
.select(
|
||||
"id, paper_id, question_number, question_type, question_format, "
|
||||
"question_text, score, topics, analytics_topic, topic_tags, difficulty"
|
||||
)
|
||||
.in_("paper_id", paper_ids)
|
||||
.order("display_order")
|
||||
.execute()
|
||||
.data
|
||||
)
|
||||
|
||||
papers_by_id = {paper["id"]: paper for paper in papers}
|
||||
total_questions = len(questions)
|
||||
topic_counter: Counter[str] = Counter()
|
||||
type_counter: Counter[str] = Counter()
|
||||
difficulty_counter: Counter[str] = Counter()
|
||||
topic_examples: dict[str, list[dict]] = defaultdict(list)
|
||||
difficulty_scores: list[int] = []
|
||||
all_question_items: list[dict] = []
|
||||
|
||||
for question in questions:
|
||||
question_type = extract_question_family(question)
|
||||
type_counter[question_type] += 1
|
||||
|
||||
difficulty = question.get("difficulty")
|
||||
if difficulty in DIFFICULTY_SCORE:
|
||||
difficulty_counter[difficulty] += 1
|
||||
difficulty_scores.append(DIFFICULTY_SCORE[difficulty])
|
||||
|
||||
paper = papers_by_id.get(question["paper_id"], {})
|
||||
source_label = (
|
||||
f"{paper.get('year', '')} {paper.get('term', '').title()} "
|
||||
f"{paper.get('exam_type', '').title()}"
|
||||
).strip()
|
||||
if paper.get("part_label"):
|
||||
source_label = f"{source_label} Part {paper['part_label']}"
|
||||
|
||||
topics = extract_topic_labels(question)
|
||||
q_item = {
|
||||
"paper_id": paper.get("id"),
|
||||
"source": source_label,
|
||||
"question_number": question["question_number"],
|
||||
"preview": question["question_text"][:220],
|
||||
"difficulty": question.get("difficulty"),
|
||||
"question_type": question_type,
|
||||
"year": paper.get("year"),
|
||||
"term": paper.get("term"),
|
||||
"exam_type": paper.get("exam_type"),
|
||||
"topics": topics,
|
||||
}
|
||||
all_question_items.append(q_item)
|
||||
|
||||
for topic in topics:
|
||||
topic_counter[topic] += 1
|
||||
topic_examples[topic].append(q_item)
|
||||
|
||||
avg_difficulty = "N/A"
|
||||
if difficulty_scores:
|
||||
rounded = round(sum(difficulty_scores) / len(difficulty_scores))
|
||||
avg_difficulty = DIFFICULTY_LABEL.get(rounded, "Medium")
|
||||
|
||||
topic_frequency = []
|
||||
for topic, count in topic_counter.most_common():
|
||||
pct = round((count / total_questions) * 100) if total_questions else 0
|
||||
topic_frequency.append(
|
||||
{
|
||||
"label": topic,
|
||||
"count": count,
|
||||
"pct": pct,
|
||||
"questions": topic_examples[topic],
|
||||
}
|
||||
)
|
||||
|
||||
question_types = []
|
||||
for label, count in type_counter.most_common():
|
||||
pct = round((count / total_questions) * 100) if total_questions else 0
|
||||
question_types.append({"label": label, "count": count, "pct": pct})
|
||||
|
||||
return {
|
||||
"course_code": course_code.upper(),
|
||||
"kpi": {
|
||||
"papers": len(papers),
|
||||
"questions": total_questions,
|
||||
"topics": len(topic_counter),
|
||||
"difficulty": avg_difficulty,
|
||||
},
|
||||
"topic_frequency": topic_frequency,
|
||||
"question_types": question_types,
|
||||
"all_questions": all_question_items,
|
||||
"difficulty_distribution": {
|
||||
"easy": difficulty_counter.get("easy", 0),
|
||||
"medium": difficulty_counter.get("medium", 0),
|
||||
"hard": difficulty_counter.get("hard", 0),
|
||||
},
|
||||
"high_yield_topics": [topic for topic, _ in topic_counter.most_common(5)],
|
||||
}
|
||||
208
backend/app/routers/attempts.py
Normal file
208
backend/app/routers/attempts.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""用户答题记录 + 拍照批改 + 错题本"""
|
||||
|
||||
import asyncio
|
||||
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends
|
||||
from pydantic import BaseModel
|
||||
from app.services.supabase_client import get_supabase
|
||||
from app.services.grader import ocr_photo, grade_answer
|
||||
from app.dependencies.auth import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class AttemptCreate(BaseModel):
|
||||
question_id: str
|
||||
attempt_type: str # "select" | "input" | "photo"
|
||||
user_answer: str | None = None
|
||||
is_correct: bool | None = None
|
||||
|
||||
|
||||
class AttemptUpdate(BaseModel):
|
||||
in_error_book: bool | None = None
|
||||
mastered: bool | None = None
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_attempt(data: AttemptCreate, user_id: str = Depends(get_current_user_id)):
|
||||
"""记录一次答题"""
|
||||
sb = get_supabase()
|
||||
record = {
|
||||
"user_id": user_id,
|
||||
"question_id": data.question_id,
|
||||
"attempt_type": data.attempt_type,
|
||||
"user_answer": data.user_answer,
|
||||
"is_correct": data.is_correct,
|
||||
}
|
||||
# Auto add to error book if wrong
|
||||
if data.is_correct is False:
|
||||
record["in_error_book"] = True
|
||||
|
||||
result = sb.table("user_attempts").insert(record).execute()
|
||||
return result.data[0]
|
||||
|
||||
|
||||
@router.post("/photo")
|
||||
async def photo_attempt(
|
||||
question_id: str = Form(...),
|
||||
photo: UploadFile = File(...),
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
):
|
||||
"""拍照上传 → OCR → AI批改"""
|
||||
sb = get_supabase()
|
||||
|
||||
# 1. Read photo
|
||||
photo_bytes = await photo.read()
|
||||
|
||||
# 2. Upload to storage
|
||||
storage_path = f"attempts/{user_id}/{question_id}/{photo.filename}"
|
||||
sb.storage.from_("attempt-photos").upload(
|
||||
storage_path, photo_bytes,
|
||||
file_options={"content-type": photo.content_type or "image/jpeg", "upsert": "true"},
|
||||
)
|
||||
photo_url = sb.storage.from_("attempt-photos").get_public_url(storage_path)
|
||||
|
||||
# 3. OCR (run in thread pool to avoid blocking event loop)
|
||||
ocr_text = await asyncio.to_thread(ocr_photo, photo_bytes)
|
||||
|
||||
# 4. Fetch question for grading context
|
||||
q_result = sb.table("paper_questions").select("*").eq("id", question_id).execute()
|
||||
if not q_result.data:
|
||||
raise HTTPException(status_code=404, detail="Question not found")
|
||||
question = q_result.data[0]
|
||||
|
||||
# 5. AI grading (run in thread pool)
|
||||
grade_result = await asyncio.to_thread(grade_answer, question, ocr_text)
|
||||
|
||||
# 6. Save attempt
|
||||
record = {
|
||||
"user_id": user_id,
|
||||
"question_id": question_id,
|
||||
"attempt_type": "photo",
|
||||
"photo_url": photo_url,
|
||||
"photo_ocr_text": ocr_text,
|
||||
"is_correct": grade_result.get("is_correct", False),
|
||||
"feedback": grade_result.get("feedback", ""),
|
||||
"error_at_step": grade_result.get("error_at_step"),
|
||||
"in_error_book": not grade_result.get("is_correct", False),
|
||||
}
|
||||
result = sb.table("user_attempts").insert(record).execute()
|
||||
|
||||
return {
|
||||
"attempt": result.data[0],
|
||||
"ocr_text": ocr_text,
|
||||
"grade": grade_result,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/error-book")
|
||||
async def get_error_book(
|
||||
course_code: str | None = None,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
):
|
||||
"""获取错题本"""
|
||||
sb = get_supabase()
|
||||
attempts = (
|
||||
sb.table("user_attempts")
|
||||
.select("*")
|
||||
.eq("user_id", user_id)
|
||||
.eq("in_error_book", True)
|
||||
.eq("mastered", False)
|
||||
.order("created_at", desc=True)
|
||||
.execute()
|
||||
.data
|
||||
)
|
||||
if not attempts:
|
||||
return []
|
||||
|
||||
question_ids = list({attempt["question_id"] for attempt in attempts})
|
||||
questions = (
|
||||
sb.table("paper_questions")
|
||||
.select("*")
|
||||
.in_("id", question_ids)
|
||||
.execute()
|
||||
.data
|
||||
)
|
||||
questions_by_id = {question["id"]: question for question in questions}
|
||||
|
||||
paper_ids = list({question["paper_id"] for question in questions})
|
||||
papers = (
|
||||
sb.table("papers")
|
||||
.select("id, course_code, year, term, exam_type, part_label")
|
||||
.in_("id", paper_ids)
|
||||
.execute()
|
||||
.data
|
||||
)
|
||||
papers_by_id = {paper["id"]: paper for paper in papers}
|
||||
|
||||
enriched = []
|
||||
for attempt in attempts:
|
||||
question = questions_by_id.get(attempt["question_id"])
|
||||
if not question:
|
||||
continue
|
||||
paper = papers_by_id.get(question["paper_id"])
|
||||
if course_code and paper and paper.get("course_code") != course_code.upper():
|
||||
continue
|
||||
|
||||
enriched.append(
|
||||
{
|
||||
**attempt,
|
||||
"paper_questions": {
|
||||
**question,
|
||||
"paper": paper,
|
||||
},
|
||||
}
|
||||
)
|
||||
return enriched
|
||||
|
||||
|
||||
@router.get("/by-paper/{paper_id}")
|
||||
async def get_paper_attempts(paper_id: str, user_id: str = Depends(get_current_user_id)):
|
||||
"""获取某张试卷所有题目的最新判卷记录"""
|
||||
sb = get_supabase()
|
||||
attempts = (
|
||||
sb.table("user_attempts")
|
||||
.select("question_id, is_correct, feedback, photo_ocr_text, attempt_type, created_at")
|
||||
.eq("user_id", user_id)
|
||||
.order("created_at", desc=True)
|
||||
.execute()
|
||||
.data
|
||||
)
|
||||
# 只保留 photo 类型的,且只保留每题最新一条
|
||||
question_ids = (
|
||||
sb.table("paper_questions")
|
||||
.select("id")
|
||||
.eq("paper_id", paper_id)
|
||||
.execute()
|
||||
.data
|
||||
)
|
||||
qid_set = {q["id"] for q in question_ids}
|
||||
seen: set[str] = set()
|
||||
result = []
|
||||
for a in attempts:
|
||||
if a["question_id"] not in qid_set:
|
||||
continue
|
||||
if a["question_id"] in seen:
|
||||
continue
|
||||
if a["attempt_type"] != "photo":
|
||||
continue
|
||||
seen.add(a["question_id"])
|
||||
result.append(a)
|
||||
return result
|
||||
|
||||
|
||||
@router.patch("/{attempt_id}")
|
||||
async def update_attempt(attempt_id: str, data: AttemptUpdate):
|
||||
"""更新错题状态(标记掌握等)"""
|
||||
sb = get_supabase()
|
||||
update = {}
|
||||
if data.in_error_book is not None:
|
||||
update["in_error_book"] = data.in_error_book
|
||||
if data.mastered is not None:
|
||||
update["mastered"] = data.mastered
|
||||
if not update:
|
||||
raise HTTPException(status_code=400, detail="Nothing to update")
|
||||
|
||||
result = sb.table("user_attempts").update(update).eq("id", attempt_id).execute()
|
||||
if not result.data:
|
||||
raise HTTPException(status_code=404, detail="Attempt not found")
|
||||
return result.data[0]
|
||||
142
backend/app/routers/papers.py
Normal file
142
backend/app/routers/papers.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""试卷上传 + 处理管线"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends
|
||||
from app.services.supabase_client import get_supabase
|
||||
from app.services.text_extractor import extract_pdf, get_full_text
|
||||
from app.services.paper_processor import process_paper
|
||||
from app.dependencies.auth import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _upload_and_process_sync(
|
||||
paper_id: str,
|
||||
storage_path: str,
|
||||
paper_bytes: bytes,
|
||||
answer_bytes: bytes | None,
|
||||
):
|
||||
"""在独立线程中运行:Storage 上传 + AI 处理"""
|
||||
sb = get_supabase()
|
||||
try:
|
||||
paper_storage_path = f"{storage_path}/paper.pdf"
|
||||
sb.storage.from_("papers").upload(
|
||||
paper_storage_path, paper_bytes,
|
||||
file_options={"content-type": "application/pdf", "upsert": "true"},
|
||||
)
|
||||
paper_url = sb.storage.from_("papers").get_public_url(paper_storage_path)
|
||||
|
||||
update_data: dict = {"paper_file_url": paper_url}
|
||||
|
||||
if answer_bytes:
|
||||
answer_storage_path = f"{storage_path}/answer.pdf"
|
||||
sb.storage.from_("papers").upload(
|
||||
answer_storage_path, answer_bytes,
|
||||
file_options={"content-type": "application/pdf", "upsert": "true"},
|
||||
)
|
||||
update_data["answer_file_url"] = sb.storage.from_("papers").get_public_url(answer_storage_path)
|
||||
|
||||
sb.table("papers").update(update_data).eq("id", paper_id).execute()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# process_paper 是 async,在新事件循环里跑
|
||||
asyncio.run(process_paper(paper_id, paper_bytes, answer_bytes))
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_papers():
|
||||
"""获取试卷列表(公共资产,所有用户共享)"""
|
||||
sb = get_supabase()
|
||||
return (
|
||||
sb.table("papers")
|
||||
.select("id, course_code, year, term, exam_type, status, question_count, total_score, difficulty_level, processing_step, processing_progress, processing_total, created_at")
|
||||
.order("created_at", desc=True)
|
||||
.execute()
|
||||
.data
|
||||
)
|
||||
|
||||
|
||||
@router.get("/mine")
|
||||
async def my_papers(user_id: str = Depends(get_current_user_id)):
|
||||
"""当前用户上传的试卷(含 processing 状态)"""
|
||||
sb = get_supabase()
|
||||
return (
|
||||
sb.table("papers")
|
||||
.select("id, course_code, year, term, exam_type, part_label, status, question_count, processing_step, processing_progress, processing_total, created_at")
|
||||
.eq("user_id", user_id)
|
||||
.order("created_at", desc=True)
|
||||
.execute()
|
||||
.data
|
||||
)
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_paper(
|
||||
paper_file: UploadFile = File(...),
|
||||
answer_file: UploadFile | None = File(None),
|
||||
course_code: str = Form(...),
|
||||
year: int = Form(...),
|
||||
term: str = Form(...),
|
||||
exam_type: str = Form(...),
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
):
|
||||
"""上传试卷 PDF(可选答案 PDF),触发后台处理"""
|
||||
sb = get_supabase()
|
||||
|
||||
# 1. 读取文件内容(已在内存中,快)
|
||||
paper_bytes = await paper_file.read()
|
||||
answer_bytes = await answer_file.read() if answer_file else None
|
||||
|
||||
# 2. 立即创建记录(status=processing),马上返回
|
||||
storage_path = f"{course_code.upper()}/{year}_{term}_{exam_type}"
|
||||
paper_record = sb.table("papers").insert({
|
||||
"user_id": user_id,
|
||||
"course_code": course_code.upper(),
|
||||
"year": year,
|
||||
"term": term,
|
||||
"exam_type": exam_type,
|
||||
"paper_file_url": "", # 后台上传后更新
|
||||
"answer_file_url": None,
|
||||
"status": "processing",
|
||||
}).execute()
|
||||
|
||||
paper_id = paper_record.data[0]["id"]
|
||||
|
||||
# 3. 在独立线程中运行,完全不阻塞事件循环
|
||||
threading.Thread(
|
||||
target=_upload_and_process_sync,
|
||||
args=(paper_id, storage_path, paper_bytes, answer_bytes),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
return {
|
||||
"paper_id": paper_id,
|
||||
"status": "processing",
|
||||
"message": "试卷已上传,正在处理中...",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{paper_id}")
|
||||
async def get_paper(paper_id: str):
|
||||
"""获取试卷信息 + 处理状态"""
|
||||
sb = get_supabase()
|
||||
result = sb.table("papers").select("*").eq("id", paper_id).execute()
|
||||
if not result.data:
|
||||
raise HTTPException(status_code=404, detail="Paper not found")
|
||||
return result.data[0]
|
||||
|
||||
|
||||
@router.get("/{paper_id}/questions")
|
||||
async def get_questions(paper_id: str):
|
||||
"""获取试卷的所有题目(含 AI 三件套)"""
|
||||
sb = get_supabase()
|
||||
result = (
|
||||
sb.table("paper_questions")
|
||||
.select("*")
|
||||
.eq("paper_id", paper_id)
|
||||
.order("display_order")
|
||||
.execute()
|
||||
)
|
||||
return result.data
|
||||
325
backend/app/routers/questions.py
Normal file
325
backend/app/routers/questions.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""题目相关:变式题生成 + 相似题召回"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel
|
||||
from app.services.supabase_client import get_supabase
|
||||
from app.services.grader import generate_variant
|
||||
from app.dependencies.auth import get_current_user_id
|
||||
|
||||
# Simple in-memory cache: question_id → (timestamp, result)
|
||||
_similar_cache: dict[str, tuple[float, list]] = {}
|
||||
_CACHE_TTL = 300 # 5 minutes
|
||||
|
||||
|
||||
class VariantUpdate(BaseModel):
|
||||
favorited: bool | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def normalized_labels(values: list[str] | None) -> dict[str, str]:
|
||||
labels: dict[str, str] = {}
|
||||
for value in values or []:
|
||||
if value:
|
||||
labels[value.lower()] = value
|
||||
return labels
|
||||
|
||||
|
||||
def question_family(question: dict) -> str:
|
||||
return question.get("question_format") or question.get("question_type") or "unknown"
|
||||
|
||||
|
||||
def display_topics(question: dict) -> list[str]:
|
||||
labels: list[str] = []
|
||||
analytics_topic = question.get("analytics_topic")
|
||||
if analytics_topic:
|
||||
labels.append(analytics_topic)
|
||||
for topic in question.get("topic_tags") or []:
|
||||
if topic and topic not in labels:
|
||||
labels.append(topic)
|
||||
if labels:
|
||||
return labels
|
||||
for topic in question.get("topics") or []:
|
||||
if topic and topic not in labels:
|
||||
labels.append(topic)
|
||||
return labels
|
||||
|
||||
|
||||
def similarity_score(
|
||||
target: dict,
|
||||
candidate: dict,
|
||||
text_score: float = 0.0,
|
||||
) -> tuple[int, list[str]]:
|
||||
score = 0
|
||||
reasons: list[str] = []
|
||||
|
||||
# Primary topic bucket: 40 pts
|
||||
target_topic = target.get("analytics_topic")
|
||||
candidate_topic = candidate.get("analytics_topic")
|
||||
if target_topic and target_topic == candidate_topic:
|
||||
score += 40
|
||||
reasons.append(f"Same topic: {target_topic}")
|
||||
|
||||
# Concept overlap: up to 20 pts
|
||||
target_topics = normalized_labels(target.get("topic_tags"))
|
||||
candidate_topics = normalized_labels(candidate.get("topic_tags"))
|
||||
shared_topics = sorted(set(target_topics) & set(candidate_topics))
|
||||
if shared_topics:
|
||||
score += min(len(shared_topics) * 10, 20)
|
||||
# Only show concept reason if analytics_topic didn't already match (avoid redundancy)
|
||||
if not (target_topic and target_topic == candidate_topic):
|
||||
reasons.append(
|
||||
"Shared concept: "
|
||||
+ ", ".join(target_topics[key] for key in shared_topics[:2])
|
||||
)
|
||||
|
||||
# Skill overlap: up to 20 pts
|
||||
target_skills = normalized_labels(target.get("skill_tags"))
|
||||
candidate_skills = normalized_labels(candidate.get("skill_tags"))
|
||||
shared_skills = sorted(set(target_skills) & set(candidate_skills))
|
||||
if shared_skills:
|
||||
score += min(len(shared_skills) * 10, 20)
|
||||
reasons.append(
|
||||
"Shared skill: "
|
||||
+ ", ".join(target_skills[key] for key in shared_skills[:2])
|
||||
)
|
||||
|
||||
# Same question format: 10 pts
|
||||
if question_family(candidate) == question_family(target):
|
||||
score += 10
|
||||
reasons.append("Same format")
|
||||
|
||||
# Same difficulty: 5 pts
|
||||
if candidate.get("difficulty") and candidate.get("difficulty") == target.get("difficulty"):
|
||||
score += 5
|
||||
reasons.append("Same difficulty")
|
||||
|
||||
# Full-text similarity from PostgreSQL ts_rank_cd: up to 20 pts
|
||||
if text_score > 0:
|
||||
text_pts = min(round(text_score * 60), 20)
|
||||
score += text_pts
|
||||
if text_pts >= 4:
|
||||
reasons.append("Similar wording")
|
||||
|
||||
return min(score, 99), reasons
|
||||
|
||||
|
||||
@router.get("/variants/favorited")
|
||||
async def get_favorited_variants(user_id: str = Depends(get_current_user_id)):
|
||||
"""获取用户收藏的所有 variant(用于 Error Book)"""
|
||||
sb = get_supabase()
|
||||
rows = (
|
||||
sb.table("question_variants")
|
||||
.select("*, paper_questions(question_number, paper_id, papers(id, course_code, year, term, exam_type, part_label))")
|
||||
.eq("user_id", user_id)
|
||||
.eq("favorited", True)
|
||||
.order("created_at", desc=True)
|
||||
.execute()
|
||||
.data
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
@router.post("/{question_id}/variant")
|
||||
async def create_variant(question_id: str, user_id: str = Depends(get_current_user_id)):
|
||||
"""生成变式题并入库"""
|
||||
sb = get_supabase()
|
||||
result = sb.table("paper_questions").select("*").eq("id", question_id).execute()
|
||||
if not result.data:
|
||||
raise HTTPException(status_code=404, detail="Question not found")
|
||||
|
||||
question = result.data[0]
|
||||
variant_data = await asyncio.to_thread(generate_variant, question)
|
||||
variant_data["knowledge_reminder"] = question.get("knowledge_reminder", "")
|
||||
|
||||
saved = sb.table("question_variants").insert({
|
||||
"user_id": user_id,
|
||||
"source_question_id": question_id,
|
||||
"variant_data": variant_data,
|
||||
"favorited": False,
|
||||
}).execute()
|
||||
|
||||
row = saved.data[0]
|
||||
row["source_question_number"] = question["question_number"]
|
||||
return row
|
||||
|
||||
|
||||
@router.get("/{question_id}/variants")
|
||||
async def list_variants(question_id: str, user_id: str = Depends(get_current_user_id)):
|
||||
"""获取某道题的用户所有 variant"""
|
||||
sb = get_supabase()
|
||||
q_result = sb.table("paper_questions").select("question_number").eq("id", question_id).execute()
|
||||
question_number = q_result.data[0]["question_number"] if q_result.data else ""
|
||||
|
||||
rows = (
|
||||
sb.table("question_variants")
|
||||
.select("*")
|
||||
.eq("user_id", user_id)
|
||||
.eq("source_question_id", question_id)
|
||||
.order("created_at", desc=True)
|
||||
.execute()
|
||||
.data
|
||||
)
|
||||
for row in rows:
|
||||
row["source_question_number"] = question_number
|
||||
return rows
|
||||
|
||||
|
||||
@router.patch("/variant/{variant_id}")
|
||||
async def update_variant(variant_id: str, data: VariantUpdate, user_id: str = Depends(get_current_user_id)):
|
||||
"""更新 variant(收藏/取消收藏)"""
|
||||
sb = get_supabase()
|
||||
update: dict = {}
|
||||
if data.favorited is not None:
|
||||
update["favorited"] = data.favorited
|
||||
if not update:
|
||||
raise HTTPException(status_code=400, detail="Nothing to update")
|
||||
|
||||
result = (
|
||||
sb.table("question_variants")
|
||||
.update(update)
|
||||
.eq("id", variant_id)
|
||||
.eq("user_id", user_id)
|
||||
.execute()
|
||||
)
|
||||
if not result.data:
|
||||
raise HTTPException(status_code=404, detail="Variant not found")
|
||||
return result.data[0]
|
||||
|
||||
|
||||
@router.delete("/variant/{variant_id}", status_code=204)
|
||||
async def delete_variant(variant_id: str, user_id: str = Depends(get_current_user_id)):
|
||||
"""删除 variant"""
|
||||
sb = get_supabase()
|
||||
sb.table("question_variants").delete().eq("id", variant_id).eq("user_id", user_id).execute()
|
||||
|
||||
|
||||
@router.get("/{question_id}/similar")
|
||||
async def get_similar_questions(question_id: str, limit: int = 6):
|
||||
"""Retrieve similar questions from the same course."""
|
||||
# Cache hit
|
||||
cached = _similar_cache.get(question_id)
|
||||
if cached and (time.time() - cached[0]) < _CACHE_TTL:
|
||||
return cached[1][:max(1, min(limit, 12))]
|
||||
|
||||
sb = get_supabase()
|
||||
result = sb.table("paper_questions").select("*, similar_questions").eq("id", question_id).execute()
|
||||
if not result.data:
|
||||
raise HTTPException(status_code=404, detail="Question not found")
|
||||
|
||||
target = result.data[0]
|
||||
|
||||
# Return pre-computed immediately; schedule background refresh
|
||||
if target.get("similar_questions"):
|
||||
precomputed = target["similar_questions"]
|
||||
_similar_cache[question_id] = (time.time(), precomputed)
|
||||
return precomputed[:max(1, min(limit, 12))]
|
||||
|
||||
paper_result = sb.table("papers").select("id, course_code").eq("id", target["paper_id"]).execute()
|
||||
# (fallback: compute on-the-fly for questions not yet backfilled)
|
||||
if not paper_result.data:
|
||||
raise HTTPException(status_code=404, detail="Paper not found")
|
||||
|
||||
course_code = paper_result.data[0]["course_code"]
|
||||
papers = (
|
||||
sb.table("papers")
|
||||
.select("id, course_code, year, term, exam_type, part_label")
|
||||
.eq("course_code", course_code)
|
||||
.eq("status", "ready")
|
||||
.execute()
|
||||
.data
|
||||
)
|
||||
paper_ids = [paper["id"] for paper in papers if paper["id"] != target["paper_id"]]
|
||||
if not paper_ids:
|
||||
return []
|
||||
|
||||
papers_by_id = {paper["id"]: paper for paper in papers}
|
||||
|
||||
# Pre-filter by analytics_topic in DB when possible (cuts candidates from ~250 to ~30)
|
||||
candidates_query = (
|
||||
sb.table("paper_questions")
|
||||
.select(
|
||||
"id, paper_id, question_number, question_type, question_format, "
|
||||
"question_text, score, topics, analytics_topic, topic_tags, skill_tags, "
|
||||
"difficulty, knowledge_reminder, ai_hint, solution"
|
||||
)
|
||||
.in_("paper_id", paper_ids)
|
||||
)
|
||||
target_topic = target.get("analytics_topic")
|
||||
if target_topic:
|
||||
candidates_query = candidates_query.eq("analytics_topic", target_topic)
|
||||
|
||||
candidates = candidates_query.execute().data
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
# Batch full-text scores from PostgreSQL (skip if too many candidates — slow)
|
||||
text_scores: dict[str, float] = {}
|
||||
if len(candidates) <= 50:
|
||||
try:
|
||||
rpc_result = sb.rpc(
|
||||
"text_similarity_scores",
|
||||
{
|
||||
"query_text": target.get("question_text") or "",
|
||||
"candidate_ids": [c["id"] for c in candidates],
|
||||
},
|
||||
).execute()
|
||||
for row in rpc_result.data or []:
|
||||
text_scores[row["question_id"]] = float(row["text_score"] or 0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
ranked = []
|
||||
for candidate in candidates:
|
||||
text_score = text_scores.get(candidate["id"], 0.0)
|
||||
match_percent, reasons = similarity_score(target, candidate, text_score)
|
||||
if match_percent < 20:
|
||||
continue
|
||||
paper = papers_by_id.get(candidate["paper_id"], {})
|
||||
source = (
|
||||
f"{paper.get('year', '')} {paper.get('term', '').title()} "
|
||||
f"{paper.get('exam_type', '').title()}"
|
||||
).strip()
|
||||
if paper.get("part_label"):
|
||||
source = f"{source} Part {paper['part_label']}"
|
||||
ranked.append(
|
||||
{
|
||||
"id": candidate["id"],
|
||||
"paper_id": candidate["paper_id"],
|
||||
"source": source,
|
||||
"question_number": candidate["question_number"],
|
||||
"match_percent": match_percent,
|
||||
"match_reasons": reasons,
|
||||
"question_type": question_family(candidate),
|
||||
"question_text": candidate["question_text"],
|
||||
"topics": display_topics(candidate),
|
||||
"difficulty": candidate.get("difficulty"),
|
||||
"knowledge_reminder": candidate.get("knowledge_reminder", ""),
|
||||
"ai_hint": candidate.get("ai_hint", ""),
|
||||
"solution": candidate.get("solution", ""),
|
||||
}
|
||||
)
|
||||
|
||||
ranked.sort(key=lambda item: (-item["match_percent"], item["source"], item["question_number"]))
|
||||
|
||||
# Keep only the best-scoring question per paper
|
||||
seen_papers: set[str] = set()
|
||||
deduped = []
|
||||
for item in ranked:
|
||||
if item["paper_id"] not in seen_papers:
|
||||
seen_papers.add(item["paper_id"])
|
||||
deduped.append(item)
|
||||
|
||||
_similar_cache[question_id] = (time.time(), deduped)
|
||||
|
||||
# Persist to DB so future requests are instant
|
||||
try:
|
||||
sb.table("paper_questions").update({"similar_questions": deduped}).eq("id", question_id).execute()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return deduped[:max(1, min(limit, 12))]
|
||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
146
backend/app/services/grader.py
Normal file
146
backend/app/services/grader.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""OCR, grading, and variant generation prompts"""
|
||||
|
||||
import json
|
||||
import base64
|
||||
from app.services.llm_clients import get_vision_client, get_deepseek_client
|
||||
|
||||
OCR_PROMPT = """You are an expert at recognizing handwritten answers. Analyze this photo of a student's handwritten answer and extract the text and mathematical formulas.
|
||||
|
||||
Requirements:
|
||||
- Faithfully extract what the student wrote, do not modify or correct
|
||||
- Use LaTeX format for math formulas (e.g. $x^2 + 1$)
|
||||
- If there are multiple steps, list them in original order
|
||||
- If some handwriting is unclear, mark with [unclear]
|
||||
|
||||
Return only the extracted text, no additional explanation."""
|
||||
|
||||
GRADING_PROMPT = """You are an expert academic grader. Grade the following student answer. ALL output must be in English.
|
||||
|
||||
Question info:
|
||||
- Number: {question_number}
|
||||
- Type: {question_type}
|
||||
- Question: {question_text}
|
||||
- Score: {score}
|
||||
|
||||
Reference answer / solution:
|
||||
{reference_answer}
|
||||
|
||||
Student answer:
|
||||
{student_answer}
|
||||
|
||||
Grade and return JSON:
|
||||
{{
|
||||
"is_correct": true/false,
|
||||
"score_given": 0-{score},
|
||||
"feedback": "<HTML> Step-by-step analysis of the student's answer, pointing out correct parts and errors, using KaTeX formulas </HTML>",
|
||||
"error_at_step": null or the step number where errors begin (integer)
|
||||
}}
|
||||
|
||||
Grading rules:
|
||||
- MC / fill-blank: only correct if answer matches exactly
|
||||
- Long questions: give partial credit for correct steps even if the final answer is wrong
|
||||
- feedback in HTML format, supports KaTeX ($..$ inline, $$...$$ block)
|
||||
- Mark errors with <div class="common-error">...</div>
|
||||
- Identify exactly which step the error starts"""
|
||||
|
||||
VARIANT_PROMPT = """You are an expert exam question creator. Generate a similar but different variant question based on the original below. ALL output must be in English.
|
||||
|
||||
Original question info:
|
||||
- Type: {question_type}
|
||||
- Question: {question_text}
|
||||
- Topics: {topics}
|
||||
- Difficulty: {difficulty}
|
||||
- Reference answer: {answer}
|
||||
|
||||
Requirements:
|
||||
- Variant must test the same knowledge points at similar difficulty
|
||||
- Data/scenario/wording must differ — don't just change numbers
|
||||
- Must provide a complete correct answer
|
||||
|
||||
Format requirements (CRITICAL):
|
||||
- All text in HTML format, absolutely NO markdown syntax
|
||||
- Code: <pre><code class="language-xxx">...</code></pre>, NOT ```
|
||||
- Math: $...$ (inline) or $$...$$ (block), KaTeX compatible
|
||||
- Line breaks: <br>, paragraphs: <p>
|
||||
|
||||
Return JSON:
|
||||
{{
|
||||
"question_text": "HTML formatted variant question",
|
||||
"question_type": "{question_type}",
|
||||
"options": [MC only, format {{"label":"A","text":"..."}}, ...] or null,
|
||||
"correct_answer": "Correct answer (plain text)",
|
||||
"ai_hint": "HTML formatted hint that guides thinking WITHOUT giving the answer",
|
||||
"solution": "HTML formatted complete step-by-step solution"
|
||||
}}"""
|
||||
|
||||
|
||||
def ocr_photo(photo_bytes: bytes) -> str:
|
||||
"""Gemini Vision OCR for handwritten answers"""
|
||||
client = get_vision_client()
|
||||
b64 = base64.b64encode(photo_bytes).decode("utf-8")
|
||||
|
||||
resp = client.chat.completions.create(
|
||||
model="gemini-2.5-flash",
|
||||
messages=[
|
||||
{"role": "system", "content": OCR_PROMPT},
|
||||
{"role": "user", "content": [
|
||||
{"type": "image_url", "image_url": {
|
||||
"url": f"data:image/jpeg;base64,{b64}",
|
||||
}},
|
||||
]},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=2000,
|
||||
)
|
||||
return resp.choices[0].message.content or ""
|
||||
|
||||
|
||||
def grade_answer(question: dict, student_answer: str) -> dict:
|
||||
"""Qwen grades student answer"""
|
||||
reference = question.get("raw_answer_text") or question.get("solution") or "No reference answer"
|
||||
score = question.get("score") or "unknown"
|
||||
|
||||
ds = get_deepseek_client()
|
||||
resp = ds.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=[
|
||||
{"role": "system", "content": GRADING_PROMPT.format(
|
||||
question_number=question["question_number"],
|
||||
question_type=question["question_type"],
|
||||
question_text=question["question_text"],
|
||||
score=score,
|
||||
reference_answer=reference,
|
||||
student_answer=student_answer,
|
||||
)},
|
||||
],
|
||||
temperature=0.2,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
return json.loads(resp.choices[0].message.content)
|
||||
|
||||
|
||||
def generate_variant(question: dict) -> dict:
|
||||
"""Gemini generates a variant question"""
|
||||
answer = (
|
||||
question.get("correct_option")
|
||||
or question.get("correct_answer")
|
||||
or question.get("raw_answer_text")
|
||||
or "N/A"
|
||||
)
|
||||
|
||||
ds = get_deepseek_client()
|
||||
resp = ds.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=[
|
||||
{"role": "system", "content": VARIANT_PROMPT.format(
|
||||
question_type=question["question_type"],
|
||||
question_text=question["question_text"],
|
||||
topics=", ".join(question.get("topics", [])),
|
||||
difficulty=question.get("difficulty", "medium"),
|
||||
answer=answer,
|
||||
)},
|
||||
],
|
||||
temperature=0.5,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
return json.loads(resp.choices[0].message.content)
|
||||
74
backend/app/services/llm_clients.py
Normal file
74
backend/app/services/llm_clients.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import httpx
|
||||
from openai import OpenAI
|
||||
from app.config import get_settings
|
||||
|
||||
_TIMEOUT = httpx.Timeout(connect=10, read=300, write=60, pool=10)
|
||||
|
||||
_gpt_client: OpenAI | None = None
|
||||
_qwen_client: OpenAI | None = None
|
||||
_gemini_flash_client: OpenAI | None = None
|
||||
_gemini_lite_client: OpenAI | None = None
|
||||
_deepseek_client: OpenAI | None = None
|
||||
|
||||
|
||||
def get_gpt_client() -> OpenAI:
|
||||
"""laozhang API — gpt-4o / gpt-4o-mini"""
|
||||
global _gpt_client
|
||||
if _gpt_client is None:
|
||||
s = get_settings()
|
||||
_gpt_client = OpenAI(
|
||||
base_url=s.laozhang_base_url,
|
||||
api_key=s.laozhang_api_key,
|
||||
)
|
||||
return _gpt_client
|
||||
|
||||
|
||||
def get_qwen_client() -> OpenAI:
|
||||
"""DashScope — qwen-plus"""
|
||||
global _qwen_client
|
||||
if _qwen_client is None:
|
||||
s = get_settings()
|
||||
_qwen_client = OpenAI(
|
||||
base_url=s.dashscope_base_url,
|
||||
api_key=s.dashscope_api_key,
|
||||
)
|
||||
return _qwen_client
|
||||
|
||||
|
||||
def get_vision_client() -> OpenAI:
|
||||
"""Google Gemini 官方 API(视觉,用于拆题+OCR)— 部署在新加坡可用"""
|
||||
global _gemini_flash_client
|
||||
if _gemini_flash_client is None:
|
||||
s = get_settings()
|
||||
_gemini_flash_client = OpenAI(
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
api_key=s.google_gemini_api_key,
|
||||
timeout=_TIMEOUT,
|
||||
)
|
||||
return _gemini_flash_client
|
||||
|
||||
|
||||
def get_gemini_lite_client() -> OpenAI:
|
||||
"""laozhang — gemini-3.1-flash-lite-preview(轻量,用于 AI trio)"""
|
||||
global _gemini_lite_client
|
||||
if _gemini_lite_client is None:
|
||||
s = get_settings()
|
||||
_gemini_lite_client = OpenAI(
|
||||
base_url=s.laozhang_base_url,
|
||||
api_key=s.laozhang_api_key,
|
||||
timeout=_TIMEOUT,
|
||||
)
|
||||
return _gemini_lite_client
|
||||
|
||||
|
||||
def get_deepseek_client() -> OpenAI:
|
||||
"""DeepSeek — deepseek-chat(用于 AI trio)"""
|
||||
global _deepseek_client
|
||||
if _deepseek_client is None:
|
||||
s = get_settings()
|
||||
_deepseek_client = OpenAI(
|
||||
base_url=s.deepseek_base_url,
|
||||
api_key=s.deepseek_api_key,
|
||||
timeout=_TIMEOUT,
|
||||
)
|
||||
return _deepseek_client
|
||||
576
backend/app/services/paper_processor.py
Normal file
576
backend/app/services/paper_processor.py
Normal file
@@ -0,0 +1,576 @@
|
||||
"""试卷处理管线:PDF → 结构化题目 → AI 三件套(Vision 模式)"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
from contextlib import redirect_stdout
|
||||
import fitz # pymupdf
|
||||
from app.services.supabase_client import get_supabase
|
||||
from app.services.llm_clients import get_vision_client, get_deepseek_client
|
||||
|
||||
|
||||
def strip_nulls(obj):
|
||||
"""Recursively remove \\u0000 null bytes from strings (PostgreSQL rejects them)."""
|
||||
if isinstance(obj, str):
|
||||
return obj.replace("\u0000", "")
|
||||
if isinstance(obj, dict):
|
||||
return {k: strip_nulls(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [strip_nulls(i) for i in obj]
|
||||
return obj
|
||||
|
||||
|
||||
# ============================================
|
||||
# Prompts
|
||||
# ============================================
|
||||
|
||||
STRUCTURE_PROMPT = """You are an expert exam paper structure analyst. You are given images of a past exam paper. Analyze every page carefully and extract all questions into structured JSON.
|
||||
All generated values must be in English. Do not output Chinese.
|
||||
|
||||
CRITICAL RULES for question_text:
|
||||
- Each question's question_text must be FULLY SELF-CONTAINED. Include ALL context needed to solve it.
|
||||
- For sub-questions (e.g. (a)(i)), copy the ENTIRE parent question setup (variable definitions, code blocks, problem description) into the question_text, then append the specific sub-question.
|
||||
- For Python/code questions: include ALL variable definitions and import statements verbatim, exactly as they appear in the exam, preserving multi-line arrays and data structures completely.
|
||||
- Never truncate code. If a variable is defined across multiple lines (e.g. a numpy array), include every line.
|
||||
|
||||
Output JSON format (strictly follow):
|
||||
{
|
||||
"total_score": 100,
|
||||
"difficulty_level": "medium",
|
||||
"topics_summary": {"Topic A": 40, "Topic B": 30, "Topic C": 30},
|
||||
"questions": [
|
||||
{
|
||||
"question_number": "1a",
|
||||
"parent_question": "1",
|
||||
"question_type": "mc",
|
||||
"question_text": "Original question text...",
|
||||
"score": 5,
|
||||
"page_number": 1,
|
||||
"options": [{"label": "A", "text": "Option content"}, {"label": "B", "text": "..."}],
|
||||
"topics": ["Linked List", "Pointer"],
|
||||
"difficulty": "easy"
|
||||
},
|
||||
{
|
||||
"question_number": "2",
|
||||
"parent_question": null,
|
||||
"question_type": "long_question",
|
||||
"question_text": "Original question text...",
|
||||
"score": 15,
|
||||
"page_number": 2,
|
||||
"options": null,
|
||||
"topics": ["Recursion"],
|
||||
"difficulty": "hard"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Rules:
|
||||
- question_type must be one of: "mc" (multiple choice), "true_false" (true/false), "fill_blank" (fill in blank), "long_question" (long question)
|
||||
- True/False questions MUST use "true_false" type, with options set to [{"label":"True","text":"True"},{"label":"False","text":"False"}], correct_option as "True" or "False"
|
||||
- Multiple choice must extract the options array
|
||||
- Sub-questions use parent_question to link to parent: "1a" parent is "1"
|
||||
- Independent questions without sub-questions set parent_question to null
|
||||
- page_number inferred from where the question appears
|
||||
- topics inferred from the question content
|
||||
- difficulty: "easy" | "medium" | "hard"
|
||||
- Extract ALL questions, do not miss any
|
||||
- Keep topic labels in English only
|
||||
"""
|
||||
|
||||
ANSWER_MATCH_PROMPT = """You are an expert exam answer matching specialist. Below is the answer text for an exam paper. Extract and match answers to their corresponding question numbers.
|
||||
All generated values must be in English. Do not output Chinese.
|
||||
|
||||
Question structure:
|
||||
{questions_json}
|
||||
|
||||
Answer text:
|
||||
{answer_text}
|
||||
|
||||
Output JSON format:
|
||||
{{
|
||||
"answers": [
|
||||
{{
|
||||
"question_number": "1a",
|
||||
"correct_option": "B",
|
||||
"correct_answer": null,
|
||||
"raw_answer_text": "Original answer text..."
|
||||
}},
|
||||
{{
|
||||
"question_number": "2",
|
||||
"correct_option": null,
|
||||
"correct_answer": null,
|
||||
"raw_answer_text": "Complete solution process and answer..."
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
Rules:
|
||||
- For MC questions, fill correct_option (e.g. "B")
|
||||
- For fill-blank questions, fill correct_answer (e.g. "O(n log n)")
|
||||
- For long questions, only fill raw_answer_text (complete solution process)
|
||||
- Match all questions where answers can be found
|
||||
- Keep raw_answer_text faithful to the source answer, but do not add Chinese commentary
|
||||
"""
|
||||
|
||||
ANALYSIS_PROMPT = """You are an expert academic answer analyst. Generate three sections for the following exam question. ALL output must be in English.
|
||||
|
||||
Question info:
|
||||
- Number: {question_number}
|
||||
- Type: {question_type}
|
||||
- Score: {score}
|
||||
- Question: {question_text}
|
||||
- Topics: {topics}
|
||||
{answer_section}
|
||||
|
||||
Generate THREE sections in HTML format (supports KaTeX: block $$ ... $$ inline $ ... $):
|
||||
|
||||
Output JSON:
|
||||
{{
|
||||
"knowledge_reminder": "<HTML> Prerequisite knowledge points needed for this question, as a concise bullet list </HTML>",
|
||||
"ai_hint": "<HTML> A hint that guides thinking direction WITHOUT giving away the answer </HTML>",
|
||||
"solution": "<HTML> Complete step-by-step solution (Step 1, Step 2, ...) with derivations, formulas, and common mistake warnings </HTML>"
|
||||
}}
|
||||
|
||||
Solution requirements:
|
||||
- Must include complete working process, not just the answer
|
||||
- Each step must have an explanation
|
||||
- If a reference answer is provided, derive the solution based on it
|
||||
- If no reference answer, work out the complete solution independently
|
||||
- For MC questions, explain why the correct option is right AND why others are wrong
|
||||
- Use <ol> or numbered steps
|
||||
- Mark common mistakes with <div class="common-error">...</div>
|
||||
|
||||
KaTeX formula rules:
|
||||
- Block formula: $$ on its own line, with blank lines before and after
|
||||
- Inline formula: $x^2$ no line break
|
||||
- Matrix: \\begin{{bmatrix}} ... \\end{{bmatrix}}
|
||||
- Fraction: \\frac{{a}}{{b}}
|
||||
"""
|
||||
|
||||
BATCH_ANALYSIS_PROMPT = """You are an expert academic answer analyst. Generate three study sections for each question below. ALL output must be in English.
|
||||
|
||||
For every question, return:
|
||||
- knowledge_reminder: concise prerequisite bullets in HTML
|
||||
- ai_hint: a helpful hint in HTML without revealing the final answer
|
||||
- solution: a complete step-by-step solution in HTML
|
||||
|
||||
Return JSON in this exact format:
|
||||
{{
|
||||
"analyses": [
|
||||
{{
|
||||
"question_number": "1a",
|
||||
"knowledge_reminder": "<HTML>...</HTML>",
|
||||
"ai_hint": "<HTML>...</HTML>",
|
||||
"solution": "<HTML>...</HTML>"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
Rules:
|
||||
- Return one item for every provided question_number
|
||||
- Keep each item matched to the same question_number
|
||||
- All text must be in English
|
||||
- HTML only, KaTeX compatible
|
||||
- For MC questions, explain why the correct option is right and why the others are wrong
|
||||
- For long questions, show a complete derivation or reasoning chain
|
||||
- Use <ol> or numbered steps in solution when appropriate
|
||||
- Mark common mistakes with <div class="common-error">...</div>
|
||||
- CRITICAL: When a question_text contains "[Context from parent question X]" followed by "[Sub-question Y]", the parent section is background context only. You MUST solve ONLY the specific sub-question labeled [Sub-question Y]. Do NOT solve other sub-questions listed in the parent context. Give one precise answer for that single sub-question only.
|
||||
|
||||
Questions:
|
||||
{questions_payload}
|
||||
"""
|
||||
|
||||
|
||||
# ============================================
|
||||
# 处理管线
|
||||
# ============================================
|
||||
|
||||
RETRYABLE_ERROR_MARKERS = (
|
||||
"429",
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"too many requests",
|
||||
"timeout",
|
||||
"timed out",
|
||||
"connection",
|
||||
)
|
||||
|
||||
|
||||
def is_retryable_error(exc: Exception) -> bool:
|
||||
message = str(exc).lower()
|
||||
return any(marker in message for marker in RETRYABLE_ERROR_MARKERS)
|
||||
|
||||
|
||||
def pdf_to_images(pdf_bytes: bytes, dpi: int = 96) -> list[str]:
|
||||
"""将 PDF 每页渲染为 base64 PNG 图片列表(96dpi 平衡清晰度与成本)"""
|
||||
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
||||
images = []
|
||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||
for page in doc:
|
||||
pix = page.get_pixmap(matrix=mat, colorspace=fitz.csRGB)
|
||||
img_bytes = pix.tobytes("png")
|
||||
images.append(base64.b64encode(img_bytes).decode())
|
||||
doc.close()
|
||||
return images
|
||||
|
||||
|
||||
def parse_json_response(text: str) -> dict:
|
||||
"""解析模型返回的 JSON,兼容 markdown 代码块包装"""
|
||||
text = text.strip()
|
||||
# 去掉 ```json ... ``` 包装
|
||||
if text.startswith("```"):
|
||||
lines = text.splitlines()
|
||||
text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:])
|
||||
# 移除 JSON 字符串中的非法控制字符(0x00-0x1F 除了 \t \n \r)
|
||||
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f]', '', text)
|
||||
# 修复模型返回的无效 JSON 转义序列:只修奇数个反斜杠后的非法字符
|
||||
text = re.sub(r'(?<!\\)((?:\\\\)*)\\([^"\\/bfnrtu])', r'\1\\\\\2', text)
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
async def gemini_vision_json(
|
||||
*,
|
||||
system_prompt: str,
|
||||
images: list[str],
|
||||
user_text: str = "",
|
||||
temperature: float = 0,
|
||||
max_attempts: int = 6,
|
||||
) -> dict:
|
||||
"""发送图片 + prompt 给 Gemini vision 模型,返回 JSON"""
|
||||
client = get_vision_client()
|
||||
delay_seconds = 2
|
||||
|
||||
content: list = []
|
||||
for b64 in images:
|
||||
content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}})
|
||||
if user_text:
|
||||
content.append({"type": "text", "text": user_text})
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="gemini-2.5-flash",
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt + "\n\nIMPORTANT: Your entire response must be valid JSON only. No markdown, no code fences, no extra text."},
|
||||
{"role": "user", "content": content},
|
||||
],
|
||||
temperature=temperature,
|
||||
max_tokens=16384,
|
||||
)
|
||||
return parse_json_response(response.choices[0].message.content)
|
||||
except Exception as exc:
|
||||
if attempt == max_attempts or not is_retryable_error(exc):
|
||||
raise
|
||||
await asyncio.sleep(delay_seconds)
|
||||
delay_seconds = min(delay_seconds * 2, 30)
|
||||
|
||||
|
||||
async def deepseek_json_completion(
|
||||
*,
|
||||
system_prompt: str,
|
||||
user_prompt: str | None = None,
|
||||
temperature: float = 0,
|
||||
max_attempts: int = 6,
|
||||
) -> dict:
|
||||
"""DeepSeek 纯文本 JSON completion(用于 AI trio 生成)"""
|
||||
client = get_deepseek_client()
|
||||
delay_seconds = 2
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
if user_prompt:
|
||||
messages.append({"role": "user", "content": user_prompt})
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=8192,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
raw = response.choices[0].message.content
|
||||
raw = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f]', '', raw)
|
||||
raw = re.sub(r'(?<!\\)((?:\\\\)*)\\([^"\\/bfnrtu])', r'\1\\\\\2', raw)
|
||||
return json.loads(raw)
|
||||
except Exception as exc:
|
||||
if attempt == max_attempts or not is_retryable_error(exc):
|
||||
raise
|
||||
await asyncio.sleep(delay_seconds)
|
||||
delay_seconds = min(delay_seconds * 2, 30)
|
||||
|
||||
|
||||
def chunked(items: list[dict], size: int) -> list[list[dict]]:
|
||||
return [items[i:i + size] for i in range(0, len(items), size)]
|
||||
|
||||
|
||||
def _question_sort_key(qnum: str) -> tuple:
|
||||
"""自然排序题号:1a < 1b < ... < 1i < 1j < 2ai < 2aii < 10a"""
|
||||
parts = re.findall(r'(\d+|[a-zA-Z]+|[()]+)', qnum)
|
||||
key = []
|
||||
for idx, p in enumerate(parts):
|
||||
if p.isdigit():
|
||||
key.append((0, int(p), ''))
|
||||
elif p in ('(', ')'):
|
||||
continue
|
||||
else:
|
||||
# Single letter (a-z): always sort alphabetically (a=1, b=2, ..., j=10)
|
||||
if len(p) == 1 and p.isalpha():
|
||||
key.append((1, ord(p.lower()) - ord('a') + 1, p))
|
||||
else:
|
||||
# Multi-letter: roman numerals for sub-sub-questions (i=1, ii=2, iii=3, ...)
|
||||
romans = {'i':1,'ii':2,'iii':3,'iv':4,'v':5,'vi':6,'vii':7,'viii':8,'ix':9,'x':10,'xi':11,'xii':12,'xiii':13}
|
||||
if p.lower() in romans:
|
||||
key.append((2, romans[p.lower()], p))
|
||||
else:
|
||||
key.append((1, 0, p))
|
||||
return tuple(key)
|
||||
|
||||
|
||||
def sort_questions(questions: list[dict]) -> list[dict]:
|
||||
"""按题号自然排序"""
|
||||
return sorted(questions, key=lambda q: _question_sort_key(q.get("question_number", "")))
|
||||
|
||||
|
||||
def extract_code_block(text: str) -> str:
|
||||
"""
|
||||
从题目文本中提取 Python 代码块。
|
||||
策略:找到第一个明确的代码起始行(import/赋值/print),
|
||||
然后把后续所有缩进或延续行一并带上,直到明显的非代码段落。
|
||||
"""
|
||||
lines = text.splitlines()
|
||||
result = []
|
||||
in_code = False
|
||||
open_brackets = 0
|
||||
|
||||
CODE_START = re.compile(r"^\s*(import |from \w|[A-Za-z_]\w*\s*=|print\()")
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
|
||||
# 已在代码块内:括号未闭合时继续收集
|
||||
if in_code and open_brackets > 0:
|
||||
result.append(stripped)
|
||||
open_brackets += stripped.count("(") + stripped.count("[") + stripped.count("{")
|
||||
open_brackets -= stripped.count(")") + stripped.count("]") + stripped.count("}")
|
||||
continue
|
||||
|
||||
# 检测新的代码起始行
|
||||
if CODE_START.match(line):
|
||||
in_code = True
|
||||
result.append(stripped)
|
||||
open_brackets += stripped.count("(") + stripped.count("[") + stripped.count("{")
|
||||
open_brackets -= stripped.count(")") + stripped.count("]") + stripped.count("}")
|
||||
continue
|
||||
|
||||
# 非代码行:重置(但保留 in_code=True 以便继续接后续代码行)
|
||||
in_code = False
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
# 保持向后兼容
|
||||
extract_code_lines = extract_code_block
|
||||
|
||||
|
||||
def try_exec_python(code: str, shared_ns: dict) -> str | None:
|
||||
"""
|
||||
在 shared_ns 命名空间中执行 code,捕获 stdout。
|
||||
返回输出字符串,失败返回 None。
|
||||
"""
|
||||
buf = io.StringIO()
|
||||
try:
|
||||
with redirect_stdout(buf):
|
||||
exec(code, shared_ns) # noqa: S102
|
||||
output = buf.getvalue().strip()
|
||||
return output if output else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _resume_ai_trio(sb, paper_id: str, questions: list[dict]):
|
||||
"""为缺 solution 的题目生成 AI trio,逐条写回 DB。支持断点续传。"""
|
||||
need = [q for q in questions if not q.get("solution")]
|
||||
if not need:
|
||||
# 全部已有 solution,直接标记完成
|
||||
sb.table("papers").update({"status": "ready", "processing_step": None}).eq("id", paper_id).execute()
|
||||
return
|
||||
|
||||
total_q = len(questions)
|
||||
done_q = total_q - len(need)
|
||||
|
||||
# 构建 payload
|
||||
id_map = {q["question_number"]: q["id"] for q in need}
|
||||
# 需要完整的 question_text 来生成 AI trio
|
||||
full_data = sb.table("paper_questions").select(
|
||||
"id, question_number, question_type, question_text, score, correct_option, correct_answer, raw_answer_text"
|
||||
).eq("paper_id", paper_id).in_("id", [q["id"] for q in need]).execute().data
|
||||
|
||||
payloads = []
|
||||
for q in full_data:
|
||||
answer_section = q.get("raw_answer_text") or ""
|
||||
if not answer_section and q.get("correct_option"):
|
||||
answer_section = f"Correct option: {q['correct_option']}"
|
||||
elif not answer_section and q.get("correct_answer"):
|
||||
answer_section = f"Correct answer: {q['correct_answer']}"
|
||||
payloads.append({
|
||||
"question_number": q["question_number"],
|
||||
"question_type": q["question_type"] or "long_question",
|
||||
"score": q.get("score") or "unknown",
|
||||
"question_text": q["question_text"] or "",
|
||||
"reference_answer": answer_section,
|
||||
})
|
||||
|
||||
batches = chunked(payloads, 3)
|
||||
for batch_idx, batch in enumerate(batches, 1):
|
||||
current = done_q + batch_idx * 3
|
||||
_update_progress(sb, paper_id, f"Generating solutions ({min(current, total_q)}/{total_q} questions)", batch_idx, len(batches))
|
||||
try:
|
||||
result = await deepseek_json_completion(
|
||||
system_prompt=BATCH_ANALYSIS_PROMPT.format(
|
||||
questions_payload=json.dumps(batch, ensure_ascii=False),
|
||||
),
|
||||
temperature=0.3,
|
||||
)
|
||||
for item in result.get("analyses", []):
|
||||
qnum = item.get("question_number")
|
||||
qid = id_map.get(qnum)
|
||||
if qid:
|
||||
sb.table("paper_questions").update({
|
||||
"knowledge_reminder": item.get("knowledge_reminder", ""),
|
||||
"ai_hint": item.get("ai_hint", ""),
|
||||
"solution": item.get("solution", ""),
|
||||
}).eq("id", qid).execute()
|
||||
except Exception:
|
||||
pass # 单批失败不影响其他批
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 标记完成
|
||||
sb.table("papers").update({"status": "ready", "processing_step": None}).eq("id", paper_id).execute()
|
||||
|
||||
|
||||
def _update_progress(sb, paper_id: str, step: str, progress: int = 0, total: int = 0):
|
||||
"""更新处理进度到 DB"""
|
||||
sb.table("papers").update({
|
||||
"processing_step": step,
|
||||
"processing_progress": progress,
|
||||
"processing_total": total,
|
||||
}).eq("id", paper_id).execute()
|
||||
|
||||
|
||||
async def process_paper(paper_id: str, paper_bytes: bytes, answer_bytes: bytes | None):
|
||||
"""后台处理管线: PDF pages → Vision 结构化 → AI 三件套
|
||||
|
||||
设计原则:每个步骤完成后立即持久化到 DB,支持断点续传。
|
||||
"""
|
||||
sb = get_supabase()
|
||||
|
||||
try:
|
||||
# 检查是否已有题目(断点续传场景)
|
||||
existing = sb.table("paper_questions").select("id, question_number, solution").eq("paper_id", paper_id).execute().data
|
||||
|
||||
if existing:
|
||||
# 已有题目 → 跳过提取,直接补 AI trio
|
||||
await _resume_ai_trio(sb, paper_id, existing)
|
||||
return
|
||||
|
||||
# ── Step 1: PDF → 图片 ──
|
||||
_update_progress(sb, paper_id, "Rendering PDF pages...")
|
||||
paper_images = pdf_to_images(paper_bytes)
|
||||
|
||||
# ── Step 2: Vision 结构化拆题 ──
|
||||
PAGE_BATCH = 8
|
||||
all_questions: list = []
|
||||
meta: dict = {}
|
||||
num_page_batches = -(-len(paper_images) // PAGE_BATCH)
|
||||
for i in range(0, len(paper_images), PAGE_BATCH):
|
||||
batch_imgs = paper_images[i:i + PAGE_BATCH]
|
||||
batch_idx = i // PAGE_BATCH + 1
|
||||
_update_progress(sb, paper_id, f"Reading pages {i+1}-{i+len(batch_imgs)}...", batch_idx, num_page_batches)
|
||||
batch_result = await gemini_vision_json(
|
||||
system_prompt=STRUCTURE_PROMPT,
|
||||
images=batch_imgs,
|
||||
user_text=f"Pages {i+1}-{i+len(batch_imgs)} of the exam paper. Extract all questions visible on these pages.",
|
||||
temperature=0,
|
||||
)
|
||||
if not meta:
|
||||
meta = {k: batch_result.get(k) for k in ("total_score", "difficulty_level", "topics_summary")}
|
||||
all_questions.extend(batch_result.get("questions", []))
|
||||
|
||||
all_questions = sort_questions(all_questions)
|
||||
questions = all_questions
|
||||
|
||||
# 更新 paper 概览
|
||||
sb.table("papers").update({
|
||||
"total_score": meta.get("total_score"),
|
||||
"question_count": len(questions),
|
||||
"topics_summary": meta.get("topics_summary"),
|
||||
"difficulty_level": meta.get("difficulty_level"),
|
||||
}).eq("id", paper_id).execute()
|
||||
|
||||
# ── Step 3: 答案匹配(分批,失败跳过)──
|
||||
answers_map = {}
|
||||
if answer_bytes:
|
||||
_update_progress(sb, paper_id, "Matching answers...")
|
||||
try:
|
||||
answer_images = pdf_to_images(answer_bytes)
|
||||
questions_json = json.dumps(
|
||||
[{"question_number": q["question_number"], "question_type": q["question_type"]}
|
||||
for q in questions], ensure_ascii=False,
|
||||
)
|
||||
all_answers: list = []
|
||||
for ai in range(0, len(answer_images), 8):
|
||||
batch_ans_imgs = answer_images[ai:ai + 8]
|
||||
try:
|
||||
match_result = await gemini_vision_json(
|
||||
system_prompt=ANSWER_MATCH_PROMPT.format(
|
||||
questions_json=questions_json, answer_text="(See images)",
|
||||
),
|
||||
images=batch_ans_imgs,
|
||||
user_text=f"Match answers to these questions: {questions_json}",
|
||||
temperature=0,
|
||||
)
|
||||
all_answers.extend(match_result.get("answers", []))
|
||||
except Exception:
|
||||
pass
|
||||
answers_map = {a["question_number"]: a for a in all_answers}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ── Step 4: 立即写入题目到 DB(先不含 AI trio)──
|
||||
_update_progress(sb, paper_id, "Saving questions...")
|
||||
for i, q in enumerate(questions):
|
||||
qnum = q["question_number"]
|
||||
answer = answers_map.get(qnum, {})
|
||||
sb.table("paper_questions").insert(strip_nulls({
|
||||
"paper_id": paper_id,
|
||||
"question_number": qnum,
|
||||
"parent_question": q.get("parent_question"),
|
||||
"display_order": i,
|
||||
"question_type": q["question_type"],
|
||||
"question_text": q["question_text"],
|
||||
"score": q.get("score"),
|
||||
"page_number": q.get("page_number"),
|
||||
"options": q.get("options"),
|
||||
"correct_option": answer.get("correct_option"),
|
||||
"correct_answer": answer.get("correct_answer"),
|
||||
"raw_answer_text": answer.get("raw_answer_text"),
|
||||
"topics": q.get("topics", []),
|
||||
"analytics_topic": q.get("topics", [None])[0],
|
||||
"topic_tags": q.get("topics", []),
|
||||
"difficulty": q.get("difficulty"),
|
||||
})).execute()
|
||||
|
||||
# ── Step 5: AI trio(逐条更新,支持断点续传)──
|
||||
saved = sb.table("paper_questions").select("id, question_number, solution").eq("paper_id", paper_id).execute().data
|
||||
await _resume_ai_trio(sb, paper_id, saved)
|
||||
|
||||
except Exception as e:
|
||||
sb.table("papers").update({
|
||||
"status": "error",
|
||||
"error_message": f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()[-500:]}",
|
||||
}).eq("id", paper_id).execute()
|
||||
raise
|
||||
13
backend/app/services/supabase_client.py
Normal file
13
backend/app/services/supabase_client.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from supabase import create_client, Client
|
||||
from app.config import get_settings
|
||||
|
||||
_client: Client | None = None
|
||||
|
||||
|
||||
def get_supabase() -> Client:
|
||||
"""获取 Supabase client (service_role,绕过 RLS)"""
|
||||
global _client
|
||||
if _client is None:
|
||||
s = get_settings()
|
||||
_client = create_client(s.supabase_url, s.supabase_service_role_key)
|
||||
return _client
|
||||
48
backend/app/services/text_extractor.py
Normal file
48
backend/app/services/text_extractor.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""PDF 文本提取 — 复用 SOS 的 text_extractor 逻辑"""
|
||||
|
||||
import base64
|
||||
import fitz # PyMuPDF
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedContent:
|
||||
pages_text: list[str] # 每页文本
|
||||
page_images: dict[int, str] # 页码 → base64 图片(图片密集型页面)
|
||||
total_pages: int
|
||||
has_images: bool
|
||||
|
||||
|
||||
def extract_pdf(file_bytes: bytes) -> ExtractedContent:
|
||||
"""从 PDF 提取文本和图片"""
|
||||
doc = fitz.open(stream=file_bytes, filetype="pdf")
|
||||
pages_text = []
|
||||
page_images = {}
|
||||
|
||||
for i, page in enumerate(doc):
|
||||
text = page.get_text("text")
|
||||
pages_text.append(text)
|
||||
|
||||
# 如果某页文本很少但有图片,可能是扫描件 → 保存为图片用于 Vision OCR
|
||||
if len(text.strip()) < 50:
|
||||
pix = page.get_pixmap(dpi=200)
|
||||
img_bytes = pix.tobytes("png")
|
||||
page_images[i] = base64.b64encode(img_bytes).decode("utf-8")
|
||||
|
||||
doc.close()
|
||||
|
||||
return ExtractedContent(
|
||||
pages_text=pages_text,
|
||||
page_images=page_images,
|
||||
total_pages=len(pages_text),
|
||||
has_images=len(page_images) > 0,
|
||||
)
|
||||
|
||||
|
||||
def get_full_text(extracted: ExtractedContent) -> str:
|
||||
"""合并所有页面文本"""
|
||||
return "\n\n".join(
|
||||
f"--- Page {i+1} ---\n{text}"
|
||||
for i, text in enumerate(extracted.pages_text)
|
||||
if text.strip()
|
||||
)
|
||||
Reference in New Issue
Block a user