@@ -59,7 +59,7 @@ class TreeSitterWheelNotInstalled(Exception):
5959
6060def get_parser (location ):
6161 """
62- Get the appropriate tree-sitter parser and string identifier for
62+ Get the appropriate tree-sitter parser and grammar config for
6363 file at location.
6464 """
6565 file_type = Type (location )
@@ -108,27 +108,34 @@ def traverse(node, language_info, mutation_index):
108108 traverse (child , language_info , mutation_index )
109109
110110
111- def apply_mutation (text , start_point , end_point , replacement ):
111+ def apply_mutation (text , start_point , end_point , replacement , successive_line_count ):
112+ """Mutate tokens between start and end points with replacement string."""
113+
112114 start_row , start_col = start_point
113115 end_row , end_col = end_point
114116
115- lines = text .splitlines ()
116-
117- # Compute the start and end indices, +1 for newline.
118- start_index = sum (len (line ) + 1 for line in lines [:start_row ]) + start_col
119- end_index = sum (len (line ) + 1 for line in lines [:end_row ]) + end_col
117+ # Compute 1D mutation position from 2D coordinates
118+ start_index = successive_line_count [start_row ] + start_col
119+ end_index = successive_line_count [end_row ] + end_col
120120
121121 modified_text = text [:start_index ] + replacement + text [end_index :]
122+ modified_lines = modified_text .splitlines (keepends = True )
122123
123- modified_lines = modified_text .splitlines ()
124124 # Remove empty comment lines.
125125 if not replacement and modified_lines [start_row ].strip () == "" :
126126 del modified_lines [start_row ]
127127
128- return "\n " .join (modified_lines )
128+ return "" .join (modified_lines )
129129
130130
131131def get_stem_code (location ):
132+ """
133+ Return the stemmed code for the code file at the specified `location`.
134+
135+ Parse the code using tree-sitter, create a mutation index for tokens that
136+ need to be replaced or removed, and apply these mutations bottom-up to
137+ generate the stemmed code.
138+ """
132139 parser_result = get_parser (location )
133140 if not parser_result :
134141 return
@@ -143,11 +150,17 @@ def get_stem_code(location):
143150 # Apply mutations bottom-up
144151 mutations = dict (sorted (mutations .items (), reverse = True ))
145152 text = source .decode ()
153+ cur_count = 0
154+ lines = text .splitlines (keepends = True )
155+ successive_line_count = [cur_count := cur_count + len (line ) for line in lines ]
156+ successive_line_count .insert (0 , 0 )
157+
146158 for value in mutations .values ():
147159 text = apply_mutation (
148160 text = text ,
149161 end_point = value ["end_point" ],
150162 start_point = value ["start_point" ],
151163 replacement = ("idf" if value ["type" ] == "identifier" else "" ),
164+ successive_line_count = successive_line_count ,
152165 )
153166 return text
0 commit comments