239 lines
8.7 KiB
Python
239 lines
8.7 KiB
Python
"""
|
||
用 Vision 模式重新处理所有已 ready 的试卷:
|
||
- 从 Supabase Storage 拉 PDF → 图片 → Vision 拆题 → exec → AI trio → 更新 DB
|
||
|
||
用法:
|
||
python backfill_vision.py --course COMP2211
|
||
python backfill_vision.py --paper-id <uuid>
|
||
"""
|
||
|
||
import asyncio
|
||
import argparse
|
||
import requests
|
||
from app.services.supabase_client import get_supabase
|
||
from app.services.paper_processor import (
|
||
process_paper,
|
||
strip_nulls,
|
||
pdf_to_images,
|
||
gemini_vision_json,
|
||
deepseek_json_completion,
|
||
parse_json_response,
|
||
extract_code_lines,
|
||
try_exec_python,
|
||
chunked,
|
||
sort_questions,
|
||
STRUCTURE_PROMPT,
|
||
ANSWER_MATCH_PROMPT,
|
||
BATCH_ANALYSIS_PROMPT,
|
||
)
|
||
import json
|
||
import traceback
|
||
|
||
|
||
async def reprocess_paper(paper: dict):
|
||
"""重新处理单张试卷(Vision 模式)"""
|
||
sb = get_supabase()
|
||
paper_id = paper["id"]
|
||
label = f"{paper['course_code']} {paper['year']} {paper['term']} {paper['exam_type']}"
|
||
print(f"\n=== {label} ({paper_id[:8]}) ===")
|
||
|
||
# 1. 拉 PDF
|
||
try:
|
||
pdf_bytes = requests.get(paper["paper_file_url"], timeout=60).content
|
||
except Exception as e:
|
||
print(f" SKIP: failed to fetch PDF: {e}")
|
||
return
|
||
|
||
answer_bytes = None
|
||
if paper.get("answer_file_url"):
|
||
try:
|
||
answer_bytes = requests.get(paper["answer_file_url"], timeout=60).content
|
||
except Exception:
|
||
pass
|
||
|
||
# 2. PDF → 图片
|
||
print(f" Rendering {len(pdf_to_images(pdf_bytes))} pages...", end=" ", flush=True)
|
||
paper_images = pdf_to_images(pdf_bytes)
|
||
print("done")
|
||
|
||
# 3. Vision 拆题(分批,每批 8 页)
|
||
PAGE_BATCH = 8
|
||
all_questions: list = []
|
||
meta: dict = {}
|
||
print(f" Vision extraction ({len(paper_images)} pages, {-(-len(paper_images)//PAGE_BATCH)} batches)...")
|
||
for i in range(0, len(paper_images), PAGE_BATCH):
|
||
batch_imgs = paper_images[i:i + PAGE_BATCH]
|
||
print(f" Pages {i+1}-{i+len(batch_imgs)}...", end=" ", flush=True)
|
||
try:
|
||
batch_result = await gemini_vision_json(
|
||
system_prompt=STRUCTURE_PROMPT,
|
||
images=batch_imgs,
|
||
user_text=f"Pages {i+1}-{i+len(batch_imgs)} of the exam paper. Extract all questions visible on these pages.",
|
||
temperature=0,
|
||
)
|
||
if not meta:
|
||
meta = {k: batch_result.get(k) for k in ("total_score", "difficulty_level", "topics_summary")}
|
||
qs = batch_result.get("questions", [])
|
||
all_questions.extend(qs)
|
||
print(f"done ({len(qs)} questions)")
|
||
except Exception as e:
|
||
print(f"FAILED: {e}")
|
||
structure = {**meta, "questions": all_questions}
|
||
questions = sort_questions(all_questions)
|
||
print(f" Total: {len(questions)} questions extracted")
|
||
|
||
# 4. 答案匹配
|
||
answers_map = {}
|
||
if answer_bytes:
|
||
print(" Vision answer matching...", end=" ", flush=True)
|
||
answer_images = pdf_to_images(answer_bytes)
|
||
questions_json = json.dumps(
|
||
[{"question_number": q["question_number"], "question_type": q["question_type"]}
|
||
for q in questions], ensure_ascii=False
|
||
)
|
||
try:
|
||
match_result = await gemini_vision_json(
|
||
system_prompt=ANSWER_MATCH_PROMPT.format(
|
||
questions_json=questions_json, answer_text="(See images)"
|
||
),
|
||
images=answer_images,
|
||
user_text=f"Match answers to these questions: {questions_json}",
|
||
temperature=0,
|
||
)
|
||
answers_map = {a["question_number"]: a for a in match_result.get("answers", [])}
|
||
print(f"done ({len(answers_map)} matched)")
|
||
except Exception as e:
|
||
print(f"FAILED: {e}")
|
||
|
||
# 5. 构建 payloads(exec Python)
|
||
import numpy as np
|
||
exec_namespaces: dict = {}
|
||
batched_payloads = []
|
||
|
||
for q in questions:
|
||
qnum = q["question_number"]
|
||
answer = answers_map.get(qnum, {})
|
||
full_text = q["question_text"] or ""
|
||
|
||
answer_section = ""
|
||
if answer.get("raw_answer_text"):
|
||
answer_section = answer["raw_answer_text"]
|
||
elif answer.get("correct_option"):
|
||
answer_section = f"Correct option: {answer['correct_option']}"
|
||
elif answer.get("correct_answer"):
|
||
answer_section = f"Correct answer: {answer['correct_answer']}"
|
||
|
||
if not answer_section:
|
||
parent_q = q.get("parent_question")
|
||
group_key = parent_q or qnum
|
||
if group_key not in exec_namespaces:
|
||
ns: dict = {"np": np}
|
||
setup = extract_code_lines(full_text)
|
||
try_exec_python(setup, ns)
|
||
exec_namespaces[group_key] = ns
|
||
ns = exec_namespaces[group_key]
|
||
print_lines = [l.strip() for l in full_text.splitlines() if l.strip().startswith("print(")]
|
||
if print_lines:
|
||
out = try_exec_python(print_lines[-1], ns)
|
||
if out is not None:
|
||
answer_section = f"Executed output: {out}"
|
||
print(f" [exec] {qnum}: {out[:60]}")
|
||
|
||
batched_payloads.append({
|
||
"question_number": qnum,
|
||
"question_type": q["question_type"],
|
||
"score": q.get("score", "unknown"),
|
||
"question_text": full_text,
|
||
"topics": q.get("topics", []),
|
||
"reference_answer": answer_section,
|
||
})
|
||
|
||
# 6. AI trio
|
||
print(f" Generating AI trio ({len(batched_payloads)} questions, {len(list(chunked(batched_payloads, 3)))} batches)...")
|
||
analyses: dict = {}
|
||
for batch in chunked(batched_payloads, 3):
|
||
nums = [p["question_number"] for p in batch]
|
||
print(f" Batch {nums}...", end=" ", flush=True)
|
||
try:
|
||
result = await deepseek_json_completion(
|
||
system_prompt=BATCH_ANALYSIS_PROMPT.format(
|
||
questions_payload=json.dumps(batch, ensure_ascii=False)
|
||
),
|
||
temperature=0.3,
|
||
)
|
||
for item in result.get("analyses", []):
|
||
if item.get("question_number"):
|
||
analyses[item["question_number"]] = item
|
||
print(f"done ({len(result.get('analyses', []))})")
|
||
except Exception as e:
|
||
print(f"FAILED: {e}")
|
||
await asyncio.sleep(1)
|
||
|
||
# 7. 删除旧题目,写入新题目
|
||
print(" Writing to DB...", end=" ", flush=True)
|
||
sb.table("paper_questions").delete().eq("paper_id", paper_id).execute()
|
||
|
||
for i, q in enumerate(questions):
|
||
qnum = q["question_number"]
|
||
answer = answers_map.get(qnum, {})
|
||
analysis = analyses.get(qnum, {})
|
||
sb.table("paper_questions").insert(strip_nulls({
|
||
"paper_id": paper_id,
|
||
"question_number": qnum,
|
||
"parent_question": q.get("parent_question"),
|
||
"display_order": i,
|
||
"question_type": q["question_type"],
|
||
"question_text": q["question_text"],
|
||
"score": q.get("score"),
|
||
"page_number": q.get("page_number"),
|
||
"options": q.get("options"),
|
||
"correct_option": answer.get("correct_option"),
|
||
"correct_answer": answer.get("correct_answer"),
|
||
"raw_answer_text": answer.get("raw_answer_text"),
|
||
"topics": q.get("topics", []),
|
||
"analytics_topic": q.get("topics", [None])[0],
|
||
"topic_tags": q.get("topics", []),
|
||
"difficulty": q.get("difficulty"),
|
||
"knowledge_reminder": analysis.get("knowledge_reminder", ""),
|
||
"ai_hint": analysis.get("ai_hint", ""),
|
||
"solution": analysis.get("solution", ""),
|
||
})).execute()
|
||
|
||
sb.table("papers").update({
|
||
"question_count": len(questions),
|
||
"total_score": structure.get("total_score"),
|
||
"topics_summary": structure.get("topics_summary"),
|
||
"difficulty_level": structure.get("difficulty_level"),
|
||
}).eq("id", paper_id).execute()
|
||
|
||
print(f"done ({len(questions)} questions written)")
|
||
|
||
|
||
async def main():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--course", help="Course code")
|
||
parser.add_argument("--paper-id", help="Single paper ID")
|
||
args = parser.parse_args()
|
||
|
||
sb = get_supabase()
|
||
query = sb.table("papers").select("*").eq("status", "ready")
|
||
if args.paper_id:
|
||
query = query.eq("id", args.paper_id)
|
||
elif args.course:
|
||
query = query.eq("course_code", args.course.upper())
|
||
papers = query.order("created_at").execute().data
|
||
|
||
print(f"Papers to reprocess: {len(papers)}")
|
||
for paper in papers:
|
||
try:
|
||
await reprocess_paper(paper)
|
||
except Exception as e:
|
||
print(f" ERROR: {e}")
|
||
traceback.print_exc()
|
||
|
||
print("\nAll done.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|