|
| 1 | +import sys |
| 2 | +from datasets import load_dataset |
| 3 | + |
| 4 | + |
| 5 | +def clean_columns(dataset): |
| 6 | + keys_to_remove = [key for key in dataset.column_names if key not in ["canonical_solution", "test", "instance_id", "prompt"]] |
| 7 | + return dataset.remove_columns(keys_to_remove) |
| 8 | + |
| 9 | +def convert_mbpp_tests(assert_list): |
| 10 | + # Generate individual test functions |
| 11 | + test_functions = [] |
| 12 | + for i, assert_line in enumerate(assert_list, 1): |
| 13 | + test_func = f"def test{i}():\n {assert_line}" |
| 14 | + test_functions.append(test_func) |
| 15 | + |
| 16 | + return "\n\n".join(test_functions) |
| 17 | + |
| 18 | +def convert_humaneval_tests(test_code, entrypoint): |
| 19 | + # Split the input into lines and clean up |
| 20 | + lines = test_code.strip().split("\n") |
| 21 | + |
| 22 | + # Find all assert lines |
| 23 | + assert_lines = [line for line in lines if line.lstrip().startswith("assert")] |
| 24 | + |
| 25 | + # Generate individual test functions |
| 26 | + test_functions = [f"candidate = {entrypoint}"] |
| 27 | + for i, assert_line in enumerate(assert_lines, 1): |
| 28 | + test_func = f"def test{i}():\n{assert_line}" |
| 29 | + test_functions.append(test_func) |
| 30 | + |
| 31 | + return "\n\n".join(test_functions) |
| 32 | + |
| 33 | +def convert_humaneval(): |
| 34 | + ds = load_dataset("openai/openai_humaneval") |
| 35 | + for split in ds: |
| 36 | + ds[split] = ds[split].rename_column('task_id', 'instance_id') |
| 37 | + tests = [convert_humaneval_tests(one['test'], one['entry_point']) for one in ds[split]] |
| 38 | + ds[split] = ds[split].remove_columns(['test']) |
| 39 | + ds[split] = ds[split].add_column(name='test', column=tests) |
| 40 | + ds[split] = clean_columns(ds[split]) |
| 41 | + out_name = f"commit0/openai_humaneval" |
| 42 | + ds.push_to_hub(out_name) |
| 43 | + |
| 44 | +def convert_codecontests(): |
| 45 | + pass |
| 46 | + |
| 47 | +def convert_bigcodebench(): |
| 48 | + pass |
| 49 | + |
| 50 | +def convert_mbpp(): |
| 51 | + ds = load_dataset("google-research-datasets/mbpp") |
| 52 | + for split in ds: |
| 53 | + ds[split] = ds[split].rename_column('task_id', 'instance_id') |
| 54 | + ds[split] = ds[split].rename_column('code', 'canonical_solution') |
| 55 | + ds[split] = ds[split].add_column(name='test', column=[convert_mbpp_tests(one) for one in ds[split]['test_list']]) |
| 56 | + ds[split] = clean_columns(ds[split]) |
| 57 | + out_name = f"commit0/mbpp" |
| 58 | + ds.push_to_hub(out_name) |
| 59 | + |
| 60 | +if __name__ == "__main__": |
| 61 | + data = sys.argv[1].lower() |
| 62 | + if data == "mbpp": |
| 63 | + convert_mbpp() |
| 64 | + elif data == "humaneval": |
| 65 | + convert_humaneval() |
| 66 | + elif data == "codecontests": |
| 67 | + convert_codecontests() |
| 68 | + elif data == "bigcodebench": |
| 69 | + convert_bigcodebench() |
| 70 | + else: |
| 71 | + raise NotImplementedError() |
| 72 | + |
0 commit comments