8000 feat: evoeval · bigcode-project/selfcodealign@1d85858 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1d85858

Browse files
committed
feat: evoeval
1 parent fc38d29 commit 1d85858

File tree

7 files changed

+130
-41
lines changed

7 files changed

+130
-41
lines changed

evaluation/text2code_vllm.py

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,17 @@
44
from typing import Literal, TypedDict, cast
55
from evalplus.data import get_human_eval_plus, get_mbpp_plus, write_jsonl
66

7-
# from evoeval.data import get_evo_eval
7+
from evoeval.data import get_evo_eval
88
from transformers import HfArgumentParser
99

10-
from star_align.prompt_template import SC2_INSTRUCT_PROMPT
11-
from star_align.utils import infer_prompt_template
10+
from star_align.utils import infer_prompt_template, is_base_model
1211

1312
from vllm import LLM, SamplingParams
1413

1514

16-
PROMPT_TEMPLATE = SC2_INSTRUCT_PROMPT
17-
18-
1915
class Text2CodeProblem(TypedDict):
2016
id: str
17+
prompt: str
2118
instruction: str
2219
response_prefix: str
2320

@@ -39,6 +36,14 @@ def get_humaneval_raw_problems() -> list[dict]:
3936
return list(problems.values())
4037

4138

39+
def get_evoeval_raw_problems(dataset: str):
40+
def get_raw_problems() -> list[dict]:
41+
problems = get_evo_eval(dataset)
42+
return list(problems.values())
43+
44+
return get_raw_problems
45+
46+
4247
def map_mbpp_problem(p: dict) -> Text2CodeProblem:
4348
id = p["task_id"]
4449
prompt = p["prompt"]
@@ -52,14 +57,16 @@ def map_mbpp_problem(p: dict) -> Text2CodeProblem:
5257
assertion = prompt[assert_index:].strip()
5358
instruction = f"""{instruction}
5459
55-
Your code should pass the following assertion:
5660
```python
5761
{assertion}
5862
```"""
5963
prefix = ""
6064
response_prefix = f"""{prefix}```python"""
6165
return Text2CodeProblem(
62-
id=str(id), instruction=instruction, response_prefix=response_prefix
66+
id=str(id),
67+
prompt=prompt,
68+
instruction=instruction,
69+
response_prefix=response_prefix,
6370
)
6471

6572

@@ -91,7 +98,10 @@ def map_humaneval_problem(p: dict) -> Text2CodeProblem:
9198
# response_prefix = f"""{prefix}```python
9299
# {prompt}"""
93100
return Text2CodeProblem(
94-
id=id, instruction=instruction, response_prefix=response_prefix
101+
id=id,
102+
prompt=prompt,
103+
instruction=instruction,
104+
response_prefix=response_prefix,
95105
)
96106

97107

@@ -120,44 +130,69 @@ class Args:
120130
def main():
121131
args = cast(Args, HfArgumentParser(Args).parse_args_into_dataclasses()[0])
122132
raw_problem_fn, map_problem_fn = (
123-
(get_humaneval_raw_problems, map_humaneval_problem)
124-
if args.dataset == "humaneval"
125-
else (get_mbpp_raw_problems, map_mbpp_problem)
133+
(get_evoeval_raw_problems(args.dataset), map_humaneval_problem)
134+
if args.dataset.startswith("EvoEval_")
135+
else (
136+
(get_humaneval_raw_problems, map_humaneval_problem)
137+
if args.dataset == "humaneval"
138+
else (get_mbpp_raw_problems, map_mbpp_problem)
139+
)
126140
)
127141
raw_problems = raw_problem_fn()
128142
problems = list(map(map_problem_fn, raw_problems))
129143

130-
engine = LLM(args.model_name_or_path or args.model_key)
144+
engine = LLM(
145+
tokenizer=args.model_key, model=args.model_name_or_path or args.model_key
146+
)
147+
148+
base_model_prompt = is_base_model(args.model_key)
149+
150+
stop: str | list[str] = (
151+
"\n```\n"
152+
if not base_model_prompt
153+
else ["\ndef ", "\nclass ", "\nimport ", "\nfrom ", "\nassert ", "\n# "]
154+
)
131155
sampling_params = SamplingParams(
132156
n=args.n_samples_per_problem,
133157
temperature=args.temperature,
134158
max_tokens=args.max_new_tokens,
135159
top_k=-1,
136160
top_p=args.top_p,
137-
stop="\n```\n",
161+
stop=stop,
138162
)
139163

140-
# state = get_model_context(args.model_key, args.model_name_or_path)
141-
try:
164+
if base_model_prompt:
165+
print("Base model")
166+
else:
142167
prompt_template = infer_prompt_template(
143168
os.getenv("TOKENIZER") or args.model_name_or_path or args.model_key
144169
)
145-
except:
146-
prompt_template = PROMPT_TEMPLATE
147-
# prompt_template = PROMPT_TEMPLATE
148-
print("Using:", prompt_template)
170+
# prompt_template = PROMPT_TEMPLATE
171+
print("Using:", prompt_template)
149172

150173
prompts: list[str] = []
151174
for problem in problems:
152-
prompt = prompt_template.format(
153-
instruction=problem["instruction"], response=problem["response_prefix"]
154-
)
175+
if not base_model_prompt:
176+
prompt = prompt_template.format(
177+
instruction=problem["instruction"], response=problem["response_prefix"]
178+
)
179+
else:
180+
prompt = problem["prompt"]
155181
prompts.append(prompt)
156182

157183
results = engine.generate(prompts, sampling_params)
158184
Path(args.save_path).write_text("")
185+
159186
step = 20
160187
print_or_not = [idx == 0 or idx % step == 0 for idx in range(len(problems))]
188+
189+
def sanitize(output: str) -> str:
190+
if not base_model_prompt:
191+
return output.split("```python")[-1].split("```")[0]
192+
for s in stop:
193+
output = output.rsplit(s, 1)[0]
194+
return output
195+
161196
for problem, prompt, result, print_debug in zip(
162197
problems, prompts, results, print_or_not
163198
):
@@ -169,7 +204,7 @@ def main():
169204
samples = [
170205
dict(
171206
task_id=problem["id"],
172-
completion=output.text.split("```python")[-1].split("```")[0],
207+
completion=sanitize(output.text),
173208
)
174209
for output in result.outputs
175210
]

prompts/self-ossinstruct-fewshot.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ Design the tasks so that the relevant concepts emerge naturally as the most appr
1414
### System: S->C
1515
Extract key programming concepts from the provided code snippet. Programming concepts refer to the foundational principles and techniques used in programming, which are crucial for developers to master. List these concepts in a comma-separated format.
1616

17+
### System: S->I
18+
Gain inspiration from the given code snippets and create a series of independent coding tasks that are original, distinct, diverse, and high-quality, fostering logical thinking.
19+
1720
### Example 1
1821
[Code]
1922
value = int(round((value - prev) * 1e5))

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ tiktoken~=0.6.0
77
accelerate>=0.27.2
88
datasets>=2.17.1
99
evalplus @ git+https://github.com/evalplus/evalplus.git@25e195e024b614f2671ad9ac5b8fdcd9b95a2b24#egg=evalplus
10+
evoeval~=0.1.0

src/star_align/collect_snippets.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,12 @@ class Args:
3434
max_avg_chars_per_line: int = field(default=80)
3535
# max_fragments: int = field(default=3)
3636
chunk_size: int = field(default=1000)
37-
content_chunk_lines: int = field(default=100)
37+
# A small value lets one document be used by multiple seeds
38+
content_chunk_lines: int = field(default=99999999999)
3839

3940
dataset_name: str = field(default="bigcode/starcoderdata")
4041
data_files: list[str] | None = field(default=None)
41-
max_considered_data: int | None = field(default=100000)
42+
max_considered_data: int | None = field(default=200000)
4243

4344
collect_function: bool = field(default=False)
4445
max_nodes_to_traverse: int = field(default=20000)

src/star_align/sanitize_data.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class Args:
2424
data_files: list[str]
2525
output_file: str
2626
shuffle: bool = field(default=True)
27+
remove_strange: bool = field(default=True)
2728
parse_raw_response: bool = field(default=True)
2829
passing_only: bool = field(default=True)
2930
data_augmentation: bool = field(default=False)
@@ -369,6 +370,10 @@ def mk_key(instruction: str) -> str:
369370

370371
def iterate(dataset: Dataset):
371372
for d in tqdm(dataset):
373+
if args.remove_strange:
374+
# NOTE: newly added
375+
if len(d["instruction"].split()) > 200:
376+
continue
372377
key_i, key_r = mk_key(d["instruction"]), mk_key(d["response"])
373378
if key_i in seen_keys or key_r in seen_keys:
374379
continue

src/star_align/self_ossinstruct.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import star_align
1616

17-
InstructMode = Literal["I->R", "S->C", "C->I"]
17+
InstructMode = Literal["I->R", "S->C", "C->I", "S->I"]
1818

1919
LANGUAGE_MAP = {
2020
"cpp": "C++",
@@ -170,6 +170,8 @@ def prefix_template(mode: InstructMode) -> str:
170170
return "### Snippet\n{snippet}\n\n### Concepts\n"
171171
elif mode == "C->I":
172172
return "### Properties\n{property}\n\n### Task\n"
173+
elif mode == "S->I":
174+
return "### Snippet\n{snippet}\n\n### Task\n"
173175
else:
174176
assert False
175177

@@ -199,6 +201,9 @@ def prompt(
199201
# property_prompt += f"\nnum_words: {num_words}"
200202
kwargs = dict(property=property_prompt)
201203
suffix = self.instruction
204+
elif mode == "S->I":
205+
kwargs = dict(snippet=self.snippet)
206+
suffix = self.instruction
202207
else:
203208
assert False
204209
prefix = self.prefix_template(mode).format(**kwargs)
@@ -213,6 +218,7 @@ class Fewshot:
213218
sys_i_r: str
214219
sys_c_i: str
215220
sys_s_c: str
221+
sys_s_i: str
216222

217223
examples: list[Example]
218224

@@ -297,8 +303,8 @@ def get_ossinstruct_fewshots() -> Fewshot:
297303
splits = re.split(r"### Example \d+", content)
298304
system_prompt = splits[0].strip()
299305
# "I->R", "E->S", "I->I", "PI->PI", "S->C"
300-
sys_pattern = r"### System: I->R|### System: C->I|### System: S->C"
301-
_, i_r, c_i, s_c = list(map(str.strip, re.split(sys_pattern, system_prompt)))
306+
sys_pattern = r"### System: I->R|### System: C->I|### System: S->C|### System: S->I"
307+
_, i_r, c_i, s_c, s_i = list(map(str.strip, re.split(sys_pattern, system_prompt)))
302308
if LLAMA3:
303309
i_r = f"{i_r}\n\nFor each '## Example' below, make sure you provide a '### Response' and a '### Tests' section."
304310
# system_prompt = re.split(r"### System: Instruction", system_prompt)[1]
@@ -331,6 +337,7 @@ def get_ossinstruct_fewshots() -> Fewshot:
331337
sys_i_r=i_r,
332338
sys_c_i=c_i,
333339
sys_s_c=s_c,
340+
sys_s_i=s_i,
334341
examples=examples,
335342
)
336343

@@ -343,6 +350,8 @@ def parse_generated_content(content: str, instruct_mode: InstructMode) -> dict |
343350
return dict(concepts=concepts)
344351
elif instruct_mode == "C->I":
345352
return dict(instruction=content.strip())
353+
elif instruct_mode == "S->I":
354+
return dict(instruction=content.strip())
346355
else:
347356
assert False
348357

@@ -352,11 +361,11 @@ def build_kwargs(instruct_mode: InstructMode, example: dict) -> dict[str, str]:
352361
if instruct_mode == "I->R":
353362
kwargs["instruction"] = example["instruction"]
354363
# Hack
355-
category_index = example["prompt"].rindex("category: ") + len("category: ")
356-
category_end = example["prompt"].index("\n", category_index)
357-
category = example["prompt"][category_index:category_end].strip()
358-
kwargs["category"] = category # type: ignore
359-
elif instruct_mode == "S->C":
364+
# category_index = example["prompt"].rindex("category: ") + len("category: ")
365+
# category_end = example["prompt"].index("\n", category_index)
366+
# category = example["prompt"][category_index:category_end].strip()
367+
# kwargs["category"] = category # type: ignore
368+
elif instruct_mode in ["S->C", "S->I"]:
360369
kwargs["snippet"] = example["seed"]
361370
elif instruct_mode == "C->I":
362371
lang = example.get("data_dir", "dummy_key_not_in_example")

src/star_align/utils.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,17 +192,52 @@ def find_codeblock_indices(
192192
return all_indices
193193

194194

195+
DEFAULT_TEMPLATE = """\
196+
### Instruction
197+
{instruction}
198+
199+
### Response
200+
{response}"""
201+
202+
203+
def is_base_model(tokenizer_name: str) -> bool:
204+
from transformers import AutoTokenizer
205+
206+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
207+
return tokenizer.chat_template is None and "octocoder" not in tokenizer_name
208+
209+
210+
OCTOCODER_CHAT_TEMPLATE = """\
211+
{%- for message in messages %}
212+
{%- if message['role'] == 'system' %}
213+
{{ raise_exception('System messages are not allowed in this template.') }}
214+
{%- else %}
215+
{%- if message['role'] == 'user' %}
216+
{{'Question: ' + message['content'] + '\n\n'}}
217+
{%- else %}
218+
{{'Answer: ' + message['content'] + '\n\n'}}
219+
{%- endif %}
220+
{%- endif %}
221+
{%- endfor %}
222+
{{'Question: '}}"""
223+
224+
195225
def infer_prompt_template(tokenizer_name: str) -> str:
196226
from transformers import AutoTokenizer
197227

198228
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
199-
template = tokenizer.apply_chat_template(
200-
[
201-
{"role": "user", "content": "{instruction}"},
202-
{"role": "assistant", "content": "{response}"},
203-
],
204-
tokenize=False,
205-
)
229+
if "octocoder" in tokenizer_name:
230+
tokenizer.chat_template = OCTOCODER_CHAT_TEMPLATE
231+
if tokenizer.chat_template is not None:
232+
template = tokenizer.apply_chat_template(
233+
[
234+
{"role": "user", "content": "{instruction}"},
235+
{"role": "assistant", "content": "{response}"},
236+
],
237+
tokenize=False,
238+
)
239+
else:
240+
template = DEFAULT_TEMPLATE
206241
end_index = template.rindex("{response}") + len("{response}")
207242
template = template[:end_index]
208243
return template

0 commit comments

Comments
 (0)
0