from typing import Dict, Iterable, Optional
import pandas as pd
from datasets import Dataset, DatasetDict, load_dataset
def example_to_qa_row(ex: Dict) -> Optional[Dict[str, str]]:
"""Convert one PubMedQA pqa_artificial example to a question/answer row.
- question: use the provided question string (omit context entirely).
- answer: prefer `long_answer` if non-empty; otherwise use `final_decision`.
"""
question = str(ex.get("question", "")).strip()
long_answer = ex.get("long_answer")
final_decision = ex.get("final_decision")
la = (str(long_answer).strip() if long_answer is not None else "")
fd = (str(final_decision).strip() if final_decision is not None else "")
answer = la if la else fd
return {"question": question, "answer": answer}
def dataset_to_csv(ds_split: Iterable[Dict], out_path: str) -> None:
rows = []
for ex in ds_split:
row = example_to_qa_row(ex)
if row:
rows.append(row)
pd.DataFrame(rows).to_csv(out_path, index=False)
def main() -> None:
"""Build an 80/20 train/validation Q/A split from PubMedQA (pqa_artificial).
- Loads `qiaojin/PubMedQA` with config `pqa_artificial` (train split).
- Randomly splits into 80% train and 20% validation (seed=42).
- Writes two CSVs with columns: `question`, `answer`.
- `answer` is `long_answer` if present, else `final_decision`.
"""
ds_all: DatasetDict = load_dataset("qiaojin/PubMedQA", "pqa_artificial")
if "train" not in ds_all:
raise RuntimeError("Expected a 'train' split in PubMedQA pqa_artificial")
base: Dataset = ds_all["train"]
split = base.train_test_split(test_size=0.2, seed=42, shuffle=True)
train_ds: Dataset = split["train"]
val_ds: Dataset = split["test"]
dataset_to_csv(train_ds, "pubmedqa_pqa_artificial_train_qa.csv")
dataset_to_csv(val_ds, "pubmedqa_pqa_artificial_validation_qa.csv")
print(
"Wrote pubmedqa_pqa_artificial_train_qa.csv ({} examples) and "
"pubmedqa_pqa_artificial_validation_qa.csv ({} examples)".format(
len(train_ds), len(val_ds)
)
)
if __name__ == "__main__":
main()