@@ -762,7 +762,7 @@ def embed(
762
762
"""
763
763
assert self ._ctx .ctx is not None
764
764
n_embd = self .n_embd ()
765
- n_ctx = self .n_ctx ()
765
+ n_batch = self .n_batch
766
766
767
767
if self .context_params .embedding == False :
768
768
raise RuntimeError (
@@ -782,54 +782,55 @@ def embed(
782
782
783
783
# decode and fetch embeddings
784
784
data : List [List [float ]] = []
785
- def decode_batch (sizes : List [ int ] ):
785
+ def decode_batch (n_seq : int ):
786
786
assert self ._ctx .ctx is not None
787
787
llama_cpp .llama_kv_cache_clear (self ._ctx .ctx )
788
788
self ._ctx .decode (self ._batch )
789
789
self ._batch .reset ()
790
790
791
791
# store embeddings
792
- for i , s in enumerate ( sizes ):
793
- embedding = llama_cpp .llama_get_embeddings_ith (self ._ctx .ctx , i )[
792
+ for i in range ( n_seq ):
793
+ embedding : List [ float ] = llama_cpp .llama_get_embeddings_ith (self ._ctx .ctx , i )[
794
794
:n_embd
795
795
]
796
- norm = np .linalg .norm (embedding ) if normalize else s
797
- embedding : List [float ] = [v / float (norm ) for v in embedding ]
796
+ if normalize :
797
+ norm = float (np .linalg .norm (embedding ))
798
+ embedding = [v / norm for v in embedding ]
798
799
data .append (embedding )
799
800
800
801
# init state
801
802
total_tokens = 0
802
803
t_batch = 0
803
- s_sizes : List [ int ] = []
804
+ p_batch = 0
804
805
805
806
# accumulate batches and encode
806
807
for text in inputs :
807
808
tokens = self .tokenize (text .encode ("utf-8" ))
808
809
if truncate :
809
- tokens = tokens [:n_ctx ]
810
+ tokens = tokens [:n_batch ]
810
811
811
812
n_tokens = len (tokens )
812
813
total_tokens += n_tokens
813
814
814
815
# check for overrun
815
- if n_tokens > n_ctx :
816
+ if n_tokens > n_batch :
816
817
raise ValueError (
817
- f"Requested tokens ({ n_tokens } ) exceed context window of { n_ctx } "
818
+ f"Requested tokens ({ n_tokens } ) exceed batch size of { n_batch } "
818
819
)
819
820
820
821
# time to eval batch
821
- if t_batch + n_tokens > self . _n_ctx :
822
- decode_batch (s_sizes )
822
+ if t_batch + n_tokens > n_batch :
823
+ decode_batch (p_batch )
823
824
t_batch = 0
824
- s_sizes = []
825
+ p_batch = 0
825
826
826
827
# add to batch
827
- self ._batch .add_sequence (tokens , len ( s_sizes ) , False )
828
+ self ._batch .add_sequence (tokens , p_batch , False )
828
829
t_batch += n_tokens
829
- s_sizes . append ( n_tokens )
830
+ p_batch += 1
830
831
831
832
# hanlde last batch
832
- decode_batch (s_sizes )
833
+ decode_batch (p_batch )
833
834
834
835
if self .verbose :
835
836
llama_cpp .llama_print_timings (self ._ctx .ctx )
0 commit comments