Files
PastpaperMaster/backend/app/routers/questions.py
Zhao 7a09167261 Initial commit: PastPaper Master full stack
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-21 12:27:47 +07:00

326 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""题目相关:变式题生成 + 相似题召回"""
from __future__ import annotations
import asyncio
import time
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
from app.services.supabase_client import get_supabase
from app.services.grader import generate_variant
from app.dependencies.auth import get_current_user_id
# Simple in-memory cache: question_id → (timestamp, result)
_similar_cache: dict[str, tuple[float, list]] = {}
_CACHE_TTL = 300 # 5 minutes
class VariantUpdate(BaseModel):
favorited: bool | None = None
router = APIRouter()
def normalized_labels(values: list[str] | None) -> dict[str, str]:
labels: dict[str, str] = {}
for value in values or []:
if value:
labels[value.lower()] = value
return labels
def question_family(question: dict) -> str:
return question.get("question_format") or question.get("question_type") or "unknown"
def display_topics(question: dict) -> list[str]:
labels: list[str] = []
analytics_topic = question.get("analytics_topic")
if analytics_topic:
labels.append(analytics_topic)
for topic in question.get("topic_tags") or []:
if topic and topic not in labels:
labels.append(topic)
if labels:
return labels
for topic in question.get("topics") or []:
if topic and topic not in labels:
labels.append(topic)
return labels
def similarity_score(
target: dict,
candidate: dict,
text_score: float = 0.0,
) -> tuple[int, list[str]]:
score = 0
reasons: list[str] = []
# Primary topic bucket: 40 pts
target_topic = target.get("analytics_topic")
candidate_topic = candidate.get("analytics_topic")
if target_topic and target_topic == candidate_topic:
score += 40
reasons.append(f"Same topic: {target_topic}")
# Concept overlap: up to 20 pts
target_topics = normalized_labels(target.get("topic_tags"))
candidate_topics = normalized_labels(candidate.get("topic_tags"))
shared_topics = sorted(set(target_topics) & set(candidate_topics))
if shared_topics:
score += min(len(shared_topics) * 10, 20)
# Only show concept reason if analytics_topic didn't already match (avoid redundancy)
if not (target_topic and target_topic == candidate_topic):
reasons.append(
"Shared concept: "
+ ", ".join(target_topics[key] for key in shared_topics[:2])
)
# Skill overlap: up to 20 pts
target_skills = normalized_labels(target.get("skill_tags"))
candidate_skills = normalized_labels(candidate.get("skill_tags"))
shared_skills = sorted(set(target_skills) & set(candidate_skills))
if shared_skills:
score += min(len(shared_skills) * 10, 20)
reasons.append(
"Shared skill: "
+ ", ".join(target_skills[key] for key in shared_skills[:2])
)
# Same question format: 10 pts
if question_family(candidate) == question_family(target):
score += 10
reasons.append("Same format")
# Same difficulty: 5 pts
if candidate.get("difficulty") and candidate.get("difficulty") == target.get("difficulty"):
score += 5
reasons.append("Same difficulty")
# Full-text similarity from PostgreSQL ts_rank_cd: up to 20 pts
if text_score > 0:
text_pts = min(round(text_score * 60), 20)
score += text_pts
if text_pts >= 4:
reasons.append("Similar wording")
return min(score, 99), reasons
@router.get("/variants/favorited")
async def get_favorited_variants(user_id: str = Depends(get_current_user_id)):
"""获取用户收藏的所有 variant用于 Error Book"""
sb = get_supabase()
rows = (
sb.table("question_variants")
.select("*, paper_questions(question_number, paper_id, papers(id, course_code, year, term, exam_type, part_label))")
.eq("user_id", user_id)
.eq("favorited", True)
.order("created_at", desc=True)
.execute()
.data
)
return rows
@router.post("/{question_id}/variant")
async def create_variant(question_id: str, user_id: str = Depends(get_current_user_id)):
"""生成变式题并入库"""
sb = get_supabase()
result = sb.table("paper_questions").select("*").eq("id", question_id).execute()
if not result.data:
raise HTTPException(status_code=404, detail="Question not found")
question = result.data[0]
variant_data = await asyncio.to_thread(generate_variant, question)
variant_data["knowledge_reminder"] = question.get("knowledge_reminder", "")
saved = sb.table("question_variants").insert({
"user_id": user_id,
"source_question_id": question_id,
"variant_data": variant_data,
"favorited": False,
}).execute()
row = saved.data[0]
row["source_question_number"] = question["question_number"]
return row
@router.get("/{question_id}/variants")
async def list_variants(question_id: str, user_id: str = Depends(get_current_user_id)):
"""获取某道题的用户所有 variant"""
sb = get_supabase()
q_result = sb.table("paper_questions").select("question_number").eq("id", question_id).execute()
question_number = q_result.data[0]["question_number"] if q_result.data else ""
rows = (
sb.table("question_variants")
.select("*")
.eq("user_id", user_id)
.eq("source_question_id", question_id)
.order("created_at", desc=True)
.execute()
.data
)
for row in rows:
row["source_question_number"] = question_number
return rows
@router.patch("/variant/{variant_id}")
async def update_variant(variant_id: str, data: VariantUpdate, user_id: str = Depends(get_current_user_id)):
"""更新 variant收藏/取消收藏)"""
sb = get_supabase()
update: dict = {}
if data.favorited is not None:
update["favorited"] = data.favorited
if not update:
raise HTTPException(status_code=400, detail="Nothing to update")
result = (
sb.table("question_variants")
.update(update)
.eq("id", variant_id)
.eq("user_id", user_id)
.execute()
)
if not result.data:
raise HTTPException(status_code=404, detail="Variant not found")
return result.data[0]
@router.delete("/variant/{variant_id}", status_code=204)
async def delete_variant(variant_id: str, user_id: str = Depends(get_current_user_id)):
"""删除 variant"""
sb = get_supabase()
sb.table("question_variants").delete().eq("id", variant_id).eq("user_id", user_id).execute()
@router.get("/{question_id}/similar")
async def get_similar_questions(question_id: str, limit: int = 6):
"""Retrieve similar questions from the same course."""
# Cache hit
cached = _similar_cache.get(question_id)
if cached and (time.time() - cached[0]) < _CACHE_TTL:
return cached[1][:max(1, min(limit, 12))]
sb = get_supabase()
result = sb.table("paper_questions").select("*, similar_questions").eq("id", question_id).execute()
if not result.data:
raise HTTPException(status_code=404, detail="Question not found")
target = result.data[0]
# Return pre-computed immediately; schedule background refresh
if target.get("similar_questions"):
precomputed = target["similar_questions"]
_similar_cache[question_id] = (time.time(), precomputed)
return precomputed[:max(1, min(limit, 12))]
paper_result = sb.table("papers").select("id, course_code").eq("id", target["paper_id"]).execute()
# (fallback: compute on-the-fly for questions not yet backfilled)
if not paper_result.data:
raise HTTPException(status_code=404, detail="Paper not found")
course_code = paper_result.data[0]["course_code"]
papers = (
sb.table("papers")
.select("id, course_code, year, term, exam_type, part_label")
.eq("course_code", course_code)
.eq("status", "ready")
.execute()
.data
)
paper_ids = [paper["id"] for paper in papers if paper["id"] != target["paper_id"]]
if not paper_ids:
return []
papers_by_id = {paper["id"]: paper for paper in papers}
# Pre-filter by analytics_topic in DB when possible (cuts candidates from ~250 to ~30)
candidates_query = (
sb.table("paper_questions")
.select(
"id, paper_id, question_number, question_type, question_format, "
"question_text, score, topics, analytics_topic, topic_tags, skill_tags, "
"difficulty, knowledge_reminder, ai_hint, solution"
)
.in_("paper_id", paper_ids)
)
target_topic = target.get("analytics_topic")
if target_topic:
candidates_query = candidates_query.eq("analytics_topic", target_topic)
candidates = candidates_query.execute().data
if not candidates:
return []
# Batch full-text scores from PostgreSQL (skip if too many candidates — slow)
text_scores: dict[str, float] = {}
if len(candidates) <= 50:
try:
rpc_result = sb.rpc(
"text_similarity_scores",
{
"query_text": target.get("question_text") or "",
"candidate_ids": [c["id"] for c in candidates],
},
).execute()
for row in rpc_result.data or []:
text_scores[row["question_id"]] = float(row["text_score"] or 0)
except Exception:
pass
ranked = []
for candidate in candidates:
text_score = text_scores.get(candidate["id"], 0.0)
match_percent, reasons = similarity_score(target, candidate, text_score)
if match_percent < 20:
continue
paper = papers_by_id.get(candidate["paper_id"], {})
source = (
f"{paper.get('year', '')} {paper.get('term', '').title()} "
f"{paper.get('exam_type', '').title()}"
).strip()
if paper.get("part_label"):
source = f"{source} Part {paper['part_label']}"
ranked.append(
{
"id": candidate["id"],
"paper_id": candidate["paper_id"],
"source": source,
"question_number": candidate["question_number"],
"match_percent": match_percent,
"match_reasons": reasons,
"question_type": question_family(candidate),
"question_text": candidate["question_text"],
"topics": display_topics(candidate),
"difficulty": candidate.get("difficulty"),
"knowledge_reminder": candidate.get("knowledge_reminder", ""),
"ai_hint": candidate.get("ai_hint", ""),
"solution": candidate.get("solution", ""),
}
)
ranked.sort(key=lambda item: (-item["match_percent"], item["source"], item["question_number"]))
# Keep only the best-scoring question per paper
seen_papers: set[str] = set()
deduped = []
for item in ranked:
if item["paper_id"] not in seen_papers:
seen_papers.add(item["paper_id"])
deduped.append(item)
_similar_cache[question_id] = (time.time(), deduped)
# Persist to DB so future requests are instant
try:
sb.table("paper_questions").update({"similar_questions": deduped}).eq("id", question_id).execute()
except Exception:
pass
return deduped[:max(1, min(limit, 12))]