Initial commit: PastPaper Master full stack

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Zhao
2026-04-21 12:15:35 +07:00
commit 7a09167261
105 changed files with 24799 additions and 0 deletions

0
backend/app/__init__.py Normal file
View File

36
backend/app/config.py Normal file
View File

@@ -0,0 +1,36 @@
from pydantic_settings import BaseSettings
from functools import lru_cache
import os
class Settings(BaseSettings):
# Supabase
supabase_url: str
supabase_anon_key: str
supabase_service_role_key: str
# LLM - laozhang (gpt-4o, gpt-4o-mini)
laozhang_base_url: str = "https://api.laozhang.ai/v1"
laozhang_api_key: str = ""
# LLM - DashScope (qwen-plus)
dashscope_base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
dashscope_api_key: str = ""
# LLM - DeepSeek
deepseek_base_url: str = "https://api.deepseek.com/v1"
deepseek_api_key: str = ""
# Google Gemini (official)
google_gemini_api_key: str = ""
model_config = {
"env_file": os.path.join(os.path.dirname(__file__), "../../.env"),
"env_file_encoding": "utf-8",
"extra": "ignore",
}
@lru_cache
def get_settings() -> Settings:
return Settings()

View File

View File

@@ -0,0 +1,34 @@
"""Auth dependency: validate Supabase JWT and return user_id"""
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from app.services.supabase_client import get_supabase
bearer_scheme = HTTPBearer(auto_error=False)
async def get_current_user_id(
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
) -> str:
"""Extract and validate Bearer token, return user_id."""
if not credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
)
token = credentials.credentials
sb = get_supabase()
try:
result = sb.auth.get_user(token)
user = result.user
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
)
return user.id
except Exception:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
)

59
backend/app/main.py Normal file
View File

@@ -0,0 +1,59 @@
import asyncio
import threading
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.routers import analytics, papers, attempts, questions
def _resume_stale_papers():
"""启动时检查卡在 processing 的 paper自动续传 AI trio"""
try:
from app.services.supabase_client import get_supabase
from app.services.paper_processor import process_paper
sb = get_supabase()
stale = sb.table("papers").select("id").eq("status", "processing").execute().data
if not stale:
return
for p in stale:
paper_id = p["id"]
print(f"[STARTUP] Resuming processing for paper {paper_id[:8]}...")
def run(pid=paper_id):
asyncio.run(process_paper(pid, b"", None))
threading.Thread(target=run, daemon=True).start()
except Exception as e:
print(f"[STARTUP] Resume skipped: {e}")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
_resume_stale_papers()
yield
# Shutdown (nothing to do)
app = FastAPI(title="PastPaper Master API", version="0.1.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 开发阶段先放开,上线收紧
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(papers.router, prefix="/api/papers", tags=["papers"])
app.include_router(attempts.router, prefix="/api/attempts", tags=["attempts"])
app.include_router(questions.router, prefix="/api/questions", tags=["questions"])
app.include_router(analytics.router, prefix="/api/analytics", tags=["analytics"])
@app.get("/health")
def health():
return {"status": "ok"}

View File

View File

@@ -0,0 +1,285 @@
"""Course-level analytics endpoints."""
from __future__ import annotations
from collections import Counter, defaultdict
from fastapi import APIRouter
from app.services.supabase_client import get_supabase
router = APIRouter()
DIFFICULTY_SCORE = {"easy": 1, "medium": 2, "hard": 3}
DIFFICULTY_LABEL = {1: "Easy", 2: "Medium", 3: "Hard"}
# ── Topic normalization ──────────────────────────────────────
# Map variant spellings to canonical label
_TOPIC_ALIASES: dict[str, str] = {
"numpy": "NumPy",
"naïve bayes": "Naive Bayes",
"naïve bayes classifier": "Naive Bayes",
"naive bayes classifier": "Naive Bayes",
"bayes classifier": "Naive Bayes",
"bayes model": "Naive Bayes",
"bayes' theorem": "Naive Bayes",
"bayes' rule": "Naive Bayes",
"k-nearest neighbors": "K-Nearest Neighbors (KNN)",
"knn": "K-Nearest Neighbors (KNN)",
"k-means clustering": "K-Means Clustering",
"k-means": "K-Means Clustering",
"k means": "K-Means Clustering",
"multilayer perceptron": "Multilayer Perceptron (MLP)",
"multi-layer perceptron": "Multilayer Perceptron (MLP)",
"multi-layer perceptron (mlp)": "Multilayer Perceptron (MLP)",
"mlp": "Multilayer Perceptron (MLP)",
"single layer perceptron": "Perceptron",
"convolutional neural network": "CNN",
"convolutional neural network (cnn)": "CNN",
"convolutional neural networks": "CNN",
"cnn architecture": "CNN",
"cnn properties": "CNN",
"python fundamentals": "Python",
"python programming": "Python",
"python implementation": "Python",
"advanced python programming": "Python",
"python programming: convolutional neural network": "CNN",
"cross-validation": "Cross Validation",
"model evaluation implementation": "Model Evaluation",
"digital image processing": "Image Processing",
"computer vision": "Image Processing",
"array slicing": "Array Slicing",
"slicing": "Array Slicing",
"array indexing": "Array Slicing",
"array reshaping": "Reshape",
"array views": "Array Slicing",
"view vs copy": "Array Slicing",
"boolean indexing": "Array Slicing",
"arange": "NumPy",
"newaxis": "NumPy",
"expand dims": "NumPy",
"transpose": "NumPy",
"type casting": "NumPy",
"element-wise operation": "NumPy",
"array reduction": "NumPy",
"multi-dimensional array": "NumPy",
"dot product": "NumPy",
"vectorization": "NumPy",
"activation functions": "Activation Function",
"linear activation function": "Activation Function",
"neural network architecture": "Neural Networks",
"hidden layer": "Neural Networks",
"deep learning": "Neural Networks",
"deep learning frameworks": "Neural Networks",
"alpha-beta pruning": "Alpha-Beta Pruning",
"minimax algorithm": "Minimax",
"ethics of ai": "AI Ethics",
"ethics": "AI Ethics",
"cosine distance": "Cosine Similarity",
"distance calculation": "Distance Metrics",
"euclidean distance": "Distance Metrics",
"manhattan distance": "Distance Metrics",
"hamming distance": "Distance Metrics",
"precision": "Model Evaluation",
"recall": "Model Evaluation",
"f1 score": "Model Evaluation",
"macro f1 score": "Model Evaluation",
"accuracy": "Model Evaluation",
"classification accuracy": "Model Evaluation",
"confusion matrix": "Model Evaluation",
"convolution operation": "Convolution",
"dilated convolution": "Convolution",
"3d convolution": "Convolution",
"gaussian likelihood": "Probability",
"gaussian distribution": "Probability",
"categorical likelihood": "Probability",
"conditional probability": "Probability",
"total probability theorem": "Probability",
"probability assumptions": "Probability",
"tensorflow": "Keras",
"model summary": "Keras",
"model construction": "Keras",
"trainable parameters": "Parameter Calculation",
"parameter reduction": "Parameter Calculation",
"output shape calculation": "Parameter Calculation",
"shape calculation": "Parameter Calculation",
}
def normalize_topic(label: str) -> str:
return _TOPIC_ALIASES.get(label.lower().strip(), label)
def extract_topic_labels(question: dict) -> list[str]:
labels: list[str] = []
raw_labels: list[str] = []
analytics_topic = question.get("analytics_topic")
if analytics_topic:
raw_labels.append(analytics_topic)
for tag in question.get("topic_tags") or []:
if tag and tag not in raw_labels:
raw_labels.append(tag)
if not raw_labels:
for tag in question.get("topics") or []:
if tag and tag not in raw_labels:
raw_labels.append(tag)
# Normalize and deduplicate
seen: set[str] = set()
for raw in raw_labels:
norm = normalize_topic(raw)
if norm not in seen:
seen.add(norm)
labels.append(norm)
return labels
def extract_question_family(question: dict) -> str:
return (
question.get("question_format")
or question.get("question_type")
or "unknown"
)
@router.get("/courses")
async def list_courses():
"""返回所有有 ready 状态试卷的课程列表"""
sb = get_supabase()
rows = (
sb.table("papers")
.select("course_code")
.eq("status", "ready")
.execute()
.data
)
codes = sorted({row["course_code"] for row in rows if row.get("course_code")})
return codes
@router.get("/course/{course_code}")
async def get_course_analytics(course_code: str):
sb = get_supabase()
papers = (
sb.table("papers")
.select("id, course_code, year, term, exam_type, part_label, status")
.eq("course_code", course_code.upper())
.eq("status", "ready")
.order("year", desc=True)
.execute()
.data
)
if not papers:
return {
"course_code": course_code.upper(),
"kpi": {"papers": 0, "questions": 0, "topics": 0, "difficulty": "N/A"},
"topic_frequency": [],
"question_types": [],
"difficulty_distribution": {"easy": 0, "medium": 0, "hard": 0},
"high_yield_topics": [],
}
paper_ids = [paper["id"] for paper in papers]
questions = (
sb.table("paper_questions")
.select(
"id, paper_id, question_number, question_type, question_format, "
"question_text, score, topics, analytics_topic, topic_tags, difficulty"
)
.in_("paper_id", paper_ids)
.order("display_order")
.execute()
.data
)
papers_by_id = {paper["id"]: paper for paper in papers}
total_questions = len(questions)
topic_counter: Counter[str] = Counter()
type_counter: Counter[str] = Counter()
difficulty_counter: Counter[str] = Counter()
topic_examples: dict[str, list[dict]] = defaultdict(list)
difficulty_scores: list[int] = []
all_question_items: list[dict] = []
for question in questions:
question_type = extract_question_family(question)
type_counter[question_type] += 1
difficulty = question.get("difficulty")
if difficulty in DIFFICULTY_SCORE:
difficulty_counter[difficulty] += 1
difficulty_scores.append(DIFFICULTY_SCORE[difficulty])
paper = papers_by_id.get(question["paper_id"], {})
source_label = (
f"{paper.get('year', '')} {paper.get('term', '').title()} "
f"{paper.get('exam_type', '').title()}"
).strip()
if paper.get("part_label"):
source_label = f"{source_label} Part {paper['part_label']}"
topics = extract_topic_labels(question)
q_item = {
"paper_id": paper.get("id"),
"source": source_label,
"question_number": question["question_number"],
"preview": question["question_text"][:220],
"difficulty": question.get("difficulty"),
"question_type": question_type,
"year": paper.get("year"),
"term": paper.get("term"),
"exam_type": paper.get("exam_type"),
"topics": topics,
}
all_question_items.append(q_item)
for topic in topics:
topic_counter[topic] += 1
topic_examples[topic].append(q_item)
avg_difficulty = "N/A"
if difficulty_scores:
rounded = round(sum(difficulty_scores) / len(difficulty_scores))
avg_difficulty = DIFFICULTY_LABEL.get(rounded, "Medium")
topic_frequency = []
for topic, count in topic_counter.most_common():
pct = round((count / total_questions) * 100) if total_questions else 0
topic_frequency.append(
{
"label": topic,
"count": count,
"pct": pct,
"questions": topic_examples[topic],
}
)
question_types = []
for label, count in type_counter.most_common():
pct = round((count / total_questions) * 100) if total_questions else 0
question_types.append({"label": label, "count": count, "pct": pct})
return {
"course_code": course_code.upper(),
"kpi": {
"papers": len(papers),
"questions": total_questions,
"topics": len(topic_counter),
"difficulty": avg_difficulty,
},
"topic_frequency": topic_frequency,
"question_types": question_types,
"all_questions": all_question_items,
"difficulty_distribution": {
"easy": difficulty_counter.get("easy", 0),
"medium": difficulty_counter.get("medium", 0),
"hard": difficulty_counter.get("hard", 0),
},
"high_yield_topics": [topic for topic, _ in topic_counter.most_common(5)],
}

View File

@@ -0,0 +1,208 @@
"""用户答题记录 + 拍照批改 + 错题本"""
import asyncio
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends
from pydantic import BaseModel
from app.services.supabase_client import get_supabase
from app.services.grader import ocr_photo, grade_answer
from app.dependencies.auth import get_current_user_id
router = APIRouter()
class AttemptCreate(BaseModel):
question_id: str
attempt_type: str # "select" | "input" | "photo"
user_answer: str | None = None
is_correct: bool | None = None
class AttemptUpdate(BaseModel):
in_error_book: bool | None = None
mastered: bool | None = None
@router.post("/")
async def create_attempt(data: AttemptCreate, user_id: str = Depends(get_current_user_id)):
"""记录一次答题"""
sb = get_supabase()
record = {
"user_id": user_id,
"question_id": data.question_id,
"attempt_type": data.attempt_type,
"user_answer": data.user_answer,
"is_correct": data.is_correct,
}
# Auto add to error book if wrong
if data.is_correct is False:
record["in_error_book"] = True
result = sb.table("user_attempts").insert(record).execute()
return result.data[0]
@router.post("/photo")
async def photo_attempt(
question_id: str = Form(...),
photo: UploadFile = File(...),
user_id: str = Depends(get_current_user_id),
):
"""拍照上传 → OCR → AI批改"""
sb = get_supabase()
# 1. Read photo
photo_bytes = await photo.read()
# 2. Upload to storage
storage_path = f"attempts/{user_id}/{question_id}/{photo.filename}"
sb.storage.from_("attempt-photos").upload(
storage_path, photo_bytes,
file_options={"content-type": photo.content_type or "image/jpeg", "upsert": "true"},
)
photo_url = sb.storage.from_("attempt-photos").get_public_url(storage_path)
# 3. OCR (run in thread pool to avoid blocking event loop)
ocr_text = await asyncio.to_thread(ocr_photo, photo_bytes)
# 4. Fetch question for grading context
q_result = sb.table("paper_questions").select("*").eq("id", question_id).execute()
if not q_result.data:
raise HTTPException(status_code=404, detail="Question not found")
question = q_result.data[0]
# 5. AI grading (run in thread pool)
grade_result = await asyncio.to_thread(grade_answer, question, ocr_text)
# 6. Save attempt
record = {
"user_id": user_id,
"question_id": question_id,
"attempt_type": "photo",
"photo_url": photo_url,
"photo_ocr_text": ocr_text,
"is_correct": grade_result.get("is_correct", False),
"feedback": grade_result.get("feedback", ""),
"error_at_step": grade_result.get("error_at_step"),
"in_error_book": not grade_result.get("is_correct", False),
}
result = sb.table("user_attempts").insert(record).execute()
return {
"attempt": result.data[0],
"ocr_text": ocr_text,
"grade": grade_result,
}
@router.get("/error-book")
async def get_error_book(
course_code: str | None = None,
user_id: str = Depends(get_current_user_id),
):
"""获取错题本"""
sb = get_supabase()
attempts = (
sb.table("user_attempts")
.select("*")
.eq("user_id", user_id)
.eq("in_error_book", True)
.eq("mastered", False)
.order("created_at", desc=True)
.execute()
.data
)
if not attempts:
return []
question_ids = list({attempt["question_id"] for attempt in attempts})
questions = (
sb.table("paper_questions")
.select("*")
.in_("id", question_ids)
.execute()
.data
)
questions_by_id = {question["id"]: question for question in questions}
paper_ids = list({question["paper_id"] for question in questions})
papers = (
sb.table("papers")
.select("id, course_code, year, term, exam_type, part_label")
.in_("id", paper_ids)
.execute()
.data
)
papers_by_id = {paper["id"]: paper for paper in papers}
enriched = []
for attempt in attempts:
question = questions_by_id.get(attempt["question_id"])
if not question:
continue
paper = papers_by_id.get(question["paper_id"])
if course_code and paper and paper.get("course_code") != course_code.upper():
continue
enriched.append(
{
**attempt,
"paper_questions": {
**question,
"paper": paper,
},
}
)
return enriched
@router.get("/by-paper/{paper_id}")
async def get_paper_attempts(paper_id: str, user_id: str = Depends(get_current_user_id)):
"""获取某张试卷所有题目的最新判卷记录"""
sb = get_supabase()
attempts = (
sb.table("user_attempts")
.select("question_id, is_correct, feedback, photo_ocr_text, attempt_type, created_at")
.eq("user_id", user_id)
.order("created_at", desc=True)
.execute()
.data
)
# 只保留 photo 类型的,且只保留每题最新一条
question_ids = (
sb.table("paper_questions")
.select("id")
.eq("paper_id", paper_id)
.execute()
.data
)
qid_set = {q["id"] for q in question_ids}
seen: set[str] = set()
result = []
for a in attempts:
if a["question_id"] not in qid_set:
continue
if a["question_id"] in seen:
continue
if a["attempt_type"] != "photo":
continue
seen.add(a["question_id"])
result.append(a)
return result
@router.patch("/{attempt_id}")
async def update_attempt(attempt_id: str, data: AttemptUpdate):
"""更新错题状态(标记掌握等)"""
sb = get_supabase()
update = {}
if data.in_error_book is not None:
update["in_error_book"] = data.in_error_book
if data.mastered is not None:
update["mastered"] = data.mastered
if not update:
raise HTTPException(status_code=400, detail="Nothing to update")
result = sb.table("user_attempts").update(update).eq("id", attempt_id).execute()
if not result.data:
raise HTTPException(status_code=404, detail="Attempt not found")
return result.data[0]

View File

@@ -0,0 +1,142 @@
"""试卷上传 + 处理管线"""
import asyncio
import threading
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends
from app.services.supabase_client import get_supabase
from app.services.text_extractor import extract_pdf, get_full_text
from app.services.paper_processor import process_paper
from app.dependencies.auth import get_current_user_id
router = APIRouter()
def _upload_and_process_sync(
paper_id: str,
storage_path: str,
paper_bytes: bytes,
answer_bytes: bytes | None,
):
"""在独立线程中运行Storage 上传 + AI 处理"""
sb = get_supabase()
try:
paper_storage_path = f"{storage_path}/paper.pdf"
sb.storage.from_("papers").upload(
paper_storage_path, paper_bytes,
file_options={"content-type": "application/pdf", "upsert": "true"},
)
paper_url = sb.storage.from_("papers").get_public_url(paper_storage_path)
update_data: dict = {"paper_file_url": paper_url}
if answer_bytes:
answer_storage_path = f"{storage_path}/answer.pdf"
sb.storage.from_("papers").upload(
answer_storage_path, answer_bytes,
file_options={"content-type": "application/pdf", "upsert": "true"},
)
update_data["answer_file_url"] = sb.storage.from_("papers").get_public_url(answer_storage_path)
sb.table("papers").update(update_data).eq("id", paper_id).execute()
except Exception:
pass
# process_paper 是 async在新事件循环里跑
asyncio.run(process_paper(paper_id, paper_bytes, answer_bytes))
@router.get("/")
async def list_papers():
"""获取试卷列表(公共资产,所有用户共享)"""
sb = get_supabase()
return (
sb.table("papers")
.select("id, course_code, year, term, exam_type, status, question_count, total_score, difficulty_level, processing_step, processing_progress, processing_total, created_at")
.order("created_at", desc=True)
.execute()
.data
)
@router.get("/mine")
async def my_papers(user_id: str = Depends(get_current_user_id)):
"""当前用户上传的试卷(含 processing 状态)"""
sb = get_supabase()
return (
sb.table("papers")
.select("id, course_code, year, term, exam_type, part_label, status, question_count, processing_step, processing_progress, processing_total, created_at")
.eq("user_id", user_id)
.order("created_at", desc=True)
.execute()
.data
)
@router.post("/upload")
async def upload_paper(
paper_file: UploadFile = File(...),
answer_file: UploadFile | None = File(None),
course_code: str = Form(...),
year: int = Form(...),
term: str = Form(...),
exam_type: str = Form(...),
user_id: str = Depends(get_current_user_id),
):
"""上传试卷 PDF可选答案 PDF触发后台处理"""
sb = get_supabase()
# 1. 读取文件内容(已在内存中,快)
paper_bytes = await paper_file.read()
answer_bytes = await answer_file.read() if answer_file else None
# 2. 立即创建记录status=processing马上返回
storage_path = f"{course_code.upper()}/{year}_{term}_{exam_type}"
paper_record = sb.table("papers").insert({
"user_id": user_id,
"course_code": course_code.upper(),
"year": year,
"term": term,
"exam_type": exam_type,
"paper_file_url": "", # 后台上传后更新
"answer_file_url": None,
"status": "processing",
}).execute()
paper_id = paper_record.data[0]["id"]
# 3. 在独立线程中运行,完全不阻塞事件循环
threading.Thread(
target=_upload_and_process_sync,
args=(paper_id, storage_path, paper_bytes, answer_bytes),
daemon=True,
).start()
return {
"paper_id": paper_id,
"status": "processing",
"message": "试卷已上传,正在处理中...",
}
@router.get("/{paper_id}")
async def get_paper(paper_id: str):
"""获取试卷信息 + 处理状态"""
sb = get_supabase()
result = sb.table("papers").select("*").eq("id", paper_id).execute()
if not result.data:
raise HTTPException(status_code=404, detail="Paper not found")
return result.data[0]
@router.get("/{paper_id}/questions")
async def get_questions(paper_id: str):
"""获取试卷的所有题目(含 AI 三件套)"""
sb = get_supabase()
result = (
sb.table("paper_questions")
.select("*")
.eq("paper_id", paper_id)
.order("display_order")
.execute()
)
return result.data

View File

@@ -0,0 +1,325 @@
"""题目相关:变式题生成 + 相似题召回"""
from __future__ import annotations
import asyncio
import time
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
from app.services.supabase_client import get_supabase
from app.services.grader import generate_variant
from app.dependencies.auth import get_current_user_id
# Simple in-memory cache: question_id → (timestamp, result)
_similar_cache: dict[str, tuple[float, list]] = {}
_CACHE_TTL = 300 # 5 minutes
class VariantUpdate(BaseModel):
favorited: bool | None = None
router = APIRouter()
def normalized_labels(values: list[str] | None) -> dict[str, str]:
labels: dict[str, str] = {}
for value in values or []:
if value:
labels[value.lower()] = value
return labels
def question_family(question: dict) -> str:
return question.get("question_format") or question.get("question_type") or "unknown"
def display_topics(question: dict) -> list[str]:
labels: list[str] = []
analytics_topic = question.get("analytics_topic")
if analytics_topic:
labels.append(analytics_topic)
for topic in question.get("topic_tags") or []:
if topic and topic not in labels:
labels.append(topic)
if labels:
return labels
for topic in question.get("topics") or []:
if topic and topic not in labels:
labels.append(topic)
return labels
def similarity_score(
target: dict,
candidate: dict,
text_score: float = 0.0,
) -> tuple[int, list[str]]:
score = 0
reasons: list[str] = []
# Primary topic bucket: 40 pts
target_topic = target.get("analytics_topic")
candidate_topic = candidate.get("analytics_topic")
if target_topic and target_topic == candidate_topic:
score += 40
reasons.append(f"Same topic: {target_topic}")
# Concept overlap: up to 20 pts
target_topics = normalized_labels(target.get("topic_tags"))
candidate_topics = normalized_labels(candidate.get("topic_tags"))
shared_topics = sorted(set(target_topics) & set(candidate_topics))
if shared_topics:
score += min(len(shared_topics) * 10, 20)
# Only show concept reason if analytics_topic didn't already match (avoid redundancy)
if not (target_topic and target_topic == candidate_topic):
reasons.append(
"Shared concept: "
+ ", ".join(target_topics[key] for key in shared_topics[:2])
)
# Skill overlap: up to 20 pts
target_skills = normalized_labels(target.get("skill_tags"))
candidate_skills = normalized_labels(candidate.get("skill_tags"))
shared_skills = sorted(set(target_skills) & set(candidate_skills))
if shared_skills:
score += min(len(shared_skills) * 10, 20)
reasons.append(
"Shared skill: "
+ ", ".join(target_skills[key] for key in shared_skills[:2])
)
# Same question format: 10 pts
if question_family(candidate) == question_family(target):
score += 10
reasons.append("Same format")
# Same difficulty: 5 pts
if candidate.get("difficulty") and candidate.get("difficulty") == target.get("difficulty"):
score += 5
reasons.append("Same difficulty")
# Full-text similarity from PostgreSQL ts_rank_cd: up to 20 pts
if text_score > 0:
text_pts = min(round(text_score * 60), 20)
score += text_pts
if text_pts >= 4:
reasons.append("Similar wording")
return min(score, 99), reasons
@router.get("/variants/favorited")
async def get_favorited_variants(user_id: str = Depends(get_current_user_id)):
"""获取用户收藏的所有 variant用于 Error Book"""
sb = get_supabase()
rows = (
sb.table("question_variants")
.select("*, paper_questions(question_number, paper_id, papers(id, course_code, year, term, exam_type, part_label))")
.eq("user_id", user_id)
.eq("favorited", True)
.order("created_at", desc=True)
.execute()
.data
)
return rows
@router.post("/{question_id}/variant")
async def create_variant(question_id: str, user_id: str = Depends(get_current_user_id)):
"""生成变式题并入库"""
sb = get_supabase()
result = sb.table("paper_questions").select("*").eq("id", question_id).execute()
if not result.data:
raise HTTPException(status_code=404, detail="Question not found")
question = result.data[0]
variant_data = await asyncio.to_thread(generate_variant, question)
variant_data["knowledge_reminder"] = question.get("knowledge_reminder", "")
saved = sb.table("question_variants").insert({
"user_id": user_id,
"source_question_id": question_id,
"variant_data": variant_data,
"favorited": False,
}).execute()
row = saved.data[0]
row["source_question_number"] = question["question_number"]
return row
@router.get("/{question_id}/variants")
async def list_variants(question_id: str, user_id: str = Depends(get_current_user_id)):
"""获取某道题的用户所有 variant"""
sb = get_supabase()
q_result = sb.table("paper_questions").select("question_number").eq("id", question_id).execute()
question_number = q_result.data[0]["question_number"] if q_result.data else ""
rows = (
sb.table("question_variants")
.select("*")
.eq("user_id", user_id)
.eq("source_question_id", question_id)
.order("created_at", desc=True)
.execute()
.data
)
for row in rows:
row["source_question_number"] = question_number
return rows
@router.patch("/variant/{variant_id}")
async def update_variant(variant_id: str, data: VariantUpdate, user_id: str = Depends(get_current_user_id)):
"""更新 variant收藏/取消收藏)"""
sb = get_supabase()
update: dict = {}
if data.favorited is not None:
update["favorited"] = data.favorited
if not update:
raise HTTPException(status_code=400, detail="Nothing to update")
result = (
sb.table("question_variants")
.update(update)
.eq("id", variant_id)
.eq("user_id", user_id)
.execute()
)
if not result.data:
raise HTTPException(status_code=404, detail="Variant not found")
return result.data[0]
@router.delete("/variant/{variant_id}", status_code=204)
async def delete_variant(variant_id: str, user_id: str = Depends(get_current_user_id)):
"""删除 variant"""
sb = get_supabase()
sb.table("question_variants").delete().eq("id", variant_id).eq("user_id", user_id).execute()
@router.get("/{question_id}/similar")
async def get_similar_questions(question_id: str, limit: int = 6):
"""Retrieve similar questions from the same course."""
# Cache hit
cached = _similar_cache.get(question_id)
if cached and (time.time() - cached[0]) < _CACHE_TTL:
return cached[1][:max(1, min(limit, 12))]
sb = get_supabase()
result = sb.table("paper_questions").select("*, similar_questions").eq("id", question_id).execute()
if not result.data:
raise HTTPException(status_code=404, detail="Question not found")
target = result.data[0]
# Return pre-computed immediately; schedule background refresh
if target.get("similar_questions"):
precomputed = target["similar_questions"]
_similar_cache[question_id] = (time.time(), precomputed)
return precomputed[:max(1, min(limit, 12))]
paper_result = sb.table("papers").select("id, course_code").eq("id", target["paper_id"]).execute()
# (fallback: compute on-the-fly for questions not yet backfilled)
if not paper_result.data:
raise HTTPException(status_code=404, detail="Paper not found")
course_code = paper_result.data[0]["course_code"]
papers = (
sb.table("papers")
.select("id, course_code, year, term, exam_type, part_label")
.eq("course_code", course_code)
.eq("status", "ready")
.execute()
.data
)
paper_ids = [paper["id"] for paper in papers if paper["id"] != target["paper_id"]]
if not paper_ids:
return []
papers_by_id = {paper["id"]: paper for paper in papers}
# Pre-filter by analytics_topic in DB when possible (cuts candidates from ~250 to ~30)
candidates_query = (
sb.table("paper_questions")
.select(
"id, paper_id, question_number, question_type, question_format, "
"question_text, score, topics, analytics_topic, topic_tags, skill_tags, "
"difficulty, knowledge_reminder, ai_hint, solution"
)
.in_("paper_id", paper_ids)
)
target_topic = target.get("analytics_topic")
if target_topic:
candidates_query = candidates_query.eq("analytics_topic", target_topic)
candidates = candidates_query.execute().data
if not candidates:
return []
# Batch full-text scores from PostgreSQL (skip if too many candidates — slow)
text_scores: dict[str, float] = {}
if len(candidates) <= 50:
try:
rpc_result = sb.rpc(
"text_similarity_scores",
{
"query_text": target.get("question_text") or "",
"candidate_ids": [c["id"] for c in candidates],
},
).execute()
for row in rpc_result.data or []:
text_scores[row["question_id"]] = float(row["text_score"] or 0)
except Exception:
pass
ranked = []
for candidate in candidates:
text_score = text_scores.get(candidate["id"], 0.0)
match_percent, reasons = similarity_score(target, candidate, text_score)
if match_percent < 20:
continue
paper = papers_by_id.get(candidate["paper_id"], {})
source = (
f"{paper.get('year', '')} {paper.get('term', '').title()} "
f"{paper.get('exam_type', '').title()}"
).strip()
if paper.get("part_label"):
source = f"{source} Part {paper['part_label']}"
ranked.append(
{
"id": candidate["id"],
"paper_id": candidate["paper_id"],
"source": source,
"question_number": candidate["question_number"],
"match_percent": match_percent,
"match_reasons": reasons,
"question_type": question_family(candidate),
"question_text": candidate["question_text"],
"topics": display_topics(candidate),
"difficulty": candidate.get("difficulty"),
"knowledge_reminder": candidate.get("knowledge_reminder", ""),
"ai_hint": candidate.get("ai_hint", ""),
"solution": candidate.get("solution", ""),
}
)
ranked.sort(key=lambda item: (-item["match_percent"], item["source"], item["question_number"]))
# Keep only the best-scoring question per paper
seen_papers: set[str] = set()
deduped = []
for item in ranked:
if item["paper_id"] not in seen_papers:
seen_papers.add(item["paper_id"])
deduped.append(item)
_similar_cache[question_id] = (time.time(), deduped)
# Persist to DB so future requests are instant
try:
sb.table("paper_questions").update({"similar_questions": deduped}).eq("id", question_id).execute()
except Exception:
pass
return deduped[:max(1, min(limit, 12))]

View File

View File

@@ -0,0 +1,146 @@
"""OCR, grading, and variant generation prompts"""
import json
import base64
from app.services.llm_clients import get_vision_client, get_deepseek_client
OCR_PROMPT = """You are an expert at recognizing handwritten answers. Analyze this photo of a student's handwritten answer and extract the text and mathematical formulas.
Requirements:
- Faithfully extract what the student wrote, do not modify or correct
- Use LaTeX format for math formulas (e.g. $x^2 + 1$)
- If there are multiple steps, list them in original order
- If some handwriting is unclear, mark with [unclear]
Return only the extracted text, no additional explanation."""
GRADING_PROMPT = """You are an expert academic grader. Grade the following student answer. ALL output must be in English.
Question info:
- Number: {question_number}
- Type: {question_type}
- Question: {question_text}
- Score: {score}
Reference answer / solution:
{reference_answer}
Student answer:
{student_answer}
Grade and return JSON:
{{
"is_correct": true/false,
"score_given": 0-{score},
"feedback": "<HTML> Step-by-step analysis of the student's answer, pointing out correct parts and errors, using KaTeX formulas </HTML>",
"error_at_step": null or the step number where errors begin (integer)
}}
Grading rules:
- MC / fill-blank: only correct if answer matches exactly
- Long questions: give partial credit for correct steps even if the final answer is wrong
- feedback in HTML format, supports KaTeX ($..$ inline, $$...$$ block)
- Mark errors with <div class="common-error">...</div>
- Identify exactly which step the error starts"""
VARIANT_PROMPT = """You are an expert exam question creator. Generate a similar but different variant question based on the original below. ALL output must be in English.
Original question info:
- Type: {question_type}
- Question: {question_text}
- Topics: {topics}
- Difficulty: {difficulty}
- Reference answer: {answer}
Requirements:
- Variant must test the same knowledge points at similar difficulty
- Data/scenario/wording must differ — don't just change numbers
- Must provide a complete correct answer
Format requirements (CRITICAL):
- All text in HTML format, absolutely NO markdown syntax
- Code: <pre><code class="language-xxx">...</code></pre>, NOT ```
- Math: $...$ (inline) or $$...$$ (block), KaTeX compatible
- Line breaks: <br>, paragraphs: <p>
Return JSON:
{{
"question_text": "HTML formatted variant question",
"question_type": "{question_type}",
"options": [MC only, format {{"label":"A","text":"..."}}, ...] or null,
"correct_answer": "Correct answer (plain text)",
"ai_hint": "HTML formatted hint that guides thinking WITHOUT giving the answer",
"solution": "HTML formatted complete step-by-step solution"
}}"""
def ocr_photo(photo_bytes: bytes) -> str:
"""Gemini Vision OCR for handwritten answers"""
client = get_vision_client()
b64 = base64.b64encode(photo_bytes).decode("utf-8")
resp = client.chat.completions.create(
model="gemini-2.5-flash",
messages=[
{"role": "system", "content": OCR_PROMPT},
{"role": "user", "content": [
{"type": "image_url", "image_url": {
"url": f"data:image/jpeg;base64,{b64}",
}},
]},
],
temperature=0,
max_tokens=2000,
)
return resp.choices[0].message.content or ""
def grade_answer(question: dict, student_answer: str) -> dict:
"""Qwen grades student answer"""
reference = question.get("raw_answer_text") or question.get("solution") or "No reference answer"
score = question.get("score") or "unknown"
ds = get_deepseek_client()
resp = ds.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": GRADING_PROMPT.format(
question_number=question["question_number"],
question_type=question["question_type"],
question_text=question["question_text"],
score=score,
reference_answer=reference,
student_answer=student_answer,
)},
],
temperature=0.2,
response_format={"type": "json_object"},
)
return json.loads(resp.choices[0].message.content)
def generate_variant(question: dict) -> dict:
"""Gemini generates a variant question"""
answer = (
question.get("correct_option")
or question.get("correct_answer")
or question.get("raw_answer_text")
or "N/A"
)
ds = get_deepseek_client()
resp = ds.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": VARIANT_PROMPT.format(
question_type=question["question_type"],
question_text=question["question_text"],
topics=", ".join(question.get("topics", [])),
difficulty=question.get("difficulty", "medium"),
answer=answer,
)},
],
temperature=0.5,
response_format={"type": "json_object"},
)
return json.loads(resp.choices[0].message.content)

View File

@@ -0,0 +1,74 @@
import httpx
from openai import OpenAI
from app.config import get_settings
_TIMEOUT = httpx.Timeout(connect=10, read=300, write=60, pool=10)
_gpt_client: OpenAI | None = None
_qwen_client: OpenAI | None = None
_gemini_flash_client: OpenAI | None = None
_gemini_lite_client: OpenAI | None = None
_deepseek_client: OpenAI | None = None
def get_gpt_client() -> OpenAI:
"""laozhang API — gpt-4o / gpt-4o-mini"""
global _gpt_client
if _gpt_client is None:
s = get_settings()
_gpt_client = OpenAI(
base_url=s.laozhang_base_url,
api_key=s.laozhang_api_key,
)
return _gpt_client
def get_qwen_client() -> OpenAI:
"""DashScope — qwen-plus"""
global _qwen_client
if _qwen_client is None:
s = get_settings()
_qwen_client = OpenAI(
base_url=s.dashscope_base_url,
api_key=s.dashscope_api_key,
)
return _qwen_client
def get_vision_client() -> OpenAI:
"""Google Gemini 官方 API视觉用于拆题+OCR— 部署在新加坡可用"""
global _gemini_flash_client
if _gemini_flash_client is None:
s = get_settings()
_gemini_flash_client = OpenAI(
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
api_key=s.google_gemini_api_key,
timeout=_TIMEOUT,
)
return _gemini_flash_client
def get_gemini_lite_client() -> OpenAI:
"""laozhang — gemini-3.1-flash-lite-preview轻量用于 AI trio"""
global _gemini_lite_client
if _gemini_lite_client is None:
s = get_settings()
_gemini_lite_client = OpenAI(
base_url=s.laozhang_base_url,
api_key=s.laozhang_api_key,
timeout=_TIMEOUT,
)
return _gemini_lite_client
def get_deepseek_client() -> OpenAI:
"""DeepSeek — deepseek-chat用于 AI trio"""
global _deepseek_client
if _deepseek_client is None:
s = get_settings()
_deepseek_client = OpenAI(
base_url=s.deepseek_base_url,
api_key=s.deepseek_api_key,
timeout=_TIMEOUT,
)
return _deepseek_client

View File

@@ -0,0 +1,576 @@
"""试卷处理管线PDF → 结构化题目 → AI 三件套Vision 模式)"""
import asyncio
import base64
import io
import json
import re
import traceback
from contextlib import redirect_stdout
import fitz # pymupdf
from app.services.supabase_client import get_supabase
from app.services.llm_clients import get_vision_client, get_deepseek_client
def strip_nulls(obj):
"""Recursively remove \\u0000 null bytes from strings (PostgreSQL rejects them)."""
if isinstance(obj, str):
return obj.replace("\u0000", "")
if isinstance(obj, dict):
return {k: strip_nulls(v) for k, v in obj.items()}
if isinstance(obj, list):
return [strip_nulls(i) for i in obj]
return obj
# ============================================
# Prompts
# ============================================
STRUCTURE_PROMPT = """You are an expert exam paper structure analyst. You are given images of a past exam paper. Analyze every page carefully and extract all questions into structured JSON.
All generated values must be in English. Do not output Chinese.
CRITICAL RULES for question_text:
- Each question's question_text must be FULLY SELF-CONTAINED. Include ALL context needed to solve it.
- For sub-questions (e.g. (a)(i)), copy the ENTIRE parent question setup (variable definitions, code blocks, problem description) into the question_text, then append the specific sub-question.
- For Python/code questions: include ALL variable definitions and import statements verbatim, exactly as they appear in the exam, preserving multi-line arrays and data structures completely.
- Never truncate code. If a variable is defined across multiple lines (e.g. a numpy array), include every line.
Output JSON format (strictly follow):
{
"total_score": 100,
"difficulty_level": "medium",
"topics_summary": {"Topic A": 40, "Topic B": 30, "Topic C": 30},
"questions": [
{
"question_number": "1a",
"parent_question": "1",
"question_type": "mc",
"question_text": "Original question text...",
"score": 5,
"page_number": 1,
"options": [{"label": "A", "text": "Option content"}, {"label": "B", "text": "..."}],
"topics": ["Linked List", "Pointer"],
"difficulty": "easy"
},
{
"question_number": "2",
"parent_question": null,
"question_type": "long_question",
"question_text": "Original question text...",
"score": 15,
"page_number": 2,
"options": null,
"topics": ["Recursion"],
"difficulty": "hard"
}
]
}
Rules:
- question_type must be one of: "mc" (multiple choice), "true_false" (true/false), "fill_blank" (fill in blank), "long_question" (long question)
- True/False questions MUST use "true_false" type, with options set to [{"label":"True","text":"True"},{"label":"False","text":"False"}], correct_option as "True" or "False"
- Multiple choice must extract the options array
- Sub-questions use parent_question to link to parent: "1a" parent is "1"
- Independent questions without sub-questions set parent_question to null
- page_number inferred from where the question appears
- topics inferred from the question content
- difficulty: "easy" | "medium" | "hard"
- Extract ALL questions, do not miss any
- Keep topic labels in English only
"""
ANSWER_MATCH_PROMPT = """You are an expert exam answer matching specialist. Below is the answer text for an exam paper. Extract and match answers to their corresponding question numbers.
All generated values must be in English. Do not output Chinese.
Question structure:
{questions_json}
Answer text:
{answer_text}
Output JSON format:
{{
"answers": [
{{
"question_number": "1a",
"correct_option": "B",
"correct_answer": null,
"raw_answer_text": "Original answer text..."
}},
{{
"question_number": "2",
"correct_option": null,
"correct_answer": null,
"raw_answer_text": "Complete solution process and answer..."
}}
]
}}
Rules:
- For MC questions, fill correct_option (e.g. "B")
- For fill-blank questions, fill correct_answer (e.g. "O(n log n)")
- For long questions, only fill raw_answer_text (complete solution process)
- Match all questions where answers can be found
- Keep raw_answer_text faithful to the source answer, but do not add Chinese commentary
"""
ANALYSIS_PROMPT = """You are an expert academic answer analyst. Generate three sections for the following exam question. ALL output must be in English.
Question info:
- Number: {question_number}
- Type: {question_type}
- Score: {score}
- Question: {question_text}
- Topics: {topics}
{answer_section}
Generate THREE sections in HTML format (supports KaTeX: block $$ ... $$ inline $ ... $):
Output JSON:
{{
"knowledge_reminder": "<HTML> Prerequisite knowledge points needed for this question, as a concise bullet list </HTML>",
"ai_hint": "<HTML> A hint that guides thinking direction WITHOUT giving away the answer </HTML>",
"solution": "<HTML> Complete step-by-step solution (Step 1, Step 2, ...) with derivations, formulas, and common mistake warnings </HTML>"
}}
Solution requirements:
- Must include complete working process, not just the answer
- Each step must have an explanation
- If a reference answer is provided, derive the solution based on it
- If no reference answer, work out the complete solution independently
- For MC questions, explain why the correct option is right AND why others are wrong
- Use <ol> or numbered steps
- Mark common mistakes with <div class="common-error">...</div>
KaTeX formula rules:
- Block formula: $$ on its own line, with blank lines before and after
- Inline formula: $x^2$ no line break
- Matrix: \\begin{{bmatrix}} ... \\end{{bmatrix}}
- Fraction: \\frac{{a}}{{b}}
"""
BATCH_ANALYSIS_PROMPT = """You are an expert academic answer analyst. Generate three study sections for each question below. ALL output must be in English.
For every question, return:
- knowledge_reminder: concise prerequisite bullets in HTML
- ai_hint: a helpful hint in HTML without revealing the final answer
- solution: a complete step-by-step solution in HTML
Return JSON in this exact format:
{{
"analyses": [
{{
"question_number": "1a",
"knowledge_reminder": "<HTML>...</HTML>",
"ai_hint": "<HTML>...</HTML>",
"solution": "<HTML>...</HTML>"
}}
]
}}
Rules:
- Return one item for every provided question_number
- Keep each item matched to the same question_number
- All text must be in English
- HTML only, KaTeX compatible
- For MC questions, explain why the correct option is right and why the others are wrong
- For long questions, show a complete derivation or reasoning chain
- Use <ol> or numbered steps in solution when appropriate
- Mark common mistakes with <div class="common-error">...</div>
- CRITICAL: When a question_text contains "[Context from parent question X]" followed by "[Sub-question Y]", the parent section is background context only. You MUST solve ONLY the specific sub-question labeled [Sub-question Y]. Do NOT solve other sub-questions listed in the parent context. Give one precise answer for that single sub-question only.
Questions:
{questions_payload}
"""
# ============================================
# 处理管线
# ============================================
RETRYABLE_ERROR_MARKERS = (
"429",
"rate limit",
"rate_limit",
"too many requests",
"timeout",
"timed out",
"connection",
)
def is_retryable_error(exc: Exception) -> bool:
message = str(exc).lower()
return any(marker in message for marker in RETRYABLE_ERROR_MARKERS)
def pdf_to_images(pdf_bytes: bytes, dpi: int = 96) -> list[str]:
"""将 PDF 每页渲染为 base64 PNG 图片列表96dpi 平衡清晰度与成本)"""
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
images = []
mat = fitz.Matrix(dpi / 72, dpi / 72)
for page in doc:
pix = page.get_pixmap(matrix=mat, colorspace=fitz.csRGB)
img_bytes = pix.tobytes("png")
images.append(base64.b64encode(img_bytes).decode())
doc.close()
return images
def parse_json_response(text: str) -> dict:
"""解析模型返回的 JSON兼容 markdown 代码块包装"""
text = text.strip()
# 去掉 ```json ... ``` 包装
if text.startswith("```"):
lines = text.splitlines()
text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:])
# 移除 JSON 字符串中的非法控制字符0x00-0x1F 除了 \t \n \r
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f]', '', text)
# 修复模型返回的无效 JSON 转义序列:只修奇数个反斜杠后的非法字符
text = re.sub(r'(?<!\\)((?:\\\\)*)\\([^"\\/bfnrtu])', r'\1\\\\\2', text)
return json.loads(text)
async def gemini_vision_json(
*,
system_prompt: str,
images: list[str],
user_text: str = "",
temperature: float = 0,
max_attempts: int = 6,
) -> dict:
"""发送图片 + prompt 给 Gemini vision 模型,返回 JSON"""
client = get_vision_client()
delay_seconds = 2
content: list = []
for b64 in images:
content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}})
if user_text:
content.append({"type": "text", "text": user_text})
for attempt in range(1, max_attempts + 1):
try:
response = client.chat.completions.create(
model="gemini-2.5-flash",
messages=[
{"role": "system", "content": system_prompt + "\n\nIMPORTANT: Your entire response must be valid JSON only. No markdown, no code fences, no extra text."},
{"role": "user", "content": content},
],
temperature=temperature,
max_tokens=16384,
)
return parse_json_response(response.choices[0].message.content)
except Exception as exc:
if attempt == max_attempts or not is_retryable_error(exc):
raise
await asyncio.sleep(delay_seconds)
delay_seconds = min(delay_seconds * 2, 30)
async def deepseek_json_completion(
*,
system_prompt: str,
user_prompt: str | None = None,
temperature: float = 0,
max_attempts: int = 6,
) -> dict:
"""DeepSeek 纯文本 JSON completion用于 AI trio 生成)"""
client = get_deepseek_client()
delay_seconds = 2
for attempt in range(1, max_attempts + 1):
try:
messages = [{"role": "system", "content": system_prompt}]
if user_prompt:
messages.append({"role": "user", "content": user_prompt})
response = client.chat.completions.create(
model="deepseek-chat",
messages=messages,
temperature=temperature,
max_tokens=8192,
response_format={"type": "json_object"},
)
raw = response.choices[0].message.content
raw = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f]', '', raw)
raw = re.sub(r'(?<!\\)((?:\\\\)*)\\([^"\\/bfnrtu])', r'\1\\\\\2', raw)
return json.loads(raw)
except Exception as exc:
if attempt == max_attempts or not is_retryable_error(exc):
raise
await asyncio.sleep(delay_seconds)
delay_seconds = min(delay_seconds * 2, 30)
def chunked(items: list[dict], size: int) -> list[list[dict]]:
return [items[i:i + size] for i in range(0, len(items), size)]
def _question_sort_key(qnum: str) -> tuple:
"""自然排序题号1a < 1b < ... < 1i < 1j < 2ai < 2aii < 10a"""
parts = re.findall(r'(\d+|[a-zA-Z]+|[()]+)', qnum)
key = []
for idx, p in enumerate(parts):
if p.isdigit():
key.append((0, int(p), ''))
elif p in ('(', ')'):
continue
else:
# Single letter (a-z): always sort alphabetically (a=1, b=2, ..., j=10)
if len(p) == 1 and p.isalpha():
key.append((1, ord(p.lower()) - ord('a') + 1, p))
else:
# Multi-letter: roman numerals for sub-sub-questions (i=1, ii=2, iii=3, ...)
romans = {'i':1,'ii':2,'iii':3,'iv':4,'v':5,'vi':6,'vii':7,'viii':8,'ix':9,'x':10,'xi':11,'xii':12,'xiii':13}
if p.lower() in romans:
key.append((2, romans[p.lower()], p))
else:
key.append((1, 0, p))
return tuple(key)
def sort_questions(questions: list[dict]) -> list[dict]:
"""按题号自然排序"""
return sorted(questions, key=lambda q: _question_sort_key(q.get("question_number", "")))
def extract_code_block(text: str) -> str:
"""
从题目文本中提取 Python 代码块。
策略找到第一个明确的代码起始行import/赋值/print
然后把后续所有缩进或延续行一并带上,直到明显的非代码段落。
"""
lines = text.splitlines()
result = []
in_code = False
open_brackets = 0
CODE_START = re.compile(r"^\s*(import |from \w|[A-Za-z_]\w*\s*=|print\()")
for line in lines:
stripped = line.strip()
# 已在代码块内:括号未闭合时继续收集
if in_code and open_brackets > 0:
result.append(stripped)
open_brackets += stripped.count("(") + stripped.count("[") + stripped.count("{")
open_brackets -= stripped.count(")") + stripped.count("]") + stripped.count("}")
continue
# 检测新的代码起始行
if CODE_START.match(line):
in_code = True
result.append(stripped)
open_brackets += stripped.count("(") + stripped.count("[") + stripped.count("{")
open_brackets -= stripped.count(")") + stripped.count("]") + stripped.count("}")
continue
# 非代码行:重置(但保留 in_code=True 以便继续接后续代码行)
in_code = False
return "\n".join(result)
# 保持向后兼容
extract_code_lines = extract_code_block
def try_exec_python(code: str, shared_ns: dict) -> str | None:
"""
在 shared_ns 命名空间中执行 code捕获 stdout。
返回输出字符串,失败返回 None。
"""
buf = io.StringIO()
try:
with redirect_stdout(buf):
exec(code, shared_ns) # noqa: S102
output = buf.getvalue().strip()
return output if output else None
except Exception:
return None
async def _resume_ai_trio(sb, paper_id: str, questions: list[dict]):
"""为缺 solution 的题目生成 AI trio逐条写回 DB。支持断点续传。"""
need = [q for q in questions if not q.get("solution")]
if not need:
# 全部已有 solution直接标记完成
sb.table("papers").update({"status": "ready", "processing_step": None}).eq("id", paper_id).execute()
return
total_q = len(questions)
done_q = total_q - len(need)
# 构建 payload
id_map = {q["question_number"]: q["id"] for q in need}
# 需要完整的 question_text 来生成 AI trio
full_data = sb.table("paper_questions").select(
"id, question_number, question_type, question_text, score, correct_option, correct_answer, raw_answer_text"
).eq("paper_id", paper_id).in_("id", [q["id"] for q in need]).execute().data
payloads = []
for q in full_data:
answer_section = q.get("raw_answer_text") or ""
if not answer_section and q.get("correct_option"):
answer_section = f"Correct option: {q['correct_option']}"
elif not answer_section and q.get("correct_answer"):
answer_section = f"Correct answer: {q['correct_answer']}"
payloads.append({
"question_number": q["question_number"],
"question_type": q["question_type"] or "long_question",
"score": q.get("score") or "unknown",
"question_text": q["question_text"] or "",
"reference_answer": answer_section,
})
batches = chunked(payloads, 3)
for batch_idx, batch in enumerate(batches, 1):
current = done_q + batch_idx * 3
_update_progress(sb, paper_id, f"Generating solutions ({min(current, total_q)}/{total_q} questions)", batch_idx, len(batches))
try:
result = await deepseek_json_completion(
system_prompt=BATCH_ANALYSIS_PROMPT.format(
questions_payload=json.dumps(batch, ensure_ascii=False),
),
temperature=0.3,
)
for item in result.get("analyses", []):
qnum = item.get("question_number")
qid = id_map.get(qnum)
if qid:
sb.table("paper_questions").update({
"knowledge_reminder": item.get("knowledge_reminder", ""),
"ai_hint": item.get("ai_hint", ""),
"solution": item.get("solution", ""),
}).eq("id", qid).execute()
except Exception:
pass # 单批失败不影响其他批
await asyncio.sleep(1)
# 标记完成
sb.table("papers").update({"status": "ready", "processing_step": None}).eq("id", paper_id).execute()
def _update_progress(sb, paper_id: str, step: str, progress: int = 0, total: int = 0):
"""更新处理进度到 DB"""
sb.table("papers").update({
"processing_step": step,
"processing_progress": progress,
"processing_total": total,
}).eq("id", paper_id).execute()
async def process_paper(paper_id: str, paper_bytes: bytes, answer_bytes: bytes | None):
"""后台处理管线: PDF pages → Vision 结构化 → AI 三件套
设计原则:每个步骤完成后立即持久化到 DB支持断点续传。
"""
sb = get_supabase()
try:
# 检查是否已有题目(断点续传场景)
existing = sb.table("paper_questions").select("id, question_number, solution").eq("paper_id", paper_id).execute().data
if existing:
# 已有题目 → 跳过提取,直接补 AI trio
await _resume_ai_trio(sb, paper_id, existing)
return
# ── Step 1: PDF → 图片 ──
_update_progress(sb, paper_id, "Rendering PDF pages...")
paper_images = pdf_to_images(paper_bytes)
# ── Step 2: Vision 结构化拆题 ──
PAGE_BATCH = 8
all_questions: list = []
meta: dict = {}
num_page_batches = -(-len(paper_images) // PAGE_BATCH)
for i in range(0, len(paper_images), PAGE_BATCH):
batch_imgs = paper_images[i:i + PAGE_BATCH]
batch_idx = i // PAGE_BATCH + 1
_update_progress(sb, paper_id, f"Reading pages {i+1}-{i+len(batch_imgs)}...", batch_idx, num_page_batches)
batch_result = await gemini_vision_json(
system_prompt=STRUCTURE_PROMPT,
images=batch_imgs,
user_text=f"Pages {i+1}-{i+len(batch_imgs)} of the exam paper. Extract all questions visible on these pages.",
temperature=0,
)
if not meta:
meta = {k: batch_result.get(k) for k in ("total_score", "difficulty_level", "topics_summary")}
all_questions.extend(batch_result.get("questions", []))
all_questions = sort_questions(all_questions)
questions = all_questions
# 更新 paper 概览
sb.table("papers").update({
"total_score": meta.get("total_score"),
"question_count": len(questions),
"topics_summary": meta.get("topics_summary"),
"difficulty_level": meta.get("difficulty_level"),
}).eq("id", paper_id).execute()
# ── Step 3: 答案匹配(分批,失败跳过)──
answers_map = {}
if answer_bytes:
_update_progress(sb, paper_id, "Matching answers...")
try:
answer_images = pdf_to_images(answer_bytes)
questions_json = json.dumps(
[{"question_number": q["question_number"], "question_type": q["question_type"]}
for q in questions], ensure_ascii=False,
)
all_answers: list = []
for ai in range(0, len(answer_images), 8):
batch_ans_imgs = answer_images[ai:ai + 8]
try:
match_result = await gemini_vision_json(
system_prompt=ANSWER_MATCH_PROMPT.format(
questions_json=questions_json, answer_text="(See images)",
),
images=batch_ans_imgs,
user_text=f"Match answers to these questions: {questions_json}",
temperature=0,
)
all_answers.extend(match_result.get("answers", []))
except Exception:
pass
answers_map = {a["question_number"]: a for a in all_answers}
except Exception:
pass
# ── Step 4: 立即写入题目到 DB先不含 AI trio──
_update_progress(sb, paper_id, "Saving questions...")
for i, q in enumerate(questions):
qnum = q["question_number"]
answer = answers_map.get(qnum, {})
sb.table("paper_questions").insert(strip_nulls({
"paper_id": paper_id,
"question_number": qnum,
"parent_question": q.get("parent_question"),
"display_order": i,
"question_type": q["question_type"],
"question_text": q["question_text"],
"score": q.get("score"),
"page_number": q.get("page_number"),
"options": q.get("options"),
"correct_option": answer.get("correct_option"),
"correct_answer": answer.get("correct_answer"),
"raw_answer_text": answer.get("raw_answer_text"),
"topics": q.get("topics", []),
"analytics_topic": q.get("topics", [None])[0],
"topic_tags": q.get("topics", []),
"difficulty": q.get("difficulty"),
})).execute()
# ── Step 5: AI trio逐条更新支持断点续传──
saved = sb.table("paper_questions").select("id, question_number, solution").eq("paper_id", paper_id).execute().data
await _resume_ai_trio(sb, paper_id, saved)
except Exception as e:
sb.table("papers").update({
"status": "error",
"error_message": f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()[-500:]}",
}).eq("id", paper_id).execute()
raise

View File

@@ -0,0 +1,13 @@
from supabase import create_client, Client
from app.config import get_settings
_client: Client | None = None
def get_supabase() -> Client:
"""获取 Supabase client (service_role绕过 RLS)"""
global _client
if _client is None:
s = get_settings()
_client = create_client(s.supabase_url, s.supabase_service_role_key)
return _client

View File

@@ -0,0 +1,48 @@
"""PDF 文本提取 — 复用 SOS 的 text_extractor 逻辑"""
import base64
import fitz # PyMuPDF
from dataclasses import dataclass
@dataclass
class ExtractedContent:
pages_text: list[str] # 每页文本
page_images: dict[int, str] # 页码 → base64 图片(图片密集型页面)
total_pages: int
has_images: bool
def extract_pdf(file_bytes: bytes) -> ExtractedContent:
"""从 PDF 提取文本和图片"""
doc = fitz.open(stream=file_bytes, filetype="pdf")
pages_text = []
page_images = {}
for i, page in enumerate(doc):
text = page.get_text("text")
pages_text.append(text)
# 如果某页文本很少但有图片,可能是扫描件 → 保存为图片用于 Vision OCR
if len(text.strip()) < 50:
pix = page.get_pixmap(dpi=200)
img_bytes = pix.tobytes("png")
page_images[i] = base64.b64encode(img_bytes).decode("utf-8")
doc.close()
return ExtractedContent(
pages_text=pages_text,
page_images=page_images,
total_pages=len(pages_text),
has_images=len(page_images) > 0,
)
def get_full_text(extracted: ExtractedContent) -> str:
"""合并所有页面文本"""
return "\n\n".join(
f"--- Page {i+1} ---\n{text}"
for i, text in enumerate(extracted.pages_text)
if text.strip()
)