diff --git a/input.py b/input.py new file mode 100644 index 0000000..d644fee --- /dev/null +++ b/input.py @@ -0,0 +1,5 @@ +def fact(n): + if n == 0: + return 1 + else: + return n * fact(n-1) diff --git a/interactive_predict.py b/interactive_predict.py index 039c037..a245909 100644 --- a/interactive_predict.py +++ b/interactive_predict.py @@ -1,10 +1,10 @@ from common import Common -from extractor import Extractor +from python_extractor.extractor import Extractor SHOW_TOP_CONTEXTS = 10 MAX_PATH_LENGTH = 8 MAX_PATH_WIDTH = 2 -EXTRACTION_API = 'https://po3g2dx2qa.execute-api.us-east-1.amazonaws.com/production/extractmethods' +# EXTRACTION_API = 'https://po3g2dx2qa.execute-api.us-east-1.amazonaws.com/production/extractmethods' class InteractivePredictor: @@ -14,7 +14,8 @@ def __init__(self, config, model): model.predict([]) self.model = model self.config = config - self.path_extractor = Extractor(config, EXTRACTION_API, self.config.MAX_PATH_LENGTH, max_path_width=2) + # self.path_extractor = Extractor(config, EXTRACTION_API, self.config.MAX_PATH_LENGTH, max_path_width=2) + self.path_extractor = Extractor(self.config.MAX_PATH_LENGTH, MAX_PATH_WIDTH) @staticmethod def read_file(input_filename): @@ -22,7 +23,7 @@ def read_file(input_filename): return file.readlines() def predict(self): - input_filename = 'Input.java' + input_filename = 'input.py' print('Serving') while True: print('Modify the file: "' + input_filename + '" and press any key when ready, or "q" / "exit" to exit') @@ -30,11 +31,16 @@ def predict(self): if user_input.lower() in self.exit_keywords: print('Exiting...') return - user_input = ' '.join(self.read_file(input_filename)) try: - predict_lines, pc_info_dict = self.path_extractor.extract_paths(user_input) - except ValueError: + predict_lines = list(path.strip() for path in self.path_extractor.extract_paths(input_filename)) + contexts = predict_lines[0].split() + # space_padding = ' ' * (self.config.MAX_CONTEXTS - len(contexts) + 1) + space_padding = ' ' * (200 - len(contexts) + 1) + predict_lines[0] = ' '.join(contexts) + space_padding + except ValueError as e: + print(e) continue + pc_info_dict = UnitDict() model_results = self.model.predict(predict_lines) prediction_results = Common.parse_results(model_results, pc_info_dict, topk=SHOW_TOP_CONTEXTS) @@ -53,3 +59,8 @@ def predict(self): print('Predicted:') for predicted_seq in method_prediction.predictions: print('\t%s' % predicted_seq.prediction) + +class UnitDict(dict): + + def __getitem__(self, key): + return key diff --git a/preprocess_python.sh b/preprocess_python.sh new file mode 100755 index 0000000..4dfeb89 --- /dev/null +++ b/preprocess_python.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +########################################################### +# Change the following values to preprocess a new dataset. +# TRAIN_DIR, VAL_DIR and TEST_DIR should be paths to +# directories containing sub-directories with .java files +# DATASET_NAME is just a name for the currently extracted +# dataset. +# MAX_DATA_CONTEXTS is the number of contexts to keep in the dataset for each +# method (by default 1000). At training time, these contexts +# will be downsampled dynamically to MAX_CONTEXTS. +# MAX_CONTEXTS - the number of actual contexts (by default 200) +# that are taken into consideration (out of MAX_DATA_CONTEXTS) +# every training iteration. To avoid randomness at test time, +# for the test and validation sets only MAX_CONTEXTS contexts are kept +# (while for training, MAX_DATA_CONTEXTS are kept and MAX_CONTEXTS are +# selected dynamically during training). +# SUBTOKEN_VOCAB_SIZE, TARGET_VOCAB_SIZE - +# - the number of subtokens and target words to keep +# in the vocabulary (the top occurring words and paths will be kept). +# NUM_THREADS - the number of parallel threads to use. It is +# recommended to use a multi-core machine for the preprocessing +# step and set this value to the number of cores. +# PYTHON - python3 interpreter alias. +DATASET_NAME=python_20k +MAX_DATA_CONTEXTS=1000 +MAX_CONTEXTS=200 +SUBTOKEN_VOCAB_SIZE=186277 +TARGET_VOCAB_SIZE=26347 +NUM_THREADS=64 +PYTHON=python +########################################################### +REPO_DIR=repos +CONTEXTS_DIR=output +DATA_DIR=data + +TRAIN_DATA_FILE=${CONTEXTS_DIR}/train/path_contexts.csv +VAL_DATA_FILE=${CONTEXTS_DIR}/val/path_contexts.csv +TEST_DATA_FILE=${CONTEXTS_DIR}/test/path_contexts.csv + +mkdir -p ${DATA_DIR}/${DATASET_NAME} + +echo "Extracting paths..." +${PYTHON} python_extractor/extract.py --in_dir ${REPO_DIR} --out_dir ${CONTEXTS_DIR} --max_path_length 8 --max_path_width 2 --max_workers ${NUM_THREADS} + +TARGET_HISTOGRAM_FILE=data/${DATASET_NAME}/${DATASET_NAME}.histo.tgt.c2s +SOURCE_SUBTOKEN_HISTOGRAM=data/${DATASET_NAME}/${DATASET_NAME}.histo.ori.c2s +NODE_HISTOGRAM_FILE=data/${DATASET_NAME}/${DATASET_NAME}.histo.node.c2s + +echo "Creating histograms from the training data" +cat ${TRAIN_DATA_FILE} | cut -d' ' -f1 | tr '|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${TARGET_HISTOGRAM_FILE} +cat ${TRAIN_DATA_FILE} | cut -d' ' -f2- | tr ' ' '\n' | cut -d',' -f1,3 | tr ',|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${SOURCE_SUBTOKEN_HISTOGRAM} +cat ${TRAIN_DATA_FILE} | cut -d' ' -f2- | tr ' ' '\n' | cut -d',' -f2 | tr '|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${NODE_HISTOGRAM_FILE} + +${PYTHON} preprocess.py --train_data ${TRAIN_DATA_FILE} --test_data ${TEST_DATA_FILE} --val_data ${VAL_DATA_FILE} \ + --max_contexts ${MAX_CONTEXTS} --max_data_contexts ${MAX_DATA_CONTEXTS} --subtoken_vocab_size ${SUBTOKEN_VOCAB_SIZE} \ + --target_vocab_size ${TARGET_VOCAB_SIZE} --subtoken_histogram ${SOURCE_SUBTOKEN_HISTOGRAM} \ + --node_histogram ${NODE_HISTOGRAM_FILE} --target_histogram ${TARGET_HISTOGRAM_FILE} --output_name data/${DATASET_NAME}/${DATASET_NAME} + +# If all went well, the raw data files can be deleted, because preprocess.py creates new files +# with truncated and padded number of paths for each example. +rm ${TARGET_HISTOGRAM_FILE} ${SOURCE_SUBTOKEN_HISTOGRAM} ${NODE_HISTOGRAM_FILE} + diff --git a/python_extractor/.gitignore b/python_extractor/.gitignore new file mode 100644 index 0000000..4a87948 --- /dev/null +++ b/python_extractor/.gitignore @@ -0,0 +1,5 @@ +.* +__pycache__/ +main.json +*.c2v +!.git* \ No newline at end of file diff --git a/python_extractor/__init__.py b/python_extractor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python_extractor/extract.py b/python_extractor/extract.py new file mode 100644 index 0000000..365a760 --- /dev/null +++ b/python_extractor/extract.py @@ -0,0 +1,123 @@ +from __future__ import annotations + + +from argparse import ArgumentParser +from concurrent.futures import ThreadPoolExecutor +from pebble import ProcessPool +import concurrent.futures as cf +from functools import partial +from pathlib import PosixPath +import os +import math +import itertools as it +from glob import iglob, glob +import subprocess as sp +from tqdm import tqdm +from typing import Iterator, List, Iterable, Tuple, Union + +import re + +from extractor import Extractor + + +Path = Union[str, PosixPath] + + +def process(fname: str, max_length: int, max_wdith: int) -> List[str]: + extractor = Extractor(max_path_length=8, max_path_width=2) + try: + paths = extractor.extract_paths(fname) + except (ValueError, SyntaxError, RecursionError): + return list() + return list(paths) + + +def write_lines(fname: str, lines: Iterable[str]) -> None: + with open(fname, "a", encoding="ISO-8859-1") as stream: + stream.writelines(map(mask_method_name, lines)) + + +def mask_method_name(line: str) -> str: + method_name, _, _ = line.partition(" ") + pattern = re.compile(re.escape(f" {method_name},")) + return pattern.sub(" METHOD_NAME,", line) + + +def to_str_path(list_path: List[str]) -> str: + return f"{list_path[0]},{'|'.join(list_path[1:-1])},{list_path[-1]}" + + +def make_posix_path(path: Path) -> PosixPath: + return PosixPath(path) if isinstance(path, str) else path + + +def concatenate_path_conext_files(mined_dir_path: Path) -> None: + mined_dir_path = make_posix_path(mined_dir_path) + dtq = tqdm(["train", "test", "val"], desc="concatenating ast path conext files") + for _dir in dtq: + file_dir = str(mined_dir_path / f"{_dir}") + concate_sh = f"cat {file_dir}/*.c2v > {file_dir}/path_contexts.csv" + sp.run(concate_sh, shell=True, check=True) + + for f in iglob(str(mined_dir_path / "*/*.c2v")): + os.remove(f) + + print("Done concatenating all path_contexts from AST miner to a single file") + + +def source_files(data_dir: str): + for fname in iglob(f"{data_dir}/*/**/[!setup]*.py", recursive=True): + if os.path.isfile(fname) and not fname.startswith("test"): + yield fname + + +def chunker(iterable, n, fillvalue=None): + "Collect data into fixed-length chunks or blocks" + # chunker('ABCDEFG', 3, 'x') --> ABC DEF Gxx" + args = [iter(iterable)] * n + return it.zip_longest(*args, fillvalue=fillvalue) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("-maxlen", "--max_path_length", dest="max_path_length", required=False, default=8) + parser.add_argument("-maxwidth", "--max_path_width", dest="max_path_width", required=False, default=2) + parser.add_argument("-workers", "--max_workers", dest="max_workers", required=False, default=None) + parser.add_argument("-in_dir", "--in_dir", dest="in_dir", required=True) + parser.add_argument("-out_dir", "--out_dir", dest="out_dir", required=True) + # parser.add_argument("-file", "--file", dest="file", required=False) + args = parser.parse_args() + + TIMEOUT = 60 * 10 + MAX_WORKERS = int(args.max_workers) + MAX_LENGTH = args.max_path_length + MAX_WIDTH = args.max_path_width + REPOS = args.in_dir + OUTPUT = args.out_dir + + writes = list() + futures = list() + with ProcessPool(max_workers=MAX_WORKERS) as pool, ThreadPoolExecutor( + max_workers=1 + ) as writer: + futures = { + pool.schedule(process, args=[fname, MAX_LENGTH, MAX_WIDTH], timeout=TIMEOUT): fname + for fname in source_files(REPOS) + } + + for future in tqdm(cf.as_completed(futures), total=len(futures)): + fname = futures[future] + splitted = fname.split("/") + project = splitted[2] + bin_ = splitted[1] + c2v_file = f"{OUTPUT}/{bin_}/{project}.c2v" + try: + paths = future.result() + except cf.TimeoutError: + continue + if paths: + writes.append(writer.submit(partial(write_lines, c2v_file), paths)) + + cf.wait(writes) + + concatenate_path_conext_files(mined_dir_path=OUTPUT) diff --git a/python_extractor/extractor.py b/python_extractor/extractor.py new file mode 100644 index 0000000..e7a83ba --- /dev/null +++ b/python_extractor/extractor.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +import ast +from ast import NodeVisitor, increment_lineno +from contextlib import contextmanager +import dataclasses as dc +import collections +import itertools as it +import re +from typing import ( + Iterable, + List, + Iterator, + Union, + Dict, + List, + Optional, + TypeVar, +) + +A = TypeVar("A") + + +class Extractor(NodeVisitor): + def __init__(self, max_path_length, max_path_width): + self._stack: List[JsonTree] = list() + self._func_name: str = str() + self.tree: List[JsonTree] = list() + self._json_tree: List[dict] = list() + self.paths: List[List[JsonTree]] = list() + self.paths_map: Dict[str, List[List[str]]] = dict() + self.MAX_DEPTH = max_path_length + self.MAX_WIDTH = max_path_width + self._build_json = False + self.replace_pattern = re.compile("[0-9_,]") + + def add_path(self, path: List[JsonTree]): + """ + The hyperparameters for filtering out some paths could be applied here. + + """ + + def transform(tree: JsonTree) -> str: + return tree.node_type if isinstance(tree, JsonNode) else tree.value + + for prev_path in self.paths: + merged = self.merge(prev_path, path) + if merged: + # For some reason we must cast to list here or the last token goes missing + self.paths_map[self._func_name].append(list(map(transform, merged))) + + self.paths.append(list(path)) + + def merge( + self, left_path: List[JsonTree], right_path: List[JsonTree] + ) -> Optional[Iterator[JsonTree]]: + """ + V + / \ + + Once we have vertex, the index of left and right canot be more than W apart? + + """ + if self.fails_depth_check(left_path, right_path): + return None + + lefts, rights = iter(left_path), iter(right_path) + for left, right in zip(lefts, rights): + if id(left) == id(right): + vertex = left + else: + break + + if self.fails_width_check(vertex.children, left, right): + return None + + merged = it.chain(reversed(list(lefts)), [left, vertex, right], rights) + return merged + + def fails_depth_check(self, left_path: Iterable, right_path: Iterable) -> bool: + length = len(set(map(id, left_path)) ^ set(map(id, right_path))) + 1 + return length > self.MAX_DEPTH + + def fails_width_check(self, children: List[A], left: A, right: A) -> bool: + left_index = children.index(left) + right_index = children.index(right) + + return right_index - left_index > self.MAX_WIDTH + + def _parse(self, fname: str) -> List[dict]: + """I only care about function/method definitions for now""" + + with open(fname, "r", encoding="ISO-8859-1") as stream: + tree = ast.parse(stream.read()) + + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + self._func_name = self.clean(node.name) + self.paths_map[self._func_name] = [] + self.visit(node) + self.paths = list() + + return self._json_tree + + def clean(self, token: str) -> str: + token = self.to_snake_case(token) + return self.replace_pattern.sub("|", token).strip("|") + + @staticmethod + def to_snake_case(token: str) -> str: + splitted = re.sub( + "([A-Z][a-z]+)", r" \1", re.sub("([A-Z]+)", r" \1", token) + ).split() + return "_".join(word.strip('_').lower() for word in splitted) + + def extract_paths(self, fname: str) -> Iterator[str]: + # Can perhaps take the transformation function as a parameter for more flexibility + def transform(path_contexts: Iterable[Iterable[str]]) -> str: + def to_str_path(path_context: Iterable[str]) -> str: + context = list(path_context) + return f"{self.clean(context[0])},{'|'.join(context[1:-1])},{self.clean(context[-1])}" + + str_paths = map(to_str_path, path_contexts) + return f"{' '.join(str_paths)}".encode("unicode_escape").decode( + "ISO-8859-1" + ) + + self._parse(fname) + + paths = ( + f"{func_name} {transform(contexts)}\n" + for func_name, contexts in self.paths_map.items() + ) + self.paths_map = dict() + return paths + + def to_json(self, fname: str) -> List[dict]: + """Parse the syntax tree and output to json + + :param fname: The filename. + :type fname: str + :return: A json representation of the syntax tree. + :rtype: dict + """ + self._build_json = True + json_tree = self._parse(fname) + self._json_tree = list() + self._build_json = False + return json_tree + + def visit(self, node: ast.AST): + if isinstance( + node, (ast.boolop, ast.cmpop, ast.unaryop, ast.operator, ast.expr_context) + ): + return + if isinstance(node, (ast.BoolOp, ast.BinOp, ast.UnaryOp, ast.AugAssign)): + self.visit_Op(node) + elif isinstance(node, (ast.Break, ast.Continue)): + self.visit_Break_Continue(node) + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + self.visit_Function(node) + else: + super().visit(node) + + def generic_visit(self, node: ast.AST): + """Called if no explicit visitor function exists for a node.""" + json_node = JsonNode(type(node).__name__) + with self.push_to_stack(json_node): + super().generic_visit(node) + + def visit_Function(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]): + method_name = JsonLeaf("MethodName", node.name) + if ast.get_docstring(node): + node.body = node.body[1:] + json_node = JsonNode(type(node).__name__) + with self.push_to_stack(json_node): + with self.push_to_stack(method_name): + self.add_path(self._stack) + super().generic_visit(node) + self.tree.append(json_node) + if self._build_json: + self._json_tree.append(json_node.to_dict()) + + def visit_arg(self, node: ast.arg): + arg_name = JsonLeaf("ArgName", node.arg) + json_node = JsonNode(type(node).__name__) + with self.push_to_stack(json_node): + with self.push_to_stack(arg_name): + self.add_path(self._stack) + super().generic_visit(node) + + def visit_Op(self, node: Union[ast.UnaryOp, ast.BoolOp, ast.BinOp, ast.AugAssign]): + operator = type(node.op).__name__ + node_type = "_".join([type(node).__name__, operator]) + json_node = JsonNode(node_type) + with self.push_to_stack(json_node): + super().generic_visit(node) + + def visit_Compare(self, node: ast.Compare): + json_node = JsonNode("Compare_" + type(node.ops[0]).__name__) + with self.push_to_stack(json_node): + super().generic_visit(node) + + def visit_Name(self, node: ast.Name): + json_node = JsonLeaf(type(node).__name__, node.id) + with self.push_to_stack(json_node): + self.add_path(self._stack) + + def visit_Constant(self, node: ast.Constant): + json_node = JsonLeaf( + type(node).__name__, + str(node.value).encode("unicode_escape").decode().replace(" ", ""), + ) + with self.push_to_stack(json_node): + self.add_path(self._stack) + + def visit_Break_Continue(self, node: Union[ast.Break, ast.Continue]): + json_node = JsonLeaf(type(node).__name__, type(node).__name__) + with self.push_to_stack(json_node): + self.add_path(self._stack) + + def add_to_tree(self, json_tree: JsonTree): + current_node = self._stack.pop() + if isinstance(current_node, JsonNode): + current_node.children.append(json_tree) + else: + raise RuntimeError("JsonLeaf node left on stack!") + self._stack.append(current_node) + + @contextmanager + def push_to_stack(self, json_node: JsonTree): + if self._stack: + self.add_to_tree(json_node) + self._stack.append(json_node) + try: + yield + finally: + self._stack.pop() + + +class _JsonTree: + """""" + + def to_dict(self): + return dc.asdict(self) + + +@dc.dataclass(frozen=True) +class JsonNode(_JsonTree): + node_type: str + children: List[_JsonTree] = dc.field(default_factory=list) + + +@dc.dataclass(frozen=True) +class JsonLeaf(_JsonTree): + node_type: str + value: str + + +JsonTree = Union[JsonNode, JsonLeaf] diff --git a/train.sh b/train.sh old mode 100644 new mode 100755 index be20fef..4a36f49 --- a/train.sh +++ b/train.sh @@ -1,3 +1,4 @@ +#!/usr/bin/env bash ########################################################### # Change the following values to train a new model. # type: the name of the new model, only affects the saved file name. @@ -5,13 +6,14 @@ # test_data: by default, points to the validation set, since this is the set that # will be evaluated after each training iteration. If you wish to test # on the final (held-out) test set, change 'val' to 'test'. -type=java-large-model -dataset_name=java-large -data_dir=data/java-large +type=python_20k +dataset_name=python_20k +data_dir=data/${dataset_name} data=${data_dir}/${dataset_name} test_data=${data_dir}/${dataset_name}.val.c2s model_dir=models/${type} +PYTHON=python mkdir -p ${model_dir} set -e -python3 -u code2seq.py --data ${data} --test ${test_data} --save_prefix ${model_dir}/model +${PYTHON} -u code2seq.py --data ${data} --test ${test_data} --save_prefix ${model_dir}/model