253 lines
9.5 KiB
Python
253 lines
9.5 KiB
Python
"""
|
||
重新生成所有题目的 AI trio,子题带父题上下文。
|
||
用法: python backfill_ai_trio_with_context.py [--paper-id <id>] [--course <code>]
|
||
"""
|
||
|
||
import asyncio
|
||
import io
|
||
import json
|
||
import re
|
||
import sys
|
||
import time
|
||
import argparse
|
||
from contextlib import redirect_stdout
|
||
from app.services.supabase_client import get_supabase
|
||
from app.services.llm_clients import get_deepseek_client
|
||
|
||
|
||
def extract_code_lines(text: str) -> str:
|
||
lines = (text or "").splitlines()
|
||
result = []
|
||
in_code = False
|
||
open_brackets = 0
|
||
CODE_START = re.compile(r"^\s*(import |from \w|[A-Za-z_]\w*\s*=|print\()")
|
||
for line in lines:
|
||
stripped = line.strip()
|
||
if in_code and open_brackets > 0:
|
||
result.append(stripped)
|
||
open_brackets += stripped.count("(") + stripped.count("[") + stripped.count("{")
|
||
open_brackets -= stripped.count(")") + stripped.count("]") + stripped.count("}")
|
||
continue
|
||
if CODE_START.match(line):
|
||
in_code = True
|
||
result.append(stripped)
|
||
open_brackets += stripped.count("(") + stripped.count("[") + stripped.count("{")
|
||
open_brackets -= stripped.count(")") + stripped.count("]") + stripped.count("}")
|
||
continue
|
||
in_code = False
|
||
return "\n".join(result)
|
||
|
||
|
||
def try_exec_python(code: str, shared_ns: dict) -> str | None:
|
||
buf = io.StringIO()
|
||
try:
|
||
with redirect_stdout(buf):
|
||
exec(code, shared_ns) # noqa: S102
|
||
output = buf.getvalue().strip()
|
||
return output if output else None
|
||
except Exception:
|
||
return None
|
||
|
||
BATCH_ANALYSIS_PROMPT = """You are an expert academic answer analyst. Generate three study sections for each question below. ALL output must be in English.
|
||
|
||
For every question, return:
|
||
- knowledge_reminder: concise prerequisite bullets in HTML
|
||
- ai_hint: a helpful hint in HTML without revealing the final answer
|
||
- solution: a complete step-by-step solution in HTML
|
||
|
||
Return JSON in this exact format:
|
||
{{
|
||
"analyses": [
|
||
{{
|
||
"question_number": "1a",
|
||
"knowledge_reminder": "<HTML>...</HTML>",
|
||
"ai_hint": "<HTML>...</HTML>",
|
||
"solution": "<HTML>...</HTML>"
|
||
}}
|
||
]
|
||
}}
|
||
|
||
Rules:
|
||
- Return one item for every provided question_number
|
||
- All text must be in English
|
||
- HTML only, KaTeX compatible (block $$ ... $$ inline $ ... $)
|
||
- For MC questions, explain why the correct option is right and why others are wrong
|
||
- For long questions, show a complete derivation or reasoning chain
|
||
- Use <ol> or numbered steps in solution when appropriate
|
||
- Mark common mistakes with <div class="common-error">...</div>
|
||
- CRITICAL: When a question_text contains "[Context from parent question X]" followed by "[Sub-question Y]", the parent section is background context only. You MUST solve ONLY the specific sub-question labeled [Sub-question Y]. Do NOT solve other sub-questions listed in the parent context. Give one precise answer for that single sub-question only.
|
||
|
||
Questions:
|
||
{questions_payload}
|
||
"""
|
||
|
||
|
||
def chunked(lst, size):
|
||
return [lst[i:i+size] for i in range(0, len(lst), size)]
|
||
|
||
|
||
async def deepseek_batch(batch: list[dict]) -> list[dict]:
|
||
client = get_deepseek_client()
|
||
for attempt in range(5):
|
||
try:
|
||
resp = client.chat.completions.create(
|
||
model="deepseek-chat",
|
||
messages=[{
|
||
"role": "system",
|
||
"content": BATCH_ANALYSIS_PROMPT.format(
|
||
questions_payload=json.dumps(batch, ensure_ascii=False)
|
||
)
|
||
}],
|
||
temperature=0.3,
|
||
max_tokens=8192,
|
||
response_format={"type": "json_object"},
|
||
)
|
||
raw = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f]', '', resp.choices[0].message.content)
|
||
raw = re.sub(r'(?<!\\)((?:\\\\)*)\\([^"\\/bfnrtu])', r'\1\\\\\2', raw)
|
||
data = json.loads(raw)
|
||
return data.get("analyses", [])
|
||
except Exception as e:
|
||
print(f" attempt {attempt+1} failed: {e}")
|
||
if attempt < 4:
|
||
await asyncio.sleep(2 ** attempt * 2)
|
||
return []
|
||
|
||
|
||
async def main():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--paper-id", help="Only process this paper")
|
||
parser.add_argument("--course", help="Only process papers with this course code")
|
||
parser.add_argument("--missing-only", action="store_true", help="Only process questions missing solution")
|
||
args = parser.parse_args()
|
||
|
||
sb = get_supabase()
|
||
|
||
# Fetch all questions (with paper info for filtering)
|
||
query = sb.table("paper_questions").select(
|
||
"id, paper_id, question_number, question_type, question_text, "
|
||
"parent_question, score, correct_option, correct_answer, raw_answer_text, "
|
||
"analytics_topic, topic_tags, solution"
|
||
)
|
||
if args.paper_id:
|
||
query = query.eq("paper_id", args.paper_id)
|
||
result = query.order("paper_id").order("display_order").execute()
|
||
all_questions = result.data
|
||
|
||
if args.course:
|
||
# Filter by course via papers table
|
||
papers_res = sb.table("papers").select("id").eq("course_code", args.course.upper()).execute()
|
||
paper_ids = {p["id"] for p in papers_res.data}
|
||
all_questions = [q for q in all_questions if q["paper_id"] in paper_ids]
|
||
|
||
if args.missing_only:
|
||
all_questions = [q for q in all_questions if not q.get("solution")]
|
||
print(f"Questions missing solution: {len(all_questions)}")
|
||
else:
|
||
print(f"Total questions to process: {len(all_questions)}")
|
||
|
||
# Group by paper_id
|
||
from collections import defaultdict
|
||
by_paper: dict[str, list] = defaultdict(list)
|
||
for q in all_questions:
|
||
by_paper[q["paper_id"]].append(q)
|
||
|
||
total_updated = 0
|
||
|
||
for paper_id, questions in by_paper.items():
|
||
print(f"\nPaper {paper_id} — {len(questions)} questions")
|
||
|
||
# 所有题都可能是别的题的父题
|
||
parent_text_map: dict[str, str] = {
|
||
q["question_number"]: q["question_text"] or ""
|
||
for q in questions
|
||
}
|
||
|
||
# Build payloads with context + Python exec
|
||
payloads = []
|
||
exec_namespaces: dict[str, dict] = {}
|
||
|
||
for q in questions:
|
||
parent_q = q.get("parent_question")
|
||
if parent_q and parent_q in parent_text_map:
|
||
full_text = (
|
||
f"[Context from parent question {parent_q}]\n"
|
||
f"{parent_text_map[parent_q]}\n\n"
|
||
f"[Sub-question {q['question_number']}]\n"
|
||
f"{q['question_text'] or ''}"
|
||
)
|
||
else:
|
||
full_text = q["question_text"] or ""
|
||
|
||
answer_section = ""
|
||
if q.get("raw_answer_text"):
|
||
answer_section = q["raw_answer_text"]
|
||
elif q.get("correct_option"):
|
||
answer_section = f"Correct option: {q['correct_option']}"
|
||
elif q.get("correct_answer"):
|
||
answer_section = f"Correct answer: {q['correct_answer']}"
|
||
|
||
# 尝试 Python exec 拿真实输出
|
||
if not answer_section:
|
||
group_key = parent_q or q["question_number"]
|
||
if group_key not in exec_namespaces:
|
||
ns: dict = {}
|
||
try:
|
||
import numpy as np
|
||
ns["np"] = np
|
||
except ImportError:
|
||
pass
|
||
# 先执行父题 setup 代码
|
||
if parent_q and parent_q in parent_text_map:
|
||
setup = extract_code_lines(parent_text_map[parent_q])
|
||
try_exec_python(setup, ns)
|
||
exec_namespaces[group_key] = ns
|
||
|
||
ns = exec_namespaces[group_key]
|
||
sub_code = extract_code_lines(q["question_text"] or "")
|
||
if sub_code:
|
||
exec_out = try_exec_python(sub_code, ns)
|
||
if exec_out is not None:
|
||
answer_section = f"Executed output: {exec_out}"
|
||
print(f" [exec] {q['question_number']}: {exec_out[:60]}")
|
||
|
||
payloads.append({
|
||
"_id": q["id"],
|
||
"question_number": q["question_number"],
|
||
"question_type": q["question_type"] or "long_question",
|
||
"score": q.get("score") or "unknown",
|
||
"question_text": full_text,
|
||
"reference_answer": answer_section,
|
||
})
|
||
|
||
# Process in batches of 3
|
||
id_map = {q["question_number"]: q["id"] for q in questions}
|
||
|
||
for batch in chunked(payloads, 3):
|
||
# Strip internal _id before sending to model
|
||
model_batch = [{k: v for k, v in p.items() if k != "_id"} for p in batch]
|
||
nums = [p["question_number"] for p in batch]
|
||
print(f" Batch {nums} ...", end=" ", flush=True)
|
||
|
||
analyses = await deepseek_batch(model_batch)
|
||
|
||
for item in analyses:
|
||
qnum = item.get("question_number")
|
||
qid = id_map.get(qnum)
|
||
if not qid:
|
||
continue
|
||
sb.table("paper_questions").update({
|
||
"knowledge_reminder": item.get("knowledge_reminder"),
|
||
"ai_hint": item.get("ai_hint"),
|
||
"solution": item.get("solution"),
|
||
}).eq("id", qid).execute()
|
||
total_updated += 1
|
||
|
||
print(f"done ({len(analyses)} updated)")
|
||
await asyncio.sleep(1)
|
||
|
||
print(f"\nDone. Total updated: {total_updated}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|