33
33
34
34
parser = argparse .ArgumentParser ()
35
35
# IO Args
36
- parser .add_argument (' --data_files' , type = str , required = True )
37
- parser .add_argument (' --output' , type = str , required = True )
38
- parser .add_argument (' --num_proc' , type = int , default = os .cpu_count ())
36
+ parser .add_argument (" --data_files" , type = str , required = True )
37
+ parser .add_argument (" --output" , type = str , required = True )
38
+ parser .add_argument (" --num_proc" , type = int , default = os .cpu_count ())
39
39
# Meta Args
40
- parser .add_argument (' --column' , type = str , required = True )
41
- parser .add_argument (' --batch_size' , type = int , default = 10_000 )
40
+ parser .add_argument (" --column" , type = str , required = True )
41
+ parser .add_argument (" --batch_size" , type = int , default = 10_000 )
42
42
# MinHash Args
43
- parser .add_argument ('--ngram' , type = int , default = 5 )
44
- parser .add_argument ('--min_length' , type = int , default = 5 )
45
- parser .add_argument ('--seed' , type = int , default = 42 )
46
- parser .add_argument ('--num_perm' , type = int , default = 250 )
47
- parser .add_argument ('--threshold' , type = float , default = 0.7 )
48
- parser .add_argument ('--b' , type = int , default = None )
49
- parser .add_argument ('--r' , type = int , default = None )
50
- parser .add_argument ('--hash_func' , type = str , default = "sha1" )
51
- parser .add_argument ('--hash_bits' , type = int , default = 64 )
43
+ parser .add_argument ("--ngram" , type = int , default = 5 )
44
+ parser .add_argument ("--min_length" , type = int , d
8000
efault = 5 )
45
+ parser .add_argument ("--ignore_empty" , type = bool , default = False )
46
+ parser .add_argument ("--seed" , type = int , default = 42 )
47
+ parser .add_argument ("--num_perm" , type = int , default = 250 )
48
+ parser .add_argument ("--threshold" , type = float , default = 0.7 )
49
+ parser .add_argument ("--b" , type = int , default = None )
50
+ parser .add_argument ("--r" , type = int , default = None )
51
+ parser .add_argument ("--hash_func" , type = str , default = "sha1" )
52
+ parser .add_argument ("--hash_bits" , type = int , default = 64 )
52
53
args = parser .parse_args ()
53
54
54
55
@@ -147,11 +148,17 @@ def sha1_hash(data: bytes, d: int = 32) -> int:
147
148
Generate a d-bit hash value from the given data.
148
149
"""
149
150
if d == 32 :
150
- return struct .unpack ("<I" , hashlib .sha1 (data , usedforsecurity = False ).digest ()[:4 ])[0 ]
151
+ return struct .unpack (
152
+ "<I" , hashlib .sha1 (data , usedforsecurity = False ).digest ()[:4 ]
153
+ )[0 ]
151
154
if d == 64 :
152
- return struct .unpack ("<Q" , hashlib .sha1 (data , usedforsecurity = False ).digest ()[:8 ])[0 ]
155
+ return struct .unpack (
156
+ "<Q" , hashlib .sha1 (data , usedforsecurity = False ).digest ()[:8 ]
157
+ )[0 ]
153
158
# struct is faster but does not support arbitrary bit lengths
154
- return int .from_bytes (hashlib .sha1 (data , usedforsecurity = False ).digest ()[: d // 8 ], byteorder = "little" )
159
+ return int .from_bytes (
160
+ hashlib .sha1 (data , usedforsecurity = False ).digest ()[: d // 8 ], byteorder = "little"
161
+ )
155
162
156
163
157
164
def xxh3_16hash (data : bytes , seed : int = 0 ) -> int :
@@ -240,10 +247,13 @@ def embed_func(
240
247
# split content on whitespace (NON_ALPHA regex), tokenize with ngrams(), and join these n-grams into a single space separated string.
241
248
# we then convert to lower case and then bytestrings which is then hashed. Only unique hashed n-grams are left.
242
249
tokens : set [bytes ] = {
243
- bytes (" " .join (t ).lower (), "utf-8" ) for t in ngrams (NON_ALPHA .split (content .lower ()), ngram_size , min_length )
250
+ bytes (" " .join (t ).lower (), "utf-8" )
251
+ for t in ngrams (NON_ALPHA .split (content .lower ()), ngram_size , min_length )
244
252
}
245
253
246
- hashvalues : np .ndarray = np .array ([hash_func (token ) for token in tokens ], dtype = dtype ).reshape (len (tokens ), 1 )
254
+ hashvalues : np .ndarray = np .array (
255
+ [hash_func (token ) for token in tokens ], dtype = dtype
256
+ ).reshape (len (tokens ), 1 )
247
257
# Permute the hash values to produce new universal hashes
248
258
# Element-wise multiplication with 'hashvalues' and a (non 0 random value) and then adding b
249
259
# Then, take modulo 'MODULO_PRIME' and bitwise_and with 'MAX_HASH' to keep only the necessary bits.
@@ -255,7 +265,9 @@ def embed_func(
255
265
# Originally, byteswap was done for speed. Testing show it has a negligible impact
256
266
# keeping for backward compatibility, even though theoretically and empirically
257
267
# it doesnt matter if it is there or not. github.com/ekzhu/datasketch/issues/114
258
- Hs : list [bytes ] = [bytes (hashvalues [start :end ].byteswap ().data ) for start , end in hashranges ]
268
+ Hs : list [bytes ] = [
269
+ bytes (hashvalues [start :end ].byteswap ().data ) for start , end in hashranges
270
+ ]
259
271
return {SIGNATURE_COLUMN : Hs , INDEX_COLUMN : idx }
260
272
261
273
@@ -310,13 +322,22 @@ def hash_func(byte_data):
310
322
# Loading
311
323
data_files_list = [x .strip () for x in args .data_files .split ("," )]
312
324
ds = datasets .load_dataset ("json" , data_files = data_files_list , split = "train" )
313
- ds = ds .map (lambda x , i : {INDEX_COLUMN : i }, with_indices = True , num_proc = args .num_proc )
325
+ ds = ds .map (
326
+ lambda x , i : {INDEX_COLUMN : i }, with_indices = True , num_proc = args .num_proc
327
+ )
328
+
329
+ if args .ignore_empty :
330
+ ds_rest = ds .filter (lambda x : len (x [args .column ].strip ()) == 0 )
331
+ ds = ds .filter (lambda x : len (x [args .column ].strip ()) > 0 )
332
+
314
333
ds = ds .filter (
315
334
lambda x : len (NON_ALPHA .split (x [args .column ].lower ())) >= args .min_length ,
316
335
num_proc = args .num_proc ,
317
336
)
318
337
319
338
LEN_DATASET = len (ds )
339
+ if args .ignore_empty :
340
+ LEN_DATASET += len (ds_rest )
320
341
321
342
# MinHashing
322
343
embedded = ds .map (
@@ -354,7 +375,9 @@ def hash_func(byte_data):
354
375
contiguous = True ,
355
376
writer_batch_size = args .batch_size ,
356
377
)
357
- for key , Hs in zip (embedded_shard [INDEX_COLUMN ], embedded_shard [SIGNATURE_COLUMN ]):
378
+ for key , Hs in zip (
379
+ embedded_shard [INDEX_COLUMN ], embedded_shard [SIGNATURE_COLUMN ]
380
+ ):
358
381
for i , H in enumerate (Hs ):
359
382
HASH_TABLES [i ][H ].add (key )
360
383
@@ -387,6 +410,8 @@ def hash_func(byte_data):
387
410
num_proc = args .num_proc ,
388
411
desc = "Filtering clusters..." ,
389
412
)
413
+ if args .ignore_empty and len (ds_rest ) > 0 :
414
+ final_data = datasets .concatenate_datasets ([ds_rest , final_data ])
390
415
391
416
# Saving
392
417
final_data = final_data .remove_columns ([CLUSTER_COLUMN , INDEX_COLUMN ])
0 commit comments