8000 feat: optimized sanitization and evaluation · bigcode-project/selfcodealign@14f4294 · GitHub
[go: up one dir, main page]

Skip to content

Commit 14f4294

Browse files
committed
feat: optimized sanitization and evaluation
1 parent c9315a0 commit 14f4294

File tree

9 files changed

+1283
-36
lines changed

9 files changed

+1283
-36
lines changed

README.md

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,11 @@ Also, the container connection may be lost during execution. In this case, you c
239239
<summary>Data sanitization and selection</summary>
240240

241241
```shell
242-
python src/star_align/sanitize_data.py \
243-
--data_files /path/to/filtered.jsonl* \
244-
--output_file /path/to/final_dataset.jsonl \
245-
--parse_raw_response True \
246-
--passing_only True \
247-
--exact_match_dedup True \
248-
--data_augmentation False
242+
# Uncomment to do decontamination
243+
# export MBPP_PATH="/path/to/mbpp.jsonl"
244+
# export DS1000_PATH="/path/to/ds1000_data"
245+
# export DECONTAMINATION=1
246+
./sanitize.sh /path/to/exec-filtered.jsonl /path/to/sanitized.jsonl
249247
```
250248

251249
</details>

evaluation/README.md

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,29 @@
1-
# Reproduce the experiments
1+
# Evaluation
22

33
> [!IMPORTANT]
44
> **General requirements**
55
>
66
> Before you start, make sure you have cloned the repository and you are in the **root directory of the project**. Make sure you installed the required packages with `pip install -e .`. Different package versions may impact the reproducibility of the results.
7+
8+
## Running EvalPlus with vLLM
9+
10+
We implemented batched inference in [evaluation/text2code_vllm.py] using [vLLM](https://docs.vllm.ai/en/latest/). This speed up the evaluation significantly: **a greedy decoding run can be finished within 20 seconds**. Here is the command:
11+
12+
```bash
13+
MODEL=/path/to/your/model
14+
DATASET=humaneval # or mbpp
15+
SAVE_PATH=evalplus-$(basename $MODEL)-$DATASET.jsonl
16+
CUDA_VISIBLE_DEVICES=0 python -m evaluation.text2code_vllm \
17+
--model_key $MODEL \
18+
--dataset $DATASET \
19+
--save_path $SAVE_PATH
20+
21+
python -m evalplus.evaluate --dataset $DATASET --samples $SAVE_PATH
22+
```
23+
24+
## Reproduce StarCoder2-Instruct
25+
26+
> [!NOTE]
727
>
828
> We obtained the results with the subsequent hardware and environment:
929
>
@@ -12,13 +32,13 @@
1232
>
1333
> In case you face issues, we provide the raw outputs we generated in the [evalplus_results](evalplus_results) directory.
1434
15-
## Reproduce HumanEval(+) and MBPP(+)
35+
### Reproduce HumanEval(+) and MBPP(+)
1636

1737
We pack multiple problems into one batch to speed up the inference. A different batch size may lead to slightly worse/better results due to the floating point round off resulted from the underlying [cuBLAS](https://docs.nvidia.com/cuda/cublas/index.html) optimization.
1838

1939
Make sure you set `CUDA_VISIBLE_DEVICES` to the GPU you want to use and `cd`ed to the root directory of the repo. We assume you use device 0 in the following commands.
2040

21-
### HumanEval(+)
41+
#### HumanEval(+)
2242

2343
```bash
2444
MODEL_KEY=bigcode/starcoder2-15b-instruct-v0.1
@@ -46,7 +66,7 @@ python -m evalplus.evaluate --dataset $DATASET --samples $SAVE_PATH
4666
# pass@1: 0.634
4767
```
4868

49-
### MBPP(+)
69+
#### MBPP(+)
5070

5171
```bash
5272
MODEL_KEY=bigcode/starcoder2-15b-instruct-v0.1
@@ -71,4 +91,4 @@ python -m evalplus.evaluate --dataset $DATASET --samples $SAVE_PATH
7191
# pass@1: 0.642
7292
# mbpp+ (base + extra tests)
7393
# pass@1: 0.526
74-
```
94+
```

evaluation/text2code_vllm.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
1-
import itertools
21
import os
3-
from dataclasses import dataclass
2+
from dataclasses import dataclass, field
43
from pathlib import Path
54
from typing import Literal, TypedDict, cast
6-
from functools import partial
75
from evalplus.data import get_human_eval_plus, get_mbpp_plus, write_jsonl
86

97
# from evoeval.data import get_evo_eval
10-
from tqdm.auto import tqdm
118
from transformers import HfArgumentParser
129

13-
from star_align.llm_wrapper import GenerationConfig, get_model_context
1410
from star_align.prompt_template import SC2_INSTRUCT_PROMPT
1511
from star_align.utils import infer_prompt_template
1612

@@ -60,7 +56,7 @@ def map_mbpp_problem(p: dict) -> Text2CodeProblem:
6056
```python
6157
{assertion}
6258
```"""
63-
prefix = "" if PROMPT_TEMPLATE.endswith("\n") else "\n"
59+
prefix = ""
6460
response_prefix = f"""{prefix}```python"""
6561
return Text2CodeProblem(
6662
id=str(id), instruction=instruction, response_prefix=response_prefix
@@ -85,7 +81,6 @@ def map_humaneval_problem(p: dict) -> Text2CodeProblem:
8581
```python
8682
{prompt}
8783
```"""
88-
prefix = "" if PROMPT_TEMPLATE.endswith("\n") else "\n"
8984
prefix = ""
9085
prefix_template = os.getenv("PREFIX_TEMPLATE", "```python")
9186
response_prefix = prefix + (
@@ -115,21 +110,15 @@ class Args:
115110
"EvoEval_concise",
116111
]
117112
save_path: str
118-
119-
n_batches: int
120-
n_problems_per_batch: int
121-
n_samples_per_problem: int
122-
# prompted: bool
123-
113+
n_samples_per_problem: int = field(default=1)
114+
max_new_tokens: int = field(default=1024)
115+
top_p: float = field(default=1.0)
116+
temperature: float = field(default=0.0)
124117
model_name_or_path: str | None = None
125118

126119

127120
def main():
128-
parser = HfArgumentParser((Args, GenerationConfig))
129-
args, generation_config = cast(
130-
tuple[Args, GenerationConfig],
131-
parser.parse_args_into_dataclasses(),
132-
)
121+
args = cast(Args, HfArgumentParser(Args).parse_args_into_dataclasses()[0])
133122
raw_problem_fn, map_problem_fn = (
134123
(get_humaneval_raw_problems, map_humaneval_problem)
135124
if args.dataset == "humaneval"
@@ -141,10 +130,10 @@ def main():
141130
engine = LLM(args.model_name_or_path or args.model_key)
142131
sampling_params = SamplingParams(
143132
n=args.n_samples_per_problem,
144-
temperature=generation_config.temperature,
145-
max_tokens=generation_config.max_new_tokens,
133+
temperature=args.temperature,
134+
max_tokens=args.max_new_tokens,
146135
top_k=-1,
147-
top_p=generation_config.top_p,
136+
top_p=args.top_p,
148137
stop="\n```\n",
149138
)
150139

0 commit comments

Comments
 (0)
0