8000 feat(minihash): add option to adapt to empty key values · bigcode-project/selfcodealign@6c1530b · GitHub
[go: up one dir, main page]

Skip to content

Commit 6c1530b

Browse files
committed
feat(minihash): add option to adapt to empty key values
1 parent 3920430 commit 6c1530b

File tree

1 file changed

+47
-22
lines changed

1 file changed

+47
-22
lines changed

src/star_align/minhash_dedup.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,23 @@
3333

3434
parser = argparse.ArgumentParser()
3535
# IO Args
36-
parser.add_argument('--data_files', type=str, required=True)
37-
parser.add_argument('--output', type=str, required=True)
38-
parser.add_argument('--num_proc', type=int, default=os.cpu_count())
36+
parser.add_argument("--data_files", type=str, required=True)
37+
parser.add_argument("--output", type=str, required=True)
38+
parser.add_argument("--num_proc", type=int, default=os.cpu_count())
3939
# Meta Args
40-
parser.add_argument('--column', type=str, required=True)
41-
parser.add_argument('--batch_size', type=int, default=10_000)
40+
parser.add_argument("--column", type=str, required=True)
41+
parser.add_argument("--batch_size", type=int, default=10_000)
4242
# MinHash Args
43-
parser.add_argument('--ngram', type=int, default=5)
44-
parser.add_argument('--min_length', type=int, default=5)
45-
parser.add_argument('--seed', type=int, default=42)
46-
parser.add_argument('--num_perm', type=int, default=250)
47-
parser.add_argument('--threshold', type=float, default=0.7)
48-
parser.add_argument('--b', type=int, default=None)
49-
parser.add_argument('--r', type=int, default=None)
50-
parser.add_argument('--hash_func', type=str, default="sha1")
51-
parser.add_argument('--hash_bits', type=int, default=64)
43+
parser.add_argument("--ngram", type=int, default=5)
44+
parser.add_argument("--min_length", type=int, d 8000 efault=5)
45+
parser.add_argument("--ignore_empty", type=bool, default=False)
46+
parser.add_argument("--seed", type=int, default=42)
47+
parser.add_argument("--num_perm", type=int, default=250)
48+
parser.add_argument("--threshold", type=float, default=0.7)
49+
parser.add_argument("--b", type=int, default=None)
50+
parser.add_argument("--r", type=int, default=None)
51+
parser.add_argument("--hash_func", type=str, default="sha1")
52+
parser.add_argument("--hash_bits", type=int, default=64)
5253
args = parser.parse_args()
5354

5455

@@ -147,11 +148,17 @@ def sha1_hash(data: bytes, d: int = 32) -> int:
147148
Generate a d-bit hash value from the given data.
148149
"""
149150
if d == 32:
150-
return struct.unpack("<I", hashlib.sha1(data, usedforsecurity=False).digest()[:4])[0]
151+
return struct.unpack(
152+
"<I", hashlib.sha1(data, usedforsecurity=False).digest()[:4]
153+
)[0]
151154
if d == 64:
152-
return struct.unpack("<Q", hashlib.sha1(data, usedforsecurity=False).digest()[:8])[0]
155+
return struct.unpack(
156+
"<Q", hashlib.sha1(data, usedforsecurity=False).digest()[:8]
157+
)[0]
153158
# struct is faster but does not support arbitrary bit lengths
154-
return int.from_bytes(hashlib.sha1(data, usedforsecurity=False).digest()[: d // 8], byteorder="little")
159+
return int.from_bytes(
160+
hashlib.sha1(data, usedforsecurity=False).digest()[: d // 8], byteorder="little"
161+
)
155162

156163

157164
def xxh3_16hash(data: bytes, seed: int = 0) -> int:
@@ -240,10 +247,13 @@ def embed_func(
240247
# split content on whitespace (NON_ALPHA regex), tokenize with ngrams(), and join these n-grams into a single space separated string.
241248
# we then convert to lower case and then bytestrings which is then hashed. Only unique hashed n-grams are left.
242249
tokens: set[bytes] = {
243-
bytes(" ".join(t).lower(), "utf-8") for t in ngrams(NON_ALPHA.split(content.lower()), ngram_size, min_length)
250+
bytes(" ".join(t).lower(), "utf-8")
251+
for t in ngrams(NON_ALPHA.split(content.lower()), ngram_size, min_length)
244252
}
245253

246-
hashvalues: np.ndarray = np.array([hash_func(token) for token in tokens], dtype=dtype).reshape(len(tokens), 1)
254+
hashvalues: np.ndarray = np.array(
255+
[hash_func(token) for token in tokens], dtype=dtype
256+
).reshape(len(tokens), 1)
247257
# Permute the hash values to produce new universal hashes
248258
# Element-wise multiplication with 'hashvalues' and a (non 0 random value) and then adding b
249259
# Then, take modulo 'MODULO_PRIME' and bitwise_and with 'MAX_HASH' to keep only the necessary bits.
@@ -255,7 +265,9 @@ def embed_func(
255265
# Originally, byteswap was done for speed. Testing show it has a negligible impact
256266
# keeping for backward compatibility, even though theoretically and empirically
257267
# it doesnt matter if it is there or not. github.com/ekzhu/datasketch/issues/114
258-
Hs: list[bytes] = [bytes(hashvalues[start:end].byteswap().data) for start, end in hashranges]
268+
Hs: list[bytes] = [
269+
bytes(hashvalues[start:end].byteswap().data) for start, end in hashranges
270+
]
259271
return {SIGNATURE_COLUMN: Hs, INDEX_COLUMN: idx}
260272

261273

@@ -310,13 +322,22 @@ def hash_func(byte_data):
310322
# Loading
311323
data_files_list = [x.strip() for x in args.data_files.split(",")]
312324
ds = datasets.load_dataset("json", data_files=data_files_list, split="train")
313-
ds = ds.map(lambda x, i: {INDEX_COLUMN: i}, with_indices=True, num_proc=args.num_proc)
325+
ds = ds.map(
326+
lambda x, i: {INDEX_COLUMN: i}, with_indices=True, num_proc=args.num_proc
327+
)
328+
329+
if args.ignore_empty:
330+
ds_rest = ds.filter(lambda x: len(x[args.column].strip()) == 0)
331+
ds = ds.filter(lambda x: len(x[args.column].strip()) > 0)
332+
314333
ds = ds.filter(
315334
lambda x: len(NON_ALPHA.split(x[args.column].lower())) >= args.min_length,
316335
num_proc=args.num_proc,
317336
)
318337

319338
LEN_DATASET = len(ds)
339+
if args.ignore_empty:
340+
LEN_DATASET += len(ds_rest)
320341

321342
# MinHashing
322343
embedded = ds.map(
@@ -354,7 +375,9 @@ def hash_func(byte_data):
354375
contiguous=True,
355376
writer_batch_size=args.batch_size,
356377
)
357-
for key, Hs in zip(embedded_shard[INDEX_COLUMN], embedded_shard[SIGNATURE_COLUMN]):
378+
for key, Hs in zip(
379+
embedded_shard[INDEX_COLUMN], embedded_shard[SIGNATURE_COLUMN]
380+
):
358381
for i, H in enumerate(Hs):
359382
HASH_TABLES[i][H].add(key)
360383

@@ -387,6 +410,8 @@ def hash_func(byte_data):
387410
num_proc=args.num_proc,
388411
desc="Filtering clusters...",
389412
)
413+
if args.ignore_empty and len(ds_rest) > 0:
414+
final_data = datasets.concatenate_datasets([ds_rest, final_data])
390415

391416
# Saving
392417
final_data = final_data.remove_columns([CLUSTER_COLUMN, INDEX_COLUMN])

0 commit comments

Comments
 (0)
0