17
17
"""
18
18
JavaSitter module
19
19
"""
20
-
21
20
from itertools import groupby
22
21
from typing import List , Set , Dict
23
22
from tree_sitter import Language , Node , Parser , Query , Tree
26
25
27
26
from cldk .models .treesitter import Captures
28
27
28
+ import logging
29
+
30
+ logger = logging .getLogger (__name__ )
31
+
29
32
30
33
class JavaSitter :
31
34
"""
@@ -51,8 +54,7 @@ def method_is_not_in_class(self, method_name: str, class_body: str) -> bool:
51
54
bool
52
55
True if the method is in the class, False otherwise.
53
56
"""
54
- methods_in_class = self .frame_query_and_capture_output ("(method_declaration name: (identifier) @name)" ,
55
- class_body )
57
+ methods_in_class = self .frame_query_and_capture_output ("(method_declaration name: (identifier) @name)" , class_body )
56
58
57
59
return method_name not in {method .node .text .decode () for method in methods_in_class }
58
60
@@ -103,8 +105,7 @@ def get_all_imports(self, source_code: str) -> Set[str]:
103
105
Returns:
104
106
Set[str]: A set of all the imports in the class.
105
107
"""
106
- import_declerations : Captures = self .frame_query_and_capture_output (
107
- query = "(import_declaration (scoped_identifier) @name)" , code_to_process = source_code )
108
+ import_declerations : Captures = self .frame_query_and_capture_output (query = "(import_declaration (scoped_identifier) @name)" , code_to_process = source_code )
108
109
return {capture .node .text .decode () for capture in import_declerations }
109
110
110
111
def get_pacakge_name (self , source_code : str ) -> str :
@@ -116,8 +117,7 @@ def get_pacakge_name(self, source_code: str) -> str:
116
117
Returns:
117
118
str: The package name.
118
119
"""
119
- package_name : Captures = self .frame_query_and_capture_output (query = "((package_declaration) @name)" ,
120
- code_to_process = source_code )
120
+ package_name : Captures = self .frame_query_and_capture_output (query = "((package_declaration) @name)" , code_to_process = source_code )
121
121
if package_name :
122
122
return package_name [0 ].node .text .decode ().replace ("package " , "" ).replace (";" , "" )
123
123
return None
@@ -143,8 +143,7 @@ def get_superclass(self, source_code: str) -> str:
143
143
Returns:
144
144
Set[str]: A set of all the superclasses in the class.
145
145
"""
146
- superclass : Captures = self .frame_query_and_capture_output (
147
- query = "(class_declaration (superclass (type_identifier) @superclass))" , code_to_process = source_code )
146
+ superclass : Captures = self .frame_query_and_capture_output (query = "(class_declaration (superclass (type_identifier) @superclass))" , code_to_process = source_code )
148
147
149
148
if len (superclass ) == 0 :
150
149
return ""
@@ -161,9 +160,7 @@ def get_all_interfaces(self, source_code: str) -> Set[str]:
161
160
Set[str]: A set of all the interfaces implemented by the class.
162
161
"""
163
162
164
- interfaces = self .frame_query_and_capture_output (
165
- "(class_declaration (super_interfaces (type_list (type_identifier) @interface)))" ,
166
- code_to_process = source_code )
163
+ interfaces = self .frame_query_and_capture_output ("(class_declaration (super_interfaces (type_list (type_identifier) @interface)))" , code_to_process = source_code )
167
164
return {interface .node .text .decode () for interface in interfaces }
168
165
169
166
def frame_query_and_capture_output (self , query : str , code_to_process : str ) -> Captures :
@@ -182,8 +179,7 @@ def frame_query_and_capture_output(self, query: str, code_to_process: str) -> Ca
182
179
183
180
def get_method_name_from_declaration (self , method_name_string : str ) -> str :
184
181
"""Get the method name from the method signature."""
185
- captures : Captures = self .frame_query_and_capture_output ("(method_declaration name: (identifier) @method_name)" ,
186
- method_name_string )
182
+ captures : Captures = self .frame_query_and_capture_output ("(method_declaration name: (identifier) @method_name)" , method_name_string )
187
183
188
184
return captures [0 ].node .text .decode ()
189
185
@@ -192,8 +188,12 @@ def get_method_name_from_invocation(self, method_invocation: str) -> str:
192
188
Using the tree-sitter query, extract the method name from the method invocation.
193
189
"""
194
190
195
- captures : Captures = self .frame_query_and_capture_output (
196
- "(method_invocation object: (identifier) @class_name name: (identifier) @method_name)" , method_invocation )
191
+ captures : Captures = self .frame_query_and_capture_output ("(method_invocation name: (identifier) @method_name)" , method_invocation )
192
+ return captures [0 ].node .text .decode ()
193
+
194
+ def get_identifier_from_arbitrary_statement (self , statement : str ) -> str :
195
+ """Get the identifier from an arbitrary statement."""
196
+ captures : Captures = self .frame_query_and_capture_output ("(identifier) @identifier" , statement )
197
197
return captures [0 ].node .text .decode ()
198
198
199
199
def safe_ascend (self , node : Node , ascend_count : int ) -> Node :
@@ -260,7 +260,7 @@ def get_call_targets(self, method_body: str, declared_methods: dict) -> Set[str]
260
260
)
261
261
return call_targets
262
262
263
- def get_calling_lines (self , source_method_code : str , target_method_name : str ) -> List [int ]:
263
+ def get_calling_lines (self , source_method_code : str , target_method_name : str , is_target_method_a_constructor : bool ) -> List [int ]:
264
264
"""
265
265
Returns a list of line numbers in source method where target method is called.
266
266
@@ -272,26 +272,34 @@ def get_calling_lines(self, source_method_code: str, target_method_name: str) ->
272
272
target_method_code : str
273
273
target method code
274
274
275
+ is_target_method_a_constructor : bool
276
+ True if target method is a constructor, False otherwise.
277
+
275
278
Returns:
276
279
--------
277
280
List[int]
278
281
List of line numbers within in source method code block.
279
282
"""
280
- query = "(method_invocation name: (identifier) @method_name)"
283
+ if not source_method_code :
284
+ return []
285
+ query = "(object_creation_expression (type_identifier) @object_name) (object_creation_expression type: (scoped_type_identifier (type_identifier) @type_name)) (method_invocation name: (identifier) @method_name)"
286
+
281
287
# if target_method_name is a method signature, get the method name
282
288
# if it is not a signature, we will just keep the passed method name
289
+
290
+ target_method_name = target_method_name .split ("(" )[0 ] # remove the arguments from the constructor name
283
291
try :
284
- target_method_name = self .get_method_name_from_declaration ( target_method_name )
285
- except Exception :
286
- pass
287
-
288
- captures : Captures = self . frame_query_and_capture_output ( query , source_method_code )
289
- # Find the line numbers where target method calls happen in source method
290
- target_call_lines = []
291
- for c in captures :
292
- method_name = c . node . text . decode ( )
293
- if method_name == target_method_name :
294
- target_call_lines . append ( c . node . start_point [ 0 ])
292
+ captures : Captures = self .frame_query_and_capture_output ( query , source_method_code )
293
+ # Find the line numbers where target method calls happen in source method
294
+ target_call_lines = []
295
+ for c in captures :
296
+ method_name = c . node . text . decode ( )
297
+ if method_name == target_method_name :
298
+ target_call_lines . append ( c . node . start_point [ 0 ])
299
+ except :
300
+ logger . warning ( f"Unable to get calling lines for { target_method_name } in { source_method_code } ." )
301
+ return []
302
+
295
303
return target_call_lines
296
304
297
305
def get_test_methods (self , source_class_code : str ) -> Dict [str , str ]:
@@ -398,8 +406,7 @@ def get_method_return_type(self, source_code: str) -> str:
398
406
The return type of the method.
399
407
"""
400
408
401
- type_references : Captures = self .frame_query_and_capture_output (
402
- "(method_declaration type: ((type_identifier) @type_id))" , source_code )
409
+ type_references : Captures = self .frame_query_and_capture_output ("(method_declaration type: ((type_identifier) @type_id))" , source_code )
403
410
404
411
return type_references [0 ].node .text .decode ()
405
412
@@ -426,9 +433,9 @@ def collect_leaf_token_values(node):
426
433
if len (node .children ) == 0 :
427
434
if filter_by_node_type is not None :
428
435
if node .type in filter_by_node_type :
429
- lexical_tokens .append (code[node .start_byte : node .end_byte ])
436
+ lexical_tokens .append (code [node .start_byte : node .end_byte ])
430
437
else :
431
- lexical_tokens .append (code [node .start_byte : node .end_byte ])
438
+ lexical_tokens .append (code [node .start_byte : node .end_byte ])
432
439
else :
433
440
for child in node .children :
434
441
collect_leaf_token_values (child )
@@ -462,11 +469,9 @@ def remove_all_comments(self, source_code: str) -> str:
462
469
pruned_source_code = self .make_pruned_code_prettier (source_code )
463
470
464
471
# Remove all comment lines: the comment lines start with / (for // and /*) or * (for multiline comments).
465
- comment_blocks : Captures = self .frame_query_and_capture_output (query = "((block_comment) @comment_block)" ,
466
- code_to_process = source_code )
472
+ comment_blocks : Captures = self .frame_query_and_capture_output (query = "((block_comment) @comment_block)" , code_to_process = source_code )
467
473
468
- comment_lines : Captures = self .frame_query_and_capture_output (query = "((line_comment) @comment_line)" ,
469
- code_to_process = source_code )
474
+ comment_lines : Captures = self .frame_query_and_capture_output (query = "((line_comment) @comment_line)" , code_to_process = source_code )
470
475
471
476
for capture in comment_blocks :
472
477
pruned_source_code = pruned_source_code .replace (capture .node .text .decode (), "" )
@@ -490,8 +495,7 @@ def make_pruned_code_prettier(self, pruned_code: str) -> str:
490
495
The prettified pruned code.
491
496
"""
492
497
# First remove remaining block comments
493
- block_comments : Captures = self .frame_query_and_capture_output (query = "((block_comment) @comment_block)" ,
494
- code_to_process = pruned_code )
498
+ block_comments : Captures = self .frame_query_and_capture_output (query = "((block_comment) @comment_block)" , code_to_process = pruned_code )
495
499
496
500
for capture in block_comments :
497
501
pruned_code = pruned_code .replace (capture .node .text .decode (), "" )
0 commit comments