8000 feat(eval): vllm-based ds1000 · bigcode-project/selfcodealign@d7b96df · GitHub
[go: up one dir, main page]

Skip to content

Commit d7b96df

Browse files
committed
feat(eval): vllm-based ds1000
1 parent cb66d7e commit d7b96df
8000

File tree

1 file changed

+264
-0
lines changed

1 file changed

+264
-0
lines changed

evaluation/ds_1000.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
import os
2+
from dataclasses import dataclass, field
3+
from pathlib import Path
4+
from typing import Callable, Literal, cast
5+
from transformers import< 10000 /span> AutoTokenizer
6+
from ds1000 import DS1000Dataset, DS1000Problem
7+
from tqdm.auto import tqdm
8+
from transformers import HfArgumentParser
9+
10+
from star_align.llm_wrapper import (
11+
GenerationConfig,
12+
ModelContext,
13+
create_infilling_prompt,
14+
get_model_context,
15+
)
16+
from star_align.utils import infer_prompt_template
17+
18+
from vllm import LLM, SamplingParams
19+
20+
PROMPT = cast(str, None)
21+
22+
23+
@dataclass
24+
class Args:
25+
dataset_path: str
26+
model_key: str
27+
model_name_or_path: str
28+
mode: Literal["Insertion", "Completion"]
29+
output_dir: str
30+
31+
temperature: float = field(default=0.2)
32+
top_p: float = field(default=0.95)
33+
max_length: int = field(default=1024)
34+
n_samples_per_batch: int = field(default=5)
35+
n_batches: int = field(default=8)
36+
37+
def to_generation_config(self) -> GenerationConfig:
38+
return GenerationConfig(
39+
# Use max_length to control
40+
max_new_tokens=9999999999999,
41+
top_p=self.top_p,
42+
temperature=self.temperature,
43+
max_length=self.max_length,
44+
)
45+
46+
47+
def postprocess(text: str) -> str:
48+
return text.split("```")[0]
49+
50+
51+
def create_prompt(args: Args, tokenizer: AutoTokenizer, problem: DS1000Problem) -> str:
52+
prompt = problem["prompt"]
53+
if args.mode == "Insertion":
54+
prompt = preprocess_insertion_prompt(prompt)
55+
assert prompt.count("[insert]") == 1
56+
prefix, suffix = prompt.split("[insert]")
57+
prompt = create_infilling_prompt(
58+
model_key=args.model_key,
59+
prefix=prefix,
60+
suffix=suffix,
61+
tokenizer=tokenizer,
62+
)
63+
else:
64+
assert args.mode == "Completion"
65+
instruction, response_prefix = preprocess_completion_prompt(problem["prompt"])
66+
prompt = PROMPT.format(
67+
instruction=instruction,
68+
response=response_prefix,
69+
)
70+
return prompt
71+
72+
73+
def generate(
74+
args: Args,
75+
# model_context: ModelContext,
76+
engine: LLM,
77+
problem: DS1000Problem,
78+
):
79+
lib: str = problem["lib"]
80+
model_key = args.model_key.replace("/", "-")
81+
problem_id: str = f"q{problem.problem_id}"
82+
path = Path(args.output_dir) / model_key / lib / args.mode / problem_id
83+
finishing_signal = path / "FINISHED"
84+
if finishing_signal.exists():
85+
print("Skipping:", path)
86+
return
87+
if not path.exists():
88+
print("Making directory:", path)
89+
path.mkdir(parents=True, exist_ok=True)
90+
# config = args.to_generation_config()
91+
prompt = create_prompt(args, engine.get_tokenizer(), problem)
92+
print("========PROMPT=======")
93+
print(prompt)
94+
print("========PROMPT=======")
95+
96+
sampling_params = SamplingParams(
97+
n=args.n_batches * args.n_samples_per_batch,
98+
temperature=args.temperature,
99+
max_tokens=args.max_length,
100+
top_k=-1,
101+
top_p=args.top_p,
102+
stop=["```"],
103+
)
104+
105+
# for batch_idx in range(args.n_batches):
106+
# print(f"Generating batch {batch_idx} of {args.n_batches}")
107+
# response = model_context.complete(
108+
# config=config,
109+
# prompts=[prompt] * args.n_samples_per_batch,
110+
# stop_tokens=["```"] if os.getenv("STOP") is not None else None,
111+
# )
112+
print(f"Generating {args.n_batches * args.n_samples_per_batch} samples")
113+
results = engine.generate(prompt, sampling_params)
114+
assert len(results) == 1
115+
print("=======RESPOSE[-1]=======")
116+
# postprocess_fn: Callable[[str], str] = (
117+
# (lambda x: x) if args.mode == "Insertion" else postprocess
118+
# )
119+
postprocess_fn = postprocess
120+
print(postprocess_fn(results[0].outputs[-1].text))
121+
# print("=======RESPOSE[-1]=======")
122+
# print("=======RESPOSE[RAW]=======")
123+
# print(response.decoded_outputs[-1])
124+
# print("=======RESPOSE[RAW]=======")
125+
# exit()
126+
assert len(results[0].outputs) == args.n_batches * args.n_samples_per_batch
127+
for idx, output in enumerate(results[0].outputs):
128+
sample = output.text
129+
sample = postprocess_fn(sample)
130+
# global_index = batch_idx * args.n_samples_per_batch + idx
131+
global_index = idx
132+
output_file = path / f"{global_index}.py"
133+
output_file.write_text(sample)
134+
finishing_signal.touch()
135+
136+
137+
def preprocess_completion_prompt(prompt: str) -> tuple[str, str]:
138+
"""Preprocess the DS-1000 prompt (Completion mode) into instruction and response prefix"""
139+
# hit = False
140+
if not "SOLUTION START" in prompt:
141+
answer_index = prompt.rindex("A:")
142+
answer = prompt[answer_index + 2 :].strip()
143+
instruction: str = prompt[:answer_index].strip()
144+
if instruction.startswith("Problem:"):
145+
instruction = instruction[len("Problem:") :].strip()
146+
if "### BEGIN SOLUTION" in prompt:
147+
assert prompt.count("<code>") == 1
148+
assert prompt.count("</code>") == 0
149+
lines = answer.splitlines(keepends=True)
150+
return_line, result_line, begin_line = lines[-3:]
151+
assert return_line.strip().startswith("# return")
152+
assert result_line.strip().startswith("# ")
153+
assert begin_line.strip() == "### BEGIN SOLUTION"
154+
response = "".join(lines[:-3]).strip()
155+
hint = begin_line.replace("###", "#").replace("BEGIN SOLUTION", "Solution")
156+
response += f"\n{hint}\n"
157+
else:
158+
assert "BEGIN SOLUTION" in prompt
159+
assert prompt.count("<code>") == 2
160+
assert prompt.count("</code>") == 1
161+
first_block_start = prompt.index("<code>")
162+
first_block_end = prompt.index("</code>")
163+
second_block_start = prompt.index("<code>", first_block_start + 1)
164+
assert first_block_end < second_block_start
165+
lines = answer.splitlines(keepends=True)
166+
block_end, instruction_line, begin_line, block_start = lines[-4:]
167+
assert begin_line.strip() == "BEGIN SOLUTION"
168+
assert block_start.strip() == "<code>"
169+
if not block_end.strip() == "</code>":
170+
if lines[-6].strip() == "</code>":
171+
response_prefix = lines[:-6]
172+
starting_lines = lines[-5:-2]
173+
else:
174+
assert instruction_line.strip() == "</code>"
175+
response_prefix = lines[:-3]
176+
starting_lines = lines[-2:-2]
177+
else:
178+
response_prefix = lines[:-4]
179+
starting_lines = lines[-3:-2]
180+
starting_lines = [f"# {line.lstrip()}" for line in starting_lines]
181+
response = "".join([*response_prefix, *starting_lines]).strip()
182+
response += "\n# Solution\n"
183+
else:
184+
# hit = True
185+
assert prompt.count("<code>") == 0
186+
assert prompt.count("</code>") == 0
187+
assert prompt.strip().endswith("# SOLUTION START")
188+
code_prefix = prompt[: prompt.rindex("# SOLUTION START")].strip()
189+
instruction = f"""Write a solution to the following problem:
190+
```python
191+
{code_prefix}
192+
```"""
193+
response = f"```python\n{code_prefix}\n# Solution\n"
194+
instruction = instruction.replace("<code>", "```python").replace("</code>", "```")
195+
response = response.replace("<code>", "```python").replace("</code>", "```")
196+
# if hit:
197+
# print("[Instruction]")
198+
# print(instruction)
199+
# print("[Response]")
200+
# print(response)
201+
# breakpoint()
202+
return instruction, response
203+
204+
205+
def preprocess_insertion_prompt(prompt: str) -> str:
206+
pattern = """</code>
207+
BEGIN SOLUTION
208+
<code>
209+
[insert]
210+
</code>
211+
END SOLUTION"""
212+
pattern_index = prompt.index(pattern)
213+
# pattern_block = prompt[pattern_index:]
214+
prefix = prompt[:pattern_index]
215+
# hit = False
216+
if pattern + "\n<code>" in prompt:
217+
index = prompt.index("<code>", pattern_index + len(pattern))
218+
suffix = prompt[index + len("<code>") :]
219+
else:
220+
# hit = True
221+
assert pattern in prompt
222+
suffix = ""
223+
final_prompt = prefix.strip() + "\n[insert]\n" + suffix.strip()
224+
final_prompt = final_prompt.replace("<code>", "```python").replace("</code>", "```")
225+
# if hit:
226+
# print(final_prompt)
227+
# breakpoint()
228+
return final_prompt
229+
230+
231+
def main():
232+
args = cast(Args, HfArgumentParser(Args).parse_args_into_dataclasses()[0])
233+
dataset = DS1000Dataset(args.dataset_path, mode=args.mode)
234+
235+
global PROMPT
236+
if (inferred := os.getenv("INFER")) is not None:
237+
if inferred == "1":
238+
PROMPT = infer_prompt_template(args.model_name_or_path)
239+
else:
240+
PROMPT = infer_prompt_template(inferred)
241+
242+
print("Using prompt:")
243+
print(PROMPT)
244+
245+
all_problems = [
246+
problem
247+
for problems in dataset.data.values()
248+
for problem in problems
249+
if args.mode == "Completion" or problem["lib"] != "Matplotlib"
250+
]
251+
engine = LLM(
252+
tokenizer=args.model_key, model=args.model_name_or_path or args.model_key
253+
)
254+
# model_context = get_model_context(
255+
# model_key=args.model_key,
256+
# model_name_or_path=args.model_name_or_path,
257+
# )
258+
for problem in tqdm(all_problems):
259+
# generate(args, model_context, problem)
260+
generate(args, engine, problem)
261+
262+
263+
if __name__ == "__main__":
264+
main()

0 commit comments

Comments
 (0)
0