-
-
Notifications
You must be signed in to change notification settings - Fork 220
/
utils.py
370 lines (295 loc) · 12.3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
import difflib
import logging
import random
import re
from functools import cache
from itertools import islice
from typing import Iterable, List, Sequence, TypeVar
import nltk
from faker import Faker
from langroid.mytypes import Document
from langroid.parsing.document_parser import DocumentType
from langroid.parsing.parser import Parser, ParsingConfig
from langroid.parsing.repo_loader import RepoLoader
from langroid.parsing.url_loader import URLLoader
from langroid.parsing.urls import get_urls_paths_bytes_indices
Faker.seed(23)
random.seed(43)
logger = logging.getLogger(__name__)
# Ensures the NLTK resource is available
@cache
def download_nltk_resource(resource: str) -> None:
try:
nltk.data.find(resource)
except LookupError:
nltk.download(resource, quiet=True)
T = TypeVar("T")
def batched(iterable: Iterable[T], n: int) -> Iterable[Sequence[T]]:
"""Batch data into tuples of length n. The last batch may be shorter."""
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := tuple(islice(it, n)):
yield batch
def generate_random_sentences(k: int) -> str:
# Load the sample text
download_nltk_resource("gutenberg")
from nltk.corpus import gutenberg
text = gutenberg.raw("austen-emma.txt")
# Split the text into sentences
sentences = nltk.tokenize.sent_tokenize(text)
# Generate k random sentences
random_sentences = random.choices(sentences, k=k)
return " ".join(random_sentences)
def generate_random_text(num_sentences: int) -> str:
fake = Faker()
text = ""
for _ in range(num_sentences):
text += fake.sentence() + " "
return text
def closest_string(query: str, string_list: List[str]) -> str:
"""Find the closest match to the query in a list of strings.
This function is case-insensitive and ignores leading and trailing whitespace.
If no match is found, it returns 'No match found'.
Args:
query (str): The string to match.
string_list (List[str]): The list of strings to search.
Returns:
str: The closest match to the query from the list, or 'No match found'
if no match is found.
"""
# Create a dictionary where the keys are the standardized strings and
# the values are the original strings.
str_dict = {s.lower().strip(): s for s in string_list}
# Standardize the query and find the closest match in the list of keys.
closest_match = difflib.get_close_matches(
query.lower().strip(), str_dict.keys(), n=1
)
# Retrieve the original string from the value in the dictionary.
original_closest_match = (
str_dict[closest_match[0]] if closest_match else "No match found"
)
return original_closest_match
def split_paragraphs(text: str) -> List[str]:
"""
Split the input text into paragraphs using "\n\n" as the delimiter.
Args:
text (str): The input text.
Returns:
list: A list of paragraphs.
"""
# Split based on a newline, followed by spaces/tabs, then another newline.
paras = re.split(r"\n[ \t]*\n", text)
return [para.strip() for para in paras if para.strip()]
def split_newlines(text: str) -> List[str]:
"""
Split the input text into lines using "\n" as the delimiter.
Args:
text (str): The input text.
Returns:
list: A list of lines.
"""
lines = re.split(r"\n", text)
return [line.strip() for line in lines if line.strip()]
def number_segments(s: str, granularity: int = 1) -> str:
"""
Number the segments in a given text, preserving paragraph structure.
A segment is a sequence of `len` consecutive "sentences", where a "sentence"
is either a normal sentence, or if there isn't enough punctuation to properly
identify sentences, then we use a pseudo-sentence via heuristics (split by newline
or failing that, just split every 40 words). The goal here is simply to number
segments at a reasonable granularity so the LLM can identify relevant segments,
in the RelevanceExtractorAgent.
Args:
s (str): The input text.
granularity (int): The number of sentences in a segment.
If this is -1, then the entire text is treated as a single segment,
and is numbered as <#1#>.
Returns:
str: The text with segments numbered in the style <#1#>, <#2#> etc.
Example:
>>> number_segments("Hello world! How are you? Have a good day.")
'<#1#> Hello world! <#2#> How are you? <#3#> Have a good day.'
"""
if granularity < 0:
return "<#1#> " + s
numbered_text = []
count = 0
paragraphs = split_paragraphs(s)
for paragraph in paragraphs:
sentences = nltk.sent_tokenize(paragraph)
# Some docs are problematic (e.g. resumes) and have no (or too few) periods,
# so we can't split usefully into sentences.
# We try a series of heuristics to split into sentences,
# until the avg num words per sentence is less than 40.
avg_words_per_sentence = sum(
len(nltk.word_tokenize(sentence)) for sentence in sentences
) / len(sentences)
if avg_words_per_sentence > 40:
sentences = split_newlines(paragraph)
avg_words_per_sentence = sum(
len(nltk.word_tokenize(sentence)) for sentence in sentences
) / len(sentences)
if avg_words_per_sentence > 40:
# Still too long, just split on every 40 words
sentences = []
for sentence in nltk.sent_tokenize(paragraph):
words = nltk.word_tokenize(sentence)
for i in range(0, len(words), 40):
# if there are less than 20 words left after this,
# just add them to the last sentence and break
if len(words) - i < 20:
sentences.append(" ".join(words[i:]))
break
else:
sentences.append(" ".join(words[i : i + 40]))
for i, sentence in enumerate(sentences):
num = count // granularity + 1
number_prefix = f"<#{num}#>" if count % granularity == 0 else ""
sentence = f"{number_prefix} {sentence}"
count += 1
sentences[i] = sentence
numbered_paragraph = " ".join(sentences)
numbered_text.append(numbered_paragraph)
return " \n\n ".join(numbered_text)
def number_sentences(s: str) -> str:
return number_segments(s, granularity=1)
def parse_number_range_list(specs: str) -> List[int]:
"""
Parse a specs string like "3,5,7-10" into a list of integers.
Args:
specs (str): A string containing segment numbers and/or ranges
(e.g., "3,5,7-10").
Returns:
List[int]: List of segment numbers.
Example:
>>> parse_number_range_list("3,5,7-10")
[3, 5, 7, 8, 9, 10]
"""
spec_indices = set() # type: ignore
for part in specs.split(","):
# some weak LLMs may generate <#1#> instead of 1, so extract just the digits
# or the "-"
part = "".join(char for char in part if char.isdigit() or char == "-")
if "-" in part:
start, end = map(int, part.split("-"))
spec_indices.update(range(start, end + 1))
else:
spec_indices.add(int(part))
return sorted(list(spec_indices))
def strip_k(s: str, k: int = 2) -> str:
"""
Strip any leading and trailing whitespaces from the input text beyond length k.
This is useful for removing leading/trailing whitespaces from a text while
preserving paragraph structure.
Args:
s (str): The input text.
k (int): The number of leading and trailing whitespaces to retain.
Returns:
str: The text with leading and trailing whitespaces removed beyond length k.
"""
# Count leading and trailing whitespaces
leading_count = len(s) - len(s.lstrip())
trailing_count = len(s) - len(s.rstrip())
# Determine how many whitespaces to retain
leading_keep = min(leading_count, k)
trailing_keep = min(trailing_count, k)
# Use slicing to get the desired output
return s[leading_count - leading_keep : len(s) - (trailing_count - trailing_keep)]
def clean_whitespace(text: str) -> str:
"""Remove extra whitespace from the input text, while preserving
paragraph structure.
"""
paragraphs = split_paragraphs(text)
cleaned_paragraphs = [" ".join(p.split()) for p in paragraphs if p]
return "\n\n".join(cleaned_paragraphs) # Join the cleaned paragraphs.
def extract_numbered_segments(s: str, specs: str) -> str:
"""
Extract specified segments from a numbered text, preserving paragraph structure.
Args:
s (str): The input text containing numbered segments.
specs (str): A string containing segment numbers and/or ranges
(e.g., "3,5,7-10").
Returns:
str: Extracted segments, keeping original paragraph structures.
Example:
>>> text = "(1) Hello world! (2) How are you? (3) Have a good day."
>>> extract_numbered_segments(text, "1,3")
'Hello world! Have a good day.'
"""
# Use the helper function to get the list of indices from specs
if specs.strip() == "":
return ""
spec_indices = parse_number_range_list(specs)
# Regular expression to identify numbered segments like
# <#1#> Hello world! This is me. <#2#> How are you? <#3#> Have a good day.
# Note we match any character between segment markers, including newlines.
segment_pattern = re.compile(r"<#(\d+)#>([\s\S]*?)(?=<#\d+#>|$)")
# Split the text into paragraphs while preserving their boundaries
paragraphs = split_paragraphs(s)
extracted_paragraphs = []
for paragraph in paragraphs:
segments_with_numbers = segment_pattern.findall(paragraph)
# Extract the desired segments from this paragraph
extracted_segments = [
segment
for num, segment in segments_with_numbers
if int(num) in spec_indices
]
# If we extracted any segments from this paragraph,
# join them and append to results
if extracted_segments:
extracted_paragraphs.append(" ".join(extracted_segments))
return "\n\n".join(extracted_paragraphs)
def extract_content_from_path(
path: bytes | str | List[bytes | str],
parsing: ParsingConfig,
doc_type: str | DocumentType | None = None,
) -> str | List[str]:
"""
Extract the content from a file path or URL, or a list of file paths or URLs.
Args:
path (bytes | str | List[str]): The file path or URL, or a list of file paths or
URLs, or bytes content. The bytes option is meant to support cases
where upstream code may have already loaded the content (e.g., from a
database or API) and we want to avoid having to copy the content to a
temporary file.
parsing (ParsingConfig): The parsing configuration.
doc_type (str | DocumentType | None): The document type if known.
If multiple paths are given, this MUST apply to ALL docs.
Returns:
str | List[str]: The extracted content if a single file path or URL is provided,
or a list of extracted contents if a
list of file paths or URLs is provided.
"""
if isinstance(path, str) or isinstance(path, bytes):
paths = [path]
elif isinstance(path, list) and len(path) == 0:
return ""
else:
paths = path
url_idxs, path_idxs, byte_idxs = get_urls_paths_bytes_indices(paths)
urls = [paths[i] for i in url_idxs]
path_list = [paths[i] for i in path_idxs]
byte_list = [paths[i] for i in byte_idxs]
path_list.extend(byte_list)
parser = Parser(parsing)
docs: List[Document] = []
try:
if len(urls) > 0:
loader = URLLoader(urls=urls, parser=parser) # type: ignore
docs = loader.load()
if len(path_list) > 0:
for p in path_list:
path_docs = RepoLoader.get_documents(
p, parser=parser, doc_type=doc_type
)
docs.extend(path_docs)
except Exception as e:
logger.warning(f"Error loading path {paths}: {e}")
return ""
if len(docs) == 1:
return docs[0].content
else:
return [d.content for d in docs]