Initial commit: PastPaper Master full stack
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
146
backend/app/services/grader.py
Normal file
146
backend/app/services/grader.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""OCR, grading, and variant generation prompts"""
|
||||
|
||||
import json
|
||||
import base64
|
||||
from app.services.llm_clients import get_vision_client, get_deepseek_client
|
||||
|
||||
OCR_PROMPT = """You are an expert at recognizing handwritten answers. Analyze this photo of a student's handwritten answer and extract the text and mathematical formulas.
|
||||
|
||||
Requirements:
|
||||
- Faithfully extract what the student wrote, do not modify or correct
|
||||
- Use LaTeX format for math formulas (e.g. $x^2 + 1$)
|
||||
- If there are multiple steps, list them in original order
|
||||
- If some handwriting is unclear, mark with [unclear]
|
||||
|
||||
Return only the extracted text, no additional explanation."""
|
||||
|
||||
GRADING_PROMPT = """You are an expert academic grader. Grade the following student answer. ALL output must be in English.
|
||||
|
||||
Question info:
|
||||
- Number: {question_number}
|
||||
- Type: {question_type}
|
||||
- Question: {question_text}
|
||||
- Score: {score}
|
||||
|
||||
Reference answer / solution:
|
||||
{reference_answer}
|
||||
|
||||
Student answer:
|
||||
{student_answer}
|
||||
|
||||
Grade and return JSON:
|
||||
{{
|
||||
"is_correct": true/false,
|
||||
"score_given": 0-{score},
|
||||
"feedback": "<HTML> Step-by-step analysis of the student's answer, pointing out correct parts and errors, using KaTeX formulas </HTML>",
|
||||
"error_at_step": null or the step number where errors begin (integer)
|
||||
}}
|
||||
|
||||
Grading rules:
|
||||
- MC / fill-blank: only correct if answer matches exactly
|
||||
- Long questions: give partial credit for correct steps even if the final answer is wrong
|
||||
- feedback in HTML format, supports KaTeX ($..$ inline, $$...$$ block)
|
||||
- Mark errors with <div class="common-error">...</div>
|
||||
- Identify exactly which step the error starts"""
|
||||
|
||||
VARIANT_PROMPT = """You are an expert exam question creator. Generate a similar but different variant question based on the original below. ALL output must be in English.
|
||||
|
||||
Original question info:
|
||||
- Type: {question_type}
|
||||
- Question: {question_text}
|
||||
- Topics: {topics}
|
||||
- Difficulty: {difficulty}
|
||||
- Reference answer: {answer}
|
||||
|
||||
Requirements:
|
||||
- Variant must test the same knowledge points at similar difficulty
|
||||
- Data/scenario/wording must differ — don't just change numbers
|
||||
- Must provide a complete correct answer
|
||||
|
||||
Format requirements (CRITICAL):
|
||||
- All text in HTML format, absolutely NO markdown syntax
|
||||
- Code: <pre><code class="language-xxx">...</code></pre>, NOT ```
|
||||
- Math: $...$ (inline) or $$...$$ (block), KaTeX compatible
|
||||
- Line breaks: <br>, paragraphs: <p>
|
||||
|
||||
Return JSON:
|
||||
{{
|
||||
"question_text": "HTML formatted variant question",
|
||||
"question_type": "{question_type}",
|
||||
"options": [MC only, format {{"label":"A","text":"..."}}, ...] or null,
|
||||
"correct_answer": "Correct answer (plain text)",
|
||||
"ai_hint": "HTML formatted hint that guides thinking WITHOUT giving the answer",
|
||||
"solution": "HTML formatted complete step-by-step solution"
|
||||
}}"""
|
||||
|
||||
|
||||
def ocr_photo(photo_bytes: bytes) -> str:
|
||||
"""Gemini Vision OCR for handwritten answers"""
|
||||
client = get_vision_client()
|
||||
b64 = base64.b64encode(photo_bytes).decode("utf-8")
|
||||
|
||||
resp = client.chat.completions.create(
|
||||
model="gemini-2.5-flash",
|
||||
messages=[
|
||||
{"role": "system", "content": OCR_PROMPT},
|
||||
{"role": "user", "content": [
|
||||
{"type": "image_url", "image_url": {
|
||||
"url": f"data:image/jpeg;base64,{b64}",
|
||||
}},
|
||||
]},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=2000,
|
||||
)
|
||||
return resp.choices[0].message.content or ""
|
||||
|
||||
|
||||
def grade_answer(question: dict, student_answer: str) -> dict:
|
||||
"""Qwen grades student answer"""
|
||||
reference = question.get("raw_answer_text") or question.get("solution") or "No reference answer"
|
||||
score = question.get("score") or "unknown"
|
||||
|
||||
ds = get_deepseek_client()
|
||||
resp = ds.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=[
|
||||
{"role": "system", "content": GRADING_PROMPT.format(
|
||||
question_number=question["question_number"],
|
||||
question_type=question["question_type"],
|
||||
question_text=question["question_text"],
|
||||
score=score,
|
||||
reference_answer=reference,
|
||||
student_answer=student_answer,
|
||||
)},
|
||||
],
|
||||
temperature=0.2,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
return json.loads(resp.choices[0].message.content)
|
||||
|
||||
|
||||
def generate_variant(question: dict) -> dict:
|
||||
"""Gemini generates a variant question"""
|
||||
answer = (
|
||||
question.get("correct_option")
|
||||
or question.get("correct_answer")
|
||||
or question.get("raw_answer_text")
|
||||
or "N/A"
|
||||
)
|
||||
|
||||
ds = get_deepseek_client()
|
||||
resp = ds.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=[
|
||||
{"role": "system", "content": VARIANT_PROMPT.format(
|
||||
question_type=question["question_type"],
|
||||
question_text=question["question_text"],
|
||||
topics=", ".join(question.get("topics", [])),
|
||||
difficulty=question.get("difficulty", "medium"),
|
||||
answer=answer,
|
||||
)},
|
||||
],
|
||||
temperature=0.5,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
return json.loads(resp.choices[0].message.content)
|
||||
74
backend/app/services/llm_clients.py
Normal file
74
backend/app/services/llm_clients.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import httpx
|
||||
from openai import OpenAI
|
||||
from app.config import get_settings
|
||||
|
||||
_TIMEOUT = httpx.Timeout(connect=10, read=300, write=60, pool=10)
|
||||
|
||||
_gpt_client: OpenAI | None = None
|
||||
_qwen_client: OpenAI | None = None
|
||||
_gemini_flash_client: OpenAI | None = None
|
||||
_gemini_lite_client: OpenAI | None = None
|
||||
_deepseek_client: OpenAI | None = None
|
||||
|
||||
|
||||
def get_gpt_client() -> OpenAI:
|
||||
"""laozhang API — gpt-4o / gpt-4o-mini"""
|
||||
global _gpt_client
|
||||
if _gpt_client is None:
|
||||
s = get_settings()
|
||||
_gpt_client = OpenAI(
|
||||
base_url=s.laozhang_base_url,
|
||||
api_key=s.laozhang_api_key,
|
||||
)
|
||||
return _gpt_client
|
||||
|
||||
|
||||
def get_qwen_client() -> OpenAI:
|
||||
"""DashScope — qwen-plus"""
|
||||
global _qwen_client
|
||||
if _qwen_client is None:
|
||||
s = get_settings()
|
||||
_qwen_client = OpenAI(
|
||||
base_url=s.dashscope_base_url,
|
||||
api_key=s.dashscope_api_key,
|
||||
)
|
||||
return _qwen_client
|
||||
|
||||
|
||||
def get_vision_client() -> OpenAI:
|
||||
"""Google Gemini 官方 API(视觉,用于拆题+OCR)— 部署在新加坡可用"""
|
||||
global _gemini_flash_client
|
||||
if _gemini_flash_client is None:
|
||||
s = get_settings()
|
||||
_gemini_flash_client = OpenAI(
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
api_key=s.google_gemini_api_key,
|
||||
timeout=_TIMEOUT,
|
||||
)
|
||||
return _gemini_flash_client
|
||||
|
||||
|
||||
def get_gemini_lite_client() -> OpenAI:
|
||||
"""laozhang — gemini-3.1-flash-lite-preview(轻量,用于 AI trio)"""
|
||||
global _gemini_lite_client
|
||||
if _gemini_lite_client is None:
|
||||
s = get_settings()
|
||||
_gemini_lite_client = OpenAI(
|
||||
base_url=s.laozhang_base_url,
|
||||
api_key=s.laozhang_api_key,
|
||||
timeout=_TIMEOUT,
|
||||
)
|
||||
return _gemini_lite_client
|
||||
|
||||
|
||||
def get_deepseek_client() -> OpenAI:
|
||||
"""DeepSeek — deepseek-chat(用于 AI trio)"""
|
||||
global _deepseek_client
|
||||
if _deepseek_client is None:
|
||||
s = get_settings()
|
||||
_deepseek_client = OpenAI(
|
||||
base_url=s.deepseek_base_url,
|
||||
api_key=s.deepseek_api_key,
|
||||
timeout=_TIMEOUT,
|
||||
)
|
||||
return _deepseek_client
|
||||
576
backend/app/services/paper_processor.py
Normal file
576
backend/app/services/paper_processor.py
Normal file
@@ -0,0 +1,576 @@
|
||||
"""试卷处理管线:PDF → 结构化题目 → AI 三件套(Vision 模式)"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
from contextlib import redirect_stdout
|
||||
import fitz # pymupdf
|
||||
from app.services.supabase_client import get_supabase
|
||||
from app.services.llm_clients import get_vision_client, get_deepseek_client
|
||||
|
||||
|
||||
def strip_nulls(obj):
|
||||
"""Recursively remove \\u0000 null bytes from strings (PostgreSQL rejects them)."""
|
||||
if isinstance(obj, str):
|
||||
return obj.replace("\u0000", "")
|
||||
if isinstance(obj, dict):
|
||||
return {k: strip_nulls(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [strip_nulls(i) for i in obj]
|
||||
return obj
|
||||
|
||||
|
||||
# ============================================
|
||||
# Prompts
|
||||
# ============================================
|
||||
|
||||
STRUCTURE_PROMPT = """You are an expert exam paper structure analyst. You are given images of a past exam paper. Analyze every page carefully and extract all questions into structured JSON.
|
||||
All generated values must be in English. Do not output Chinese.
|
||||
|
||||
CRITICAL RULES for question_text:
|
||||
- Each question's question_text must be FULLY SELF-CONTAINED. Include ALL context needed to solve it.
|
||||
- For sub-questions (e.g. (a)(i)), copy the ENTIRE parent question setup (variable definitions, code blocks, problem description) into the question_text, then append the specific sub-question.
|
||||
- For Python/code questions: include ALL variable definitions and import statements verbatim, exactly as they appear in the exam, preserving multi-line arrays and data structures completely.
|
||||
- Never truncate code. If a variable is defined across multiple lines (e.g. a numpy array), include every line.
|
||||
|
||||
Output JSON format (strictly follow):
|
||||
{
|
||||
"total_score": 100,
|
||||
"difficulty_level": "medium",
|
||||
"topics_summary": {"Topic A": 40, "Topic B": 30, "Topic C": 30},
|
||||
"questions": [
|
||||
{
|
||||
"question_number": "1a",
|
||||
"parent_question": "1",
|
||||
"question_type": "mc",
|
||||
"question_text": "Original question text...",
|
||||
"score": 5,
|
||||
"page_number": 1,
|
||||
"options": [{"label": "A", "text": "Option content"}, {"label": "B", "text": "..."}],
|
||||
"topics": ["Linked List", "Pointer"],
|
||||
"difficulty": "easy"
|
||||
},
|
||||
{
|
||||
"question_number": "2",
|
||||
"parent_question": null,
|
||||
"question_type": "long_question",
|
||||
"question_text": "Original question text...",
|
||||
"score": 15,
|
||||
"page_number": 2,
|
||||
"options": null,
|
||||
"topics": ["Recursion"],
|
||||
"difficulty": "hard"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Rules:
|
||||
- question_type must be one of: "mc" (multiple choice), "true_false" (true/false), "fill_blank" (fill in blank), "long_question" (long question)
|
||||
- True/False questions MUST use "true_false" type, with options set to [{"label":"True","text":"True"},{"label":"False","text":"False"}], correct_option as "True" or "False"
|
||||
- Multiple choice must extract the options array
|
||||
- Sub-questions use parent_question to link to parent: "1a" parent is "1"
|
||||
- Independent questions without sub-questions set parent_question to null
|
||||
- page_number inferred from where the question appears
|
||||
- topics inferred from the question content
|
||||
- difficulty: "easy" | "medium" | "hard"
|
||||
- Extract ALL questions, do not miss any
|
||||
- Keep topic labels in English only
|
||||
"""
|
||||
|
||||
ANSWER_MATCH_PROMPT = """You are an expert exam answer matching specialist. Below is the answer text for an exam paper. Extract and match answers to their corresponding question numbers.
|
||||
All generated values must be in English. Do not output Chinese.
|
||||
|
||||
Question structure:
|
||||
{questions_json}
|
||||
|
||||
Answer text:
|
||||
{answer_text}
|
||||
|
||||
Output JSON format:
|
||||
{{
|
||||
"answers": [
|
||||
{{
|
||||
"question_number": "1a",
|
||||
"correct_option": "B",
|
||||
"correct_answer": null,
|
||||
"raw_answer_text": "Original answer text..."
|
||||
}},
|
||||
{{
|
||||
"question_number": "2",
|
||||
"correct_option": null,
|
||||
"correct_answer": null,
|
||||
"raw_answer_text": "Complete solution process and answer..."
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
Rules:
|
||||
- For MC questions, fill correct_option (e.g. "B")
|
||||
- For fill-blank questions, fill correct_answer (e.g. "O(n log n)")
|
||||
- For long questions, only fill raw_answer_text (complete solution process)
|
||||
- Match all questions where answers can be found
|
||||
- Keep raw_answer_text faithful to the source answer, but do not add Chinese commentary
|
||||
"""
|
||||
|
||||
ANALYSIS_PROMPT = """You are an expert academic answer analyst. Generate three sections for the following exam question. ALL output must be in English.
|
||||
|
||||
Question info:
|
||||
- Number: {question_number}
|
||||
- Type: {question_type}
|
||||
- Score: {score}
|
||||
- Question: {question_text}
|
||||
- Topics: {topics}
|
||||
{answer_section}
|
||||
|
||||
Generate THREE sections in HTML format (supports KaTeX: block $$ ... $$ inline $ ... $):
|
||||
|
||||
Output JSON:
|
||||
{{
|
||||
"knowledge_reminder": "<HTML> Prerequisite knowledge points needed for this question, as a concise bullet list </HTML>",
|
||||
"ai_hint": "<HTML> A hint that guides thinking direction WITHOUT giving away the answer </HTML>",
|
||||
"solution": "<HTML> Complete step-by-step solution (Step 1, Step 2, ...) with derivations, formulas, and common mistake warnings </HTML>"
|
||||
}}
|
||||
|
||||
Solution requirements:
|
||||
- Must include complete working process, not just the answer
|
||||
- Each step must have an explanation
|
||||
- If a reference answer is provided, derive the solution based on it
|
||||
- If no reference answer, work out the complete solution independently
|
||||
- For MC questions, explain why the correct option is right AND why others are wrong
|
||||
- Use <ol> or numbered steps
|
||||
- Mark common mistakes with <div class="common-error">...</div>
|
||||
|
||||
KaTeX formula rules:
|
||||
- Block formula: $$ on its own line, with blank lines before and after
|
||||
- Inline formula: $x^2$ no line break
|
||||
- Matrix: \\begin{{bmatrix}} ... \\end{{bmatrix}}
|
||||
- Fraction: \\frac{{a}}{{b}}
|
||||
"""
|
||||
|
||||
BATCH_ANALYSIS_PROMPT = """You are an expert academic answer analyst. Generate three study sections for each question below. ALL output must be in English.
|
||||
|
||||
For every question, return:
|
||||
- knowledge_reminder: concise prerequisite bullets in HTML
|
||||
- ai_hint: a helpful hint in HTML without revealing the final answer
|
||||
- solution: a complete step-by-step solution in HTML
|
||||
|
||||
Return JSON in this exact format:
|
||||
{{
|
||||
"analyses": [
|
||||
{{
|
||||
"question_number": "1a",
|
||||
"knowledge_reminder": "<HTML>...</HTML>",
|
||||
"ai_hint": "<HTML>...</HTML>",
|
||||
"solution": "<HTML>...</HTML>"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
Rules:
|
||||
- Return one item for every provided question_number
|
||||
- Keep each item matched to the same question_number
|
||||
- All text must be in English
|
||||
- HTML only, KaTeX compatible
|
||||
- For MC questions, explain why the correct option is right and why the others are wrong
|
||||
- For long questions, show a complete derivation or reasoning chain
|
||||
- Use <ol> or numbered steps in solution when appropriate
|
||||
- Mark common mistakes with <div class="common-error">...</div>
|
||||
- CRITICAL: When a question_text contains "[Context from parent question X]" followed by "[Sub-question Y]", the parent section is background context only. You MUST solve ONLY the specific sub-question labeled [Sub-question Y]. Do NOT solve other sub-questions listed in the parent context. Give one precise answer for that single sub-question only.
|
||||
|
||||
Questions:
|
||||
{questions_payload}
|
||||
"""
|
||||
|
||||
|
||||
# ============================================
|
||||
# 处理管线
|
||||
# ============================================
|
||||
|
||||
RETRYABLE_ERROR_MARKERS = (
|
||||
"429",
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"too many requests",
|
||||
"timeout",
|
||||
"timed out",
|
||||
"connection",
|
||||
)
|
||||
|
||||
|
||||
def is_retryable_error(exc: Exception) -> bool:
|
||||
message = str(exc).lower()
|
||||
return any(marker in message for marker in RETRYABLE_ERROR_MARKERS)
|
||||
|
||||
|
||||
def pdf_to_images(pdf_bytes: bytes, dpi: int = 96) -> list[str]:
|
||||
"""将 PDF 每页渲染为 base64 PNG 图片列表(96dpi 平衡清晰度与成本)"""
|
||||
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
||||
images = []
|
||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||
for page in doc:
|
||||
pix = page.get_pixmap(matrix=mat, colorspace=fitz.csRGB)
|
||||
img_bytes = pix.tobytes("png")
|
||||
images.append(base64.b64encode(img_bytes).decode())
|
||||
doc.close()
|
||||
return images
|
||||
|
||||
|
||||
def parse_json_response(text: str) -> dict:
|
||||
"""解析模型返回的 JSON,兼容 markdown 代码块包装"""
|
||||
text = text.strip()
|
||||
# 去掉 ```json ... ``` 包装
|
||||
if text.startswith("```"):
|
||||
lines = text.splitlines()
|
||||
text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:])
|
||||
# 移除 JSON 字符串中的非法控制字符(0x00-0x1F 除了 \t \n \r)
|
||||
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f]', '', text)
|
||||
# 修复模型返回的无效 JSON 转义序列:只修奇数个反斜杠后的非法字符
|
||||
text = re.sub(r'(?<!\\)((?:\\\\)*)\\([^"\\/bfnrtu])', r'\1\\\\\2', text)
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
async def gemini_vision_json(
|
||||
*,
|
||||
system_prompt: str,
|
||||
images: list[str],
|
||||
user_text: str = "",
|
||||
temperature: float = 0,
|
||||
max_attempts: int = 6,
|
||||
) -> dict:
|
||||
"""发送图片 + prompt 给 Gemini vision 模型,返回 JSON"""
|
||||
client = get_vision_client()
|
||||
delay_seconds = 2
|
||||
|
||||
content: list = []
|
||||
for b64 in images:
|
||||
content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}})
|
||||
if user_text:
|
||||
content.append({"type": "text", "text": user_text})
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="gemini-2.5-flash",
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt + "\n\nIMPORTANT: Your entire response must be valid JSON only. No markdown, no code fences, no extra text."},
|
||||
{"role": "user", "content": content},
|
||||
],
|
||||
temperature=temperature,
|
||||
max_tokens=16384,
|
||||
)
|
||||
return parse_json_response(response.choices[0].message.content)
|
||||
except Exception as exc:
|
||||
if attempt == max_attempts or not is_retryable_error(exc):
|
||||
raise
|
||||
await asyncio.sleep(delay_seconds)
|
||||
delay_seconds = min(delay_seconds * 2, 30)
|
||||
|
||||
|
||||
async def deepseek_json_completion(
|
||||
*,
|
||||
system_prompt: str,
|
||||
user_prompt: str | None = None,
|
||||
temperature: float = 0,
|
||||
max_attempts: int = 6,
|
||||
) -> dict:
|
||||
"""DeepSeek 纯文本 JSON completion(用于 AI trio 生成)"""
|
||||
client = get_deepseek_client()
|
||||
delay_seconds = 2
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
if user_prompt:
|
||||
messages.append({"role": "user", "content": user_prompt})
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=8192,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
raw = response.choices[0].message.content
|
||||
raw = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f]', '', raw)
|
||||
raw = re.sub(r'(?<!\\)((?:\\\\)*)\\([^"\\/bfnrtu])', r'\1\\\\\2', raw)
|
||||
return json.loads(raw)
|
||||
except Exception as exc:
|
||||
if attempt == max_attempts or not is_retryable_error(exc):
|
||||
raise
|
||||
await asyncio.sleep(delay_seconds)
|
||||
delay_seconds = min(delay_seconds * 2, 30)
|
||||
|
||||
|
||||
def chunked(items: list[dict], size: int) -> list[list[dict]]:
|
||||
return [items[i:i + size] for i in range(0, len(items), size)]
|
||||
|
||||
|
||||
def _question_sort_key(qnum: str) -> tuple:
|
||||
"""自然排序题号:1a < 1b < ... < 1i < 1j < 2ai < 2aii < 10a"""
|
||||
parts = re.findall(r'(\d+|[a-zA-Z]+|[()]+)', qnum)
|
||||
key = []
|
||||
for idx, p in enumerate(parts):
|
||||
if p.isdigit():
|
||||
key.append((0, int(p), ''))
|
||||
elif p in ('(', ')'):
|
||||
continue
|
||||
else:
|
||||
# Single letter (a-z): always sort alphabetically (a=1, b=2, ..., j=10)
|
||||
if len(p) == 1 and p.isalpha():
|
||||
key.append((1, ord(p.lower()) - ord('a') + 1, p))
|
||||
else:
|
||||
# Multi-letter: roman numerals for sub-sub-questions (i=1, ii=2, iii=3, ...)
|
||||
romans = {'i':1,'ii':2,'iii':3,'iv':4,'v':5,'vi':6,'vii':7,'viii':8,'ix':9,'x':10,'xi':11,'xii':12,'xiii':13}
|
||||
if p.lower() in romans:
|
||||
key.append((2, romans[p.lower()], p))
|
||||
else:
|
||||
key.append((1, 0, p))
|
||||
return tuple(key)
|
||||
|
||||
|
||||
def sort_questions(questions: list[dict]) -> list[dict]:
|
||||
"""按题号自然排序"""
|
||||
return sorted(questions, key=lambda q: _question_sort_key(q.get("question_number", "")))
|
||||
|
||||
|
||||
def extract_code_block(text: str) -> str:
|
||||
"""
|
||||
从题目文本中提取 Python 代码块。
|
||||
策略:找到第一个明确的代码起始行(import/赋值/print),
|
||||
然后把后续所有缩进或延续行一并带上,直到明显的非代码段落。
|
||||
"""
|
||||
lines = text.splitlines()
|
||||
result = []
|
||||
in_code = False
|
||||
open_brackets = 0
|
||||
|
||||
CODE_START = re.compile(r"^\s*(import |from \w|[A-Za-z_]\w*\s*=|print\()")
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
|
||||
# 已在代码块内:括号未闭合时继续收集
|
||||
if in_code and open_brackets > 0:
|
||||
result.append(stripped)
|
||||
open_brackets += stripped.count("(") + stripped.count("[") + stripped.count("{")
|
||||
open_brackets -= stripped.count(")") + stripped.count("]") + stripped.count("}")
|
||||
continue
|
||||
|
||||
# 检测新的代码起始行
|
||||
if CODE_START.match(line):
|
||||
in_code = True
|
||||
result.append(stripped)
|
||||
open_brackets += stripped.count("(") + stripped.count("[") + stripped.count("{")
|
||||
open_brackets -= stripped.count(")") + stripped.count("]") + stripped.count("}")
|
||||
continue
|
||||
|
||||
# 非代码行:重置(但保留 in_code=True 以便继续接后续代码行)
|
||||
in_code = False
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
# 保持向后兼容
|
||||
extract_code_lines = extract_code_block
|
||||
|
||||
|
||||
def try_exec_python(code: str, shared_ns: dict) -> str | None:
|
||||
"""
|
||||
在 shared_ns 命名空间中执行 code,捕获 stdout。
|
||||
返回输出字符串,失败返回 None。
|
||||
"""
|
||||
buf = io.StringIO()
|
||||
try:
|
||||
with redirect_stdout(buf):
|
||||
exec(code, shared_ns) # noqa: S102
|
||||
output = buf.getvalue().strip()
|
||||
return output if output else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _resume_ai_trio(sb, paper_id: str, questions: list[dict]):
|
||||
"""为缺 solution 的题目生成 AI trio,逐条写回 DB。支持断点续传。"""
|
||||
need = [q for q in questions if not q.get("solution")]
|
||||
if not need:
|
||||
# 全部已有 solution,直接标记完成
|
||||
sb.table("papers").update({"status": "ready", "processing_step": None}).eq("id", paper_id).execute()
|
||||
return
|
||||
|
||||
total_q = len(questions)
|
||||
done_q = total_q - len(need)
|
||||
|
||||
# 构建 payload
|
||||
id_map = {q["question_number"]: q["id"] for q in need}
|
||||
# 需要完整的 question_text 来生成 AI trio
|
||||
full_data = sb.table("paper_questions").select(
|
||||
"id, question_number, question_type, question_text, score, correct_option, correct_answer, raw_answer_text"
|
||||
).eq("paper_id", paper_id).in_("id", [q["id"] for q in need]).execute().data
|
||||
|
||||
payloads = []
|
||||
for q in full_data:
|
||||
answer_section = q.get("raw_answer_text") or ""
|
||||
if not answer_section and q.get("correct_option"):
|
||||
answer_section = f"Correct option: {q['correct_option']}"
|
||||
elif not answer_section and q.get("correct_answer"):
|
||||
answer_section = f"Correct answer: {q['correct_answer']}"
|
||||
payloads.append({
|
||||
"question_number": q["question_number"],
|
||||
"question_type": q["question_type"] or "long_question",
|
||||
"score": q.get("score") or "unknown",
|
||||
"question_text": q["question_text"] or "",
|
||||
"reference_answer": answer_section,
|
||||
})
|
||||
|
||||
batches = chunked(payloads, 3)
|
||||
for batch_idx, batch in enumerate(batches, 1):
|
||||
current = done_q + batch_idx * 3
|
||||
_update_progress(sb, paper_id, f"Generating solutions ({min(current, total_q)}/{total_q} questions)", batch_idx, len(batches))
|
||||
try:
|
||||
result = await deepseek_json_completion(
|
||||
system_prompt=BATCH_ANALYSIS_PROMPT.format(
|
||||
questions_payload=json.dumps(batch, ensure_ascii=False),
|
||||
),
|
||||
temperature=0.3,
|
||||
)
|
||||
for item in result.get("analyses", []):
|
||||
qnum = item.get("question_number")
|
||||
qid = id_map.get(qnum)
|
||||
if qid:
|
||||
sb.table("paper_questions").update({
|
||||
"knowledge_reminder": item.get("knowledge_reminder", ""),
|
||||
"ai_hint": item.get("ai_hint", ""),
|
||||
"solution": item.get("solution", ""),
|
||||
}).eq("id", qid).execute()
|
||||
except Exception:
|
||||
pass # 单批失败不影响其他批
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 标记完成
|
||||
sb.table("papers").update({"status": "ready", "processing_step": None}).eq("id", paper_id).execute()
|
||||
|
||||
|
||||
def _update_progress(sb, paper_id: str, step: str, progress: int = 0, total: int = 0):
|
||||
"""更新处理进度到 DB"""
|
||||
sb.table("papers").update({
|
||||
"processing_step": step,
|
||||
"processing_progress": progress,
|
||||
"processing_total": total,
|
||||
}).eq("id", paper_id).execute()
|
||||
|
||||
|
||||
async def process_paper(paper_id: str, paper_bytes: bytes, answer_bytes: bytes | None):
|
||||
"""后台处理管线: PDF pages → Vision 结构化 → AI 三件套
|
||||
|
||||
设计原则:每个步骤完成后立即持久化到 DB,支持断点续传。
|
||||
"""
|
||||
sb = get_supabase()
|
||||
|
||||
try:
|
||||
# 检查是否已有题目(断点续传场景)
|
||||
existing = sb.table("paper_questions").select("id, question_number, solution").eq("paper_id", paper_id).execute().data
|
||||
|
||||
if existing:
|
||||
# 已有题目 → 跳过提取,直接补 AI trio
|
||||
await _resume_ai_trio(sb, paper_id, existing)
|
||||
return
|
||||
|
||||
# ── Step 1: PDF → 图片 ──
|
||||
_update_progress(sb, paper_id, "Rendering PDF pages...")
|
||||
paper_images = pdf_to_images(paper_bytes)
|
||||
|
||||
# ── Step 2: Vision 结构化拆题 ──
|
||||
PAGE_BATCH = 8
|
||||
all_questions: list = []
|
||||
meta: dict = {}
|
||||
num_page_batches = -(-len(paper_images) // PAGE_BATCH)
|
||||
for i in range(0, len(paper_images), PAGE_BATCH):
|
||||
batch_imgs = paper_images[i:i + PAGE_BATCH]
|
||||
batch_idx = i // PAGE_BATCH + 1
|
||||
_update_progress(sb, paper_id, f"Reading pages {i+1}-{i+len(batch_imgs)}...", batch_idx, num_page_batches)
|
||||
batch_result = await gemini_vision_json(
|
||||
system_prompt=STRUCTURE_PROMPT,
|
||||
images=batch_imgs,
|
||||
user_text=f"Pages {i+1}-{i+len(batch_imgs)} of the exam paper. Extract all questions visible on these pages.",
|
||||
temperature=0,
|
||||
)
|
||||
if not meta:
|
||||
meta = {k: batch_result.get(k) for k in ("total_score", "difficulty_level", "topics_summary")}
|
||||
all_questions.extend(batch_result.get("questions", []))
|
||||
|
||||
all_questions = sort_questions(all_questions)
|
||||
questions = all_questions
|
||||
|
||||
# 更新 paper 概览
|
||||
sb.table("papers").update({
|
||||
"total_score": meta.get("total_score"),
|
||||
"question_count": len(questions),
|
||||
"topics_summary": meta.get("topics_summary"),
|
||||
"difficulty_level": meta.get("difficulty_level"),
|
||||
}).eq("id", paper_id).execute()
|
||||
|
||||
# ── Step 3: 答案匹配(分批,失败跳过)──
|
||||
answers_map = {}
|
||||
if answer_bytes:
|
||||
_update_progress(sb, paper_id, "Matching answers...")
|
||||
try:
|
||||
answer_images = pdf_to_images(answer_bytes)
|
||||
questions_json = json.dumps(
|
||||
[{"question_number": q["question_number"], "question_type": q["question_type"]}
|
||||
for q in questions], ensure_ascii=False,
|
||||
)
|
||||
all_answers: list = []
|
||||
for ai in range(0, len(answer_images), 8):
|
||||
batch_ans_imgs = answer_images[ai:ai + 8]
|
||||
try:
|
||||
match_result = await gemini_vision_json(
|
||||
system_prompt=ANSWER_MATCH_PROMPT.format(
|
||||
questions_json=questions_json, answer_text="(See images)",
|
||||
),
|
||||
images=batch_ans_imgs,
|
||||
user_text=f"Match answers to these questions: {questions_json}",
|
||||
temperature=0,
|
||||
)
|
||||
all_answers.extend(match_result.get("answers", []))
|
||||
except Exception:
|
||||
pass
|
||||
answers_map = {a["question_number"]: a for a in all_answers}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ── Step 4: 立即写入题目到 DB(先不含 AI trio)──
|
||||
_update_progress(sb, paper_id, "Saving questions...")
|
||||
for i, q in enumerate(questions):
|
||||
qnum = q["question_number"]
|
||||
answer = answers_map.get(qnum, {})
|
||||
sb.table("paper_questions").insert(strip_nulls({
|
||||
"paper_id": paper_id,
|
||||
"question_number": qnum,
|
||||
"parent_question": q.get("parent_question"),
|
||||
"display_order": i,
|
||||
"question_type": q["question_type"],
|
||||
"question_text": q["question_text"],
|
||||
"score": q.get("score"),
|
||||
"page_number": q.get("page_number"),
|
||||
"options": q.get("options"),
|
||||
"correct_option": answer.get("correct_option"),
|
||||
"correct_answer": answer.get("correct_answer"),
|
||||
"raw_answer_text": answer.get("raw_answer_text"),
|
||||
"topics": q.get("topics", []),
|
||||
"analytics_topic": q.get("topics", [None])[0],
|
||||
"topic_tags": q.get("topics", []),
|
||||
"difficulty": q.get("difficulty"),
|
||||
})).execute()
|
||||
|
||||
# ── Step 5: AI trio(逐条更新,支持断点续传)──
|
||||
saved = sb.table("paper_questions").select("id, question_number, solution").eq("paper_id", paper_id).execute().data
|
||||
await _resume_ai_trio(sb, paper_id, saved)
|
||||
|
||||
except Exception as e:
|
||||
sb.table("papers").update({
|
||||
"status": "error",
|
||||
"error_message": f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()[-500:]}",
|
||||
}).eq("id", paper_id).execute()
|
||||
raise
|
||||
13
backend/app/services/supabase_client.py
Normal file
13
backend/app/services/supabase_client.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from supabase import create_client, Client
|
||||
from app.config import get_settings
|
||||
|
||||
_client: Client | None = None
|
||||
|
||||
|
||||
def get_supabase() -> Client:
|
||||
"""获取 Supabase client (service_role,绕过 RLS)"""
|
||||
global _client
|
||||
if _client is None:
|
||||
s = get_settings()
|
||||
_client = create_client(s.supabase_url, s.supabase_service_role_key)
|
||||
return _client
|
||||
48
backend/app/services/text_extractor.py
Normal file
48
backend/app/services/text_extractor.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""PDF 文本提取 — 复用 SOS 的 text_extractor 逻辑"""
|
||||
|
||||
import base64
|
||||
import fitz # PyMuPDF
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedContent:
|
||||
pages_text: list[str] # 每页文本
|
||||
page_images: dict[int, str] # 页码 → base64 图片(图片密集型页面)
|
||||
total_pages: int
|
||||
has_images: bool
|
||||
|
||||
|
||||
def extract_pdf(file_bytes: bytes) -> ExtractedContent:
|
||||
"""从 PDF 提取文本和图片"""
|
||||
doc = fitz.open(stream=file_bytes, filetype="pdf")
|
||||
pages_text = []
|
||||
page_images = {}
|
||||
|
||||
for i, page in enumerate(doc):
|
||||
text = page.get_text("text")
|
||||
pages_text.append(text)
|
||||
|
||||
# 如果某页文本很少但有图片,可能是扫描件 → 保存为图片用于 Vision OCR
|
||||
if len(text.strip()) < 50:
|
||||
pix = page.get_pixmap(dpi=200)
|
||||
img_bytes = pix.tobytes("png")
|
||||
page_images[i] = base64.b64encode(img_bytes).decode("utf-8")
|
||||
|
||||
doc.close()
|
||||
|
||||
return ExtractedContent(
|
||||
pages_text=pages_text,
|
||||
page_images=page_images,
|
||||
total_pages=len(pages_text),
|
||||
has_images=len(page_images) > 0,
|
||||
)
|
||||
|
||||
|
||||
def get_full_text(extracted: ExtractedContent) -> str:
|
||||
"""合并所有页面文本"""
|
||||
return "\n\n".join(
|
||||
f"--- Page {i+1} ---\n{text}"
|
||||
for i, text in enumerate(extracted.pages_text)
|
||||
if text.strip()
|
||||
)
|
||||
Reference in New Issue
Block a user