8000 Added how to use commit0 for sampling during STAR training by wenting-zhao · Pull Request #105 · commit-0/commit0 · GitHub
[go: up one dir, main page]

Skip to content

Added how to use commit0 for sampling during STAR training #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add paralle commit0 tests
  • Loading branch information
wenting-zhao committed Dec 8, 2024
commit cd1fbdf342653ff4c509a70ade97af221402a4b0
4 changes: 2 additions & 2 deletions commit0/harness/run_pytest_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def main(
or "bigcodebench" in dataset_name
or "codecontests" in dataset_name
):
repo_name = example["instance_id"]
repo_name = str(example["instance_id"])
dataset_type = "simple"
else:
repo_name = example["repo"].split("/")[-1]
Expand Down Expand Up @@ -174,7 +174,7 @@ def main(
prompt = example["prompt"] if "prompt" in example.keys() else ""
matches = extract_code_blocks(solution)
if len(matches) > 0:
solution = "\n\n".join(matches)
solution = matches[0]
else:
solution = prompt + "\n\n" + solution
patch = solution + "\n\n" + example["test"]
Expand Down
19 changes: 15 additions & 4 deletions examples/star/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,21 @@ def generate_predictions(

prompts: List[str] = []
for example in dataset:
prompt = (
f"{example['text']} Your code should satisfy these tests:\n\n"
f"{'\n'.join(example['test_list'])}"
)
prompt = example["prompt"]
test = example["test"]
prompt = f"""Write a Python function implementation for the following prompt:

{prompt}

Your code should satisfy these tests:

{test}

Return only the implementation code, no tests or explanations. Be sure to include the relevant import statements:
```python
code
```
"""
prompts.append(prompt)

outputs = llm.generate(prompts, sampling_params)
Expand Down
18 changes: 17 additions & 1 deletion examples/star/star.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import argparse
from datasets import Dataset, load_dataset
from inference import generate_predictions
from utils import execute_tests


def main():
parser = argparse.ArgumentParser()
Expand All @@ -13,7 +15,21 @@ def main():

ds = load_dataset(args.dataset_name)
assert "train" in ds
samples = generate_predictions(args.model_name, ds["train"], args.temperature, args.n)
all_samples = generate_predictions(args.model_name, ds["train"], args.temperature, args.n)
for x in all_samples:
for xx in x:
print(xx)
print("-"*100)
assert len(ds["train"]) == len(all_samples)
all_traces, all_execution_results = execute_tests(ds["train"], all_samples)
passed_examples = []
for example, execution_results, samples in zip(ds["train"], all_execution_results, all_samples):
for execution_result, sample in zip(execution_results, samples):
if execution_result == 0:
example['prediction'] = sample
passed_examples.append(example)
break
print(len(passed_examples)/len(ds["train"]))

if __name__ == '__main__':
main()
74 changes: 74 additions & 0 deletions examples/star/utils.py
7B80
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed
from datasets import Dataset
from tqdm import tqdm
from typing import List, Tuple, Any

def execute_tests(
examples: Dataset,
all_samples: List[List[str]],
max_workers: int = 100
) -> Tuple[List[List[str]], List[List[int]]]:
"""
Run `commit0 test` in parallel for all (example, sample) pairs and collect results.

This function:
1. Flattens the iteration over examples and samples into a single list of tasks.
2. Executes them in parallel with a ThreadPoolExecutor.
3. Reassembles the results into two lists of lists:
- `all_traces`, containing the stdout for each sample.
- `all_execution_results`, containing the exit code for each sample.

We assume:
- `ds["train"]` is a list of dictionaries, each representing an example.
Each example contains an "instance_id" key.
- `all_samples` is a list where each element corresponds to an example from `ds["train"]`.
Each element of `all_samples` is a list of strings (samples).
- All elements of `all_samples` are of equal length.

Args:
ds (Dataset): A Dataset object.
all_samples (List[List[str]]): A 2D list of strings, where `all_samples[i]` corresponds to the samples associated with `ds[i]`.
max_workers (int): The number of worker threads to use for parallel execution. Default is 100.

Returns:
Tuple[List[List[str]], List[List[int]]]:
A tuple of (all_traces, all_execution_results) where:
- all_traces is a 2D list of strings: all_traces[i][j] is the stdout from running `commit0 test` on `ds[i]` with `all_samples[i][j]` as stdin.
- all_execution_results is a 2D list of ints: all_execution_results[i][j] is the exit code from running the command for that example/sample pair.
"""
M = len(examples)
N = len(all_samples[0]) if M > 0 else 0

# Flatten tasks: each task is (example_index, sample_index, instance_id, input_sample)
tasks = []
for i, example in enumerate(examples):
instance_id = str(example["instance_id"])
for j, sample in enumerate(all_samples[i]):
tasks.append((i, j, instance_id, sample))

all_traces = [ [None]*N for _ in range(M)]
all_execution_results = [ [None]*N for _ in range(M)]

# Run all tasks in parallel
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(
subprocess.run,
["commit0", "test", instance_id, "--stdin"],
input=sample,
text=True,
capture_output=True
): (i, j)
for (i, j, instance_id, sample) in tasks
}

for future in tqdm(as_completed(futures), total=len(tasks), desc="Executing tests"):
i, j = futures[future]
result = future.result()
stdout = result.stdout
exit_code = result.returncode
all_traces[i][j] = stdout
all_execution_results[i][j] = exit_code

return all_traces, all_execution_results
0