"""题目相关:变式题生成 + 相似题召回""" from __future__ import annotations import asyncio import time from fastapi import APIRouter, HTTPException, Depends from pydantic import BaseModel from app.services.supabase_client import get_supabase from app.services.grader import generate_variant from app.dependencies.auth import get_current_user_id # Simple in-memory cache: question_id → (timestamp, result) _similar_cache: dict[str, tuple[float, list]] = {} _CACHE_TTL = 300 # 5 minutes class VariantUpdate(BaseModel): favorited: bool | None = None router = APIRouter() def normalized_labels(values: list[str] | None) -> dict[str, str]: labels: dict[str, str] = {} for value in values or []: if value: labels[value.lower()] = value return labels def question_family(question: dict) -> str: return question.get("question_format") or question.get("question_type") or "unknown" def display_topics(question: dict) -> list[str]: labels: list[str] = [] analytics_topic = question.get("analytics_topic") if analytics_topic: labels.append(analytics_topic) for topic in question.get("topic_tags") or []: if topic and topic not in labels: labels.append(topic) if labels: return labels for topic in question.get("topics") or []: if topic and topic not in labels: labels.append(topic) return labels def similarity_score( target: dict, candidate: dict, text_score: float = 0.0, ) -> tuple[int, list[str]]: score = 0 reasons: list[str] = [] # Primary topic bucket: 40 pts target_topic = target.get("analytics_topic") candidate_topic = candidate.get("analytics_topic") if target_topic and target_topic == candidate_topic: score += 40 reasons.append(f"Same topic: {target_topic}") # Concept overlap: up to 20 pts target_topics = normalized_labels(target.get("topic_tags")) candidate_topics = normalized_labels(candidate.get("topic_tags")) shared_topics = sorted(set(target_topics) & set(candidate_topics)) if shared_topics: score += min(len(shared_topics) * 10, 20) # Only show concept reason if analytics_topic didn't already match (avoid redundancy) if not (target_topic and target_topic == candidate_topic): reasons.append( "Shared concept: " + ", ".join(target_topics[key] for key in shared_topics[:2]) ) # Skill overlap: up to 20 pts target_skills = normalized_labels(target.get("skill_tags")) candidate_skills = normalized_labels(candidate.get("skill_tags")) shared_skills = sorted(set(target_skills) & set(candidate_skills)) if shared_skills: score += min(len(shared_skills) * 10, 20) reasons.append( "Shared skill: " + ", ".join(target_skills[key] for key in shared_skills[:2]) ) # Same question format: 10 pts if question_family(candidate) == question_family(target): score += 10 reasons.append("Same format") # Same difficulty: 5 pts if candidate.get("difficulty") and candidate.get("difficulty") == target.get("difficulty"): score += 5 reasons.append("Same difficulty") # Full-text similarity from PostgreSQL ts_rank_cd: up to 20 pts if text_score > 0: text_pts = min(round(text_score * 60), 20) score += text_pts if text_pts >= 4: reasons.append("Similar wording") return min(score, 99), reasons @router.get("/variants/favorited") async def get_favorited_variants(user_id: str = Depends(get_current_user_id)): """获取用户收藏的所有 variant(用于 Error Book)""" sb = get_supabase() rows = ( sb.table("question_variants") .select("*, paper_questions(question_number, paper_id, papers(id, course_code, year, term, exam_type, part_label))") .eq("user_id", user_id) .eq("favorited", True) .order("created_at", desc=True) .execute() .data ) return rows @router.post("/{question_id}/variant") async def create_variant(question_id: str, user_id: str = Depends(get_current_user_id)): """生成变式题并入库""" sb = get_supabase() result = sb.table("paper_questions").select("*").eq("id", question_id).execute() if not result.data: raise HTTPException(status_code=404, detail="Question not found") question = result.data[0] variant_data = await generate_variant(question) variant_data["knowledge_reminder"] = question.get("knowledge_reminder", "") saved = sb.table("question_variants").insert({ "user_id": user_id, "source_question_id": question_id, "variant_data": variant_data, "favorited": False, }).execute() row = saved.data[0] row["source_question_number"] = question["question_number"] return row @router.get("/{question_id}/variants") async def list_variants(question_id: str, user_id: str = Depends(get_current_user_id)): """获取某道题的用户所有 variant""" sb = get_supabase() q_result = sb.table("paper_questions").select("question_number").eq("id", question_id).execute() question_number = q_result.data[0]["question_number"] if q_result.data else "" rows = ( sb.table("question_variants") .select("*") .eq("user_id", user_id) .eq("source_question_id", question_id) .order("created_at", desc=True) .execute() .data ) for row in rows: row["source_question_number"] = question_number return rows @router.patch("/variant/{variant_id}") async def update_variant(variant_id: str, data: VariantUpdate, user_id: str = Depends(get_current_user_id)): """更新 variant(收藏/取消收藏)""" sb = get_supabase() update: dict = {} if data.favorited is not None: update["favorited"] = data.favorited if not update: raise HTTPException(status_code=400, detail="Nothing to update") result = ( sb.table("question_variants") .update(update) .eq("id", variant_id) .eq("user_id", user_id) .execute() ) if not result.data: raise HTTPException(status_code=404, detail="Variant not found") return result.data[0] @router.delete("/variant/{variant_id}", status_code=204) async def delete_variant(variant_id: str, user_id: str = Depends(get_current_user_id)): """删除 variant""" sb = get_supabase() sb.table("question_variants").delete().eq("id", variant_id).eq("user_id", user_id).execute() @router.get("/{question_id}/similar") async def get_similar_questions(question_id: str, limit: int = 6): """Retrieve similar questions from the same course.""" # Cache hit cached = _similar_cache.get(question_id) if cached and (time.time() - cached[0]) < _CACHE_TTL: return cached[1][:max(1, min(limit, 12))] sb = get_supabase() result = sb.table("paper_questions").select("*, similar_questions").eq("id", question_id).execute() if not result.data: raise HTTPException(status_code=404, detail="Question not found") target = result.data[0] # Return pre-computed immediately; schedule background refresh if target.get("similar_questions"): precomputed = target["similar_questions"] _similar_cache[question_id] = (time.time(), precomputed) return precomputed[:max(1, min(limit, 12))] paper_result = sb.table("papers").select("id, course_code").eq("id", target["paper_id"]).execute() # (fallback: compute on-the-fly for questions not yet backfilled) if not paper_result.data: raise HTTPException(status_code=404, detail="Paper not found") course_code = paper_result.data[0]["course_code"] papers = ( sb.table("papers") .select("id, course_code, year, term, exam_type, part_label") .eq("course_code", course_code) .eq("status", "ready") .execute() .data ) paper_ids = [paper["id"] for paper in papers if paper["id"] != target["paper_id"]] if not paper_ids: return [] papers_by_id = {paper["id"]: paper for paper in papers} # Pre-filter by analytics_topic in DB when possible (cuts candidates from ~250 to ~30) candidates_query = ( sb.table("paper_questions") .select( "id, paper_id, question_number, question_type, question_format, " "question_text, score, topics, analytics_topic, topic_tags, skill_tags, " "difficulty, knowledge_reminder, ai_hint, solution" ) .in_("paper_id", paper_ids) ) target_topic = target.get("analytics_topic") if target_topic: candidates_query = candidates_query.eq("analytics_topic", target_topic) candidates = candidates_query.execute().data if not candidates: return [] # Batch full-text scores from PostgreSQL (skip if too many candidates — slow) text_scores: dict[str, float] = {} if len(candidates) <= 50: try: rpc_result = sb.rpc( "text_similarity_scores", { "query_text": target.get("question_text") or "", "candidate_ids": [c["id"] for c in candidates], }, ).execute() for row in rpc_result.data or []: text_scores[row["question_id"]] = float(row["text_score"] or 0) except Exception: pass ranked = [] for candidate in candidates: text_score = text_scores.get(candidate["id"], 0.0) match_percent, reasons = similarity_score(target, candidate, text_score) if match_percent < 20: continue paper = papers_by_id.get(candidate["paper_id"], {}) source = ( f"{paper.get('year', '')} {paper.get('term', '').title()} " f"{paper.get('exam_type', '').title()}" ).strip() if paper.get("part_label"): source = f"{source} Part {paper['part_label']}" ranked.append( { "id": candidate["id"], "paper_id": candidate["paper_id"], "source": source, "question_number": candidate["question_number"], "match_percent": match_percent, "match_reasons": reasons, "question_type": question_family(candidate), "question_text": candidate["question_text"], "topics": display_topics(candidate), "difficulty": candidate.get("difficulty"), "knowledge_reminder": candidate.get("knowledge_reminder", ""), "ai_hint": candidate.get("ai_hint", ""), "solution": candidate.get("solution", ""), } ) ranked.sort(key=lambda item: (-item["match_percent"], item["source"], item["question_number"])) # Keep only the best-scoring question per paper seen_papers: set[str] = set() deduped = [] for item in ranked: if item["paper_id"] not in seen_papers: seen_papers.add(item["paper_id"]) deduped.append(item) _similar_cache[question_id] = (time.time(), deduped) # Persist to DB so future requests are instant try: sb.table("paper_questions").update({"similar_questions": deduped}).eq("id", question_id).execute() except Exception: pass return deduped[:max(1, min(limit, 12))]