diff --git a/yourbench/pipeline/question_rewriting.py b/yourbench/pipeline/question_rewriting.py index 238a3ee7..7bc3af8f 100644 --- a/yourbench/pipeline/question_rewriting.py +++ b/yourbench/pipeline/question_rewriting.py @@ -111,7 +111,7 @@ def _build_question_rewriting_calls( messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] - calls.append(InferenceCall(messages=messages, tags=[STAGE_TAG])) + calls.append(InferenceCall(messages=messages, tags=STAGE_TAG)) indices.append(idx) return calls, indices @@ -156,6 +156,10 @@ def _process_question_rewriting_responses( "question_rewriting_rationale": rewritten.question_rewriting_rationale, "raw_question_rewriting_response": response, }) + + # Ensure question_mode is present (required by QuestionRow but may be missing from older datasets) + if "question_mode" not in new_row_dict: + new_row_dict["question_mode"] = "open-ended" # Default for older datasets try: # Validate and structure the data using QuestionRow @@ -191,7 +195,14 @@ def _process_question_type( """ try: logger.info(f"Processing {question_type} questions...") - dataset = custom_load_dataset(config=config, subset=load_subset) + try: # skipping question rewriting if subset not found + dataset = custom_load_dataset(config=config, subset=load_subset) + except Exception as e: + if "not found" in str(e).lower(): + logger.warning(f"Subset '{load_subset}' not found. Skipping {question_type} question rewriting.") + return + else: + raise e if not dataset or len(dataset) == 0: logger.warning(f"No {question_type} questions found or dataset is empty.")