8000 remove Tail-Free sampling, https://github.com/ggml-org/llama.cpp/pull… · JamePeng/llama-cpp-python@fef4f25 · GitHub
[go: up one dir, main page]

Skip to content

Commit fef4f25

Browse files
committed
remove Tail-Free sampling, ggml-org/llama.cpp#10071
more top_n_sigma、xtc_threshold: float = 0.1、xtc_probability: float params
1 parent d984742 commit fef4f25

File tree

5 files changed

+93
-57
lines changed
  • llama_cpp
  • 5 files changed

    +93
    -57
    lines changed

    examples/low_level_api/common.py

    Lines changed: 23 additions & 9 deletions
    Original file line numberDiff line numberDiff line change
    @@ -21,8 +21,9 @@ class GptParams:
    2121
    ignore_eos: bool = False
    2222
    logit_bias: dict[int, float] = field(default_factory=dict)
    2323
    top_k: int = 40
    24+
    top_n_sigma: float = -1.00
    2425
    top_p: float = 0.95
    25-
    tfs_z: float = 1.00
    26+
    2627
    typical_p: float = 1.00
    2728
    temp: float = 0.80
    2829
    repeat_penalty: float = 1.10
    @@ -32,7 +33,8 @@ class GptParams:
    3233
    mirostat: int = 0
    3334
    mirostat_tau: float = 5.0
    3435
    mirostat_eta: float = 0.1
    35-
    36+
    xtc_threshold: float = 0.1
    37+
    xtc_probability: float = 0.0
    3638
    model: str = "./models/llama-7B/ggml-model.bin"
    3739
    prompt: str = ""
    3840
    path_session: str = ""
    @@ -147,14 +149,10 @@ def gpt_params_parse(argv=None):
    147149
    "--top_k", type=int, default=40, help="top-k sampling", dest="top_k"
    148150
    )
    149151
    parser.add_argument(
    150-
    "--top_p", type=float, default=0.95, help="top-p samplin", dest="top_p"
    152+
    "--top_n_sigma", type=int, default=40, help="top-n-sigma sampling", dest="top_n_sigma"
    151153
    )
    152154
    parser.add_argument(
    153-
    "--tfs",
    154-
    type=float,
    155-
    default=1.0,
    156-
    help="tail free sampling, parameter z (1.0 = disabled)",
    157-
    dest="tfs_z",
    155+
    "--top_p", type=float, default=0.95, help="top-p samplin", dest="top_p"
    158156
    )
    159157
    parser.add_argument(
    160158
    "--temp", type=float, default=0.80, help="temperature", dest="temp"
    @@ -178,7 +176,7 @@ def gpt_params_parse(argv=None):
    178176
    type=float,
    179177
    default=0.0,
    180178
    help="repeat alpha frequency penalty (0.0 = disabled)",
    181-
    dest="tfs_z",
    179+
    dest="frequency_penalty",
    182180
    )
    183181
    parser.add_argument(
    184182
    "--presence_penalty",
    @@ -209,6 +207,22 @@ def gpt_params_parse(argv=None):
    209207
    dest="mirostat_eta",
    210208
    )
    211209

    210+
    parser.add_argument(
    211+
    "--xtc_threshold",
    212+
    type=float,
    213+
    default=0.1,
    214+
    help=" Sets a minimum probability threshold for tokens to be removed (default: 0.1)",
    215+
    dest="xtc_threshold",
    216+
    )
    217+
    218+
    parser.add_argument(
    219+
    "--xtc_probability",
    220+
    type=float,
    221+
    default=0.0,
    222+
    help="ets the chance for token removal (checked once on sampler start) (default: 0.0)",
    223+
    dest="xtc_probability",
    224+
    )
    225+
    212226
    parser.add_argument(
    213227
    "-m",
    214228
    "--model",

    examples/low_level_api/low_level_api_chat_cpp.py

    Lines changed: 15 additions & 12 deletions
    Original file line numberDiff line numberDiff line change
    @@ -275,14 +275,17 @@ def __init__(self, params: GptParams) -> None:
    275275
    presence_penalty = {self.params.presence_penalty},\
    276276
    frequency_penalty = {self.params.frequency_penalty},\
    277277
    top_k = {self.params.top_k},\
    278-
    tfs_z = {self.params.tfs_z},\
    278+
    top_n_sigma = {self.params.top_n_sigma},\
    279279
    top_p = {self.params.top_p},\
    280280
    typical_p = {self.params.typical_p},\
    281281
    temp = {self.params.temp},\
    282282
    mirostat = {self.params.mirostat},\
    283283
    mirostat_lr = {self.params.mirostat_eta},\
    284284
    mirostat_ent = {self.params.mirostat_tau},\
    285285
    286+
    xtc_threshold = {self.params.xtc_threshold},\
    287+
    xtc_probability = {self.params.xtc_probability},\
    288+
    286289
    generate: n_ctx = {self.n_ctx},\
    287290
    n_batch = {self.params.n_batch},\
    288291
    n_predict = {self.params.n_predict},\
    @@ -454,7 +457,7 @@ def generate(self):
    454457
    _arr = (llama_cpp.llama_token * last_n_repeat)(
    455458
    *self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat :]
    456459
    )
    457-
    llama_cpp.llama_sample_repetition_penalties(
    460+
    llama_cpp.llama_sampler_init_penalties(
    458461
    ctx=self.ctx,
    459462
    candidates=candidates_p,
    460463
    last_tokens_data=_arr,
    @@ -474,15 +477,15 @@ def generate(self):
    474477

    475478
    if self.params.temp <= 0:
    476479
    # Greedy sampling
    477-
    id = llama_cpp.llama_sample_token_greedy(self.ctx, candidates_p)
    480+
    id = llama_cpp.llama_sampler_init_greedy(self.ctx, candidates_p)
    478481
    else:
    479482
    if self.params.mirostat == 1:
    480483
    mirostat_mu = 2.0 * self.params.mirostat_tau
    481484
    mirostat_m = 100
    482-
    llama_cpp.llama_sample_temperature(
    485+
    llama_cpp.llama_sampler_init_temp(
    483486
    self.ctx, candidates_p, llama_cpp.c_float(self.params.temp)
    484487
    )
    485-
    id = llama_cpp.llama_sample_token_mirostat(
    488+
    id = llama_cpp.llama_sampler_init_mirostat(
    486489
    self.ctx,
    487490
    candidates_p,
    488491
    llama_cpp.c_float(self.params.mirostat_tau),
    @@ -495,7 +498,7 @@ def generate(self):
    495498
    llama_cpp.llama_sample_temperature(
    496499
    self.ctx, candidates_p, llama_cpp.c_float(self.params.temp)
    497500
    )
    498-
    id = llama_cpp.llama_sample_token_mirostat_v2(
    501+
    id = llama_cpp.llama_sampler_init_mirostat_v2(
    499502
    self.ctx,
    500503
    candidates_p,
    501504
    llama_cpp.c_float(self.params.mirostat_tau),
    @@ -504,31 +507,31 @@ def generate(self):
    504507
    )
    505508
    else:
    506509
    # Temperature sampling
    507-
    llama_cpp.llama_sample_top_k(
    510+
    llama_cpp.llama_sampler_init_top_k(
    508511
    self.ctx,
    509512
    candidates_p,
    510513
    top_k,
    511514
    min_keep=llama_cpp.c_size_t(1),
    512515
    )
    513-
    llama_cpp.llama_sample_tail_free(
    516+
    llama_cpp.llama_sampler_init_top_n_sigma(
    514517
    self.ctx,
    515518
    candidates_p,
    516-
    llama_cpp.c_float(self.params.tfs_z),
    519+
    llama_cpp.c_float(self.params.top_n_sigma),
    517520
    min_keep=llama_cpp.c_size_t(1),
    518521
    )
    519-
    llama_cpp.llama_sample_typical(
    522+
    llama_cpp.llama_sampler_init_typical(
    520523
    self.ctx,
    521524
    candidates_p,
    522525
    llama_cpp.c_float(self.params.typical_p),
    523526
    min_keep=llama_cpp.c_size_t(1),
    524527
    )
    525-
    llama_cpp.llama_sample_top_p(
    528+
    llama_cpp.llama_sampler_init_top_p(
    526529
    self.ctx,
    527530
    candidates_p,
    528531
    llama_cpp.c_float(self.params.top_p),
    529532
    min_keep=llama_cpp.c_size_t(1),
    530533
    )
    531-
    llama_cpp.llama_sample_temperature(
    534+
    llama_cpp.llama_sampler_init_temp(
    532535
    self.ctx, candidates_p, llama_cpp.c_float(self.params.temp)
    533536
    )
    534537
    id = llama_cpp.llama_sample_token(self.ctx, candidates_p)

    llama_cpp/_internals.py

    Lines changed: 1 addition & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -570,9 +570,9 @@ class LlamaSamplingParams:
    570570
    n_prev: int = 64
    571571
    n_probs: int = 0
    572572
    top_k: int = 40
    573+
    top_n_sigma: float = -1.00
    573574
    top_p: float = 0.95
    574575
    min_p: float = 0.05
    575-
    tfs_z: float = 1.00
    576576
    typical_p: float = 1.00
    577577
    temp: float = 0.80
    578578
    penalty_last_n: int = 64

    llama_cpp/llama.py

    Lines changed: 0 additions & 17 deletions
    Original file line numberDiff line numberDiff line change
    @@ -677,7 +677,6 @@ def _init_sampler(
    677677
    repeat_penalty: float = 1.0,
    678678
    frequency_penalty: float = 0.0,
    679679
    presence_penalty: float = 0.0,
    680-
    tfs_z: float = 1.0,
    681680
    mirostat_mode: int = 0,
    682681
    mirostat_eta: float = 0.1,
    683682
    mirostat_tau: float = 5.0,
    @@ -771,7 +770,6 @@ def sample(
    771770
    repeat_penalty: float = 1.0,
    772771
    frequency_penalty: float = 0.0,
    773772
    presence_penalty: float = 0.0,
    774-
    tfs_z: float = 1.0,
    775773
    mirostat_mode: int = 0,
    776774
    mirostat_eta: float = 0.1,
    777775
    mirostat_tau: float = 5.0,
    @@ -809,7 +807,6 @@ def sample(
    809807
    repeat_penalty=repeat_penalty,
    810808
    frequency_penalty=frequency_penalty,
    811809
    presence_penalty=presence_penalty,
    812-
    tfs_z=tfs_z,
    813810
    mirostat_mode=mirostat_mode,
    814811
    mirostat_tau=mirostat_tau,
    815812
    mirostat_eta=mirostat_eta,
    @@ -841,7 +838,6 @@ def generate(
    841838
    reset: bool = True,
    842839
    frequency_penalty: float = 0.0,
    843840
    presence_penalty: float = 0.0,
    844-
    tfs_z: float = 1.0,
    845841
    mirostat_mode: int = 0,
    846842
    mirostat_tau: float = 5.0,
    847843
    mirostat_eta: float = 0.1,
    @@ -883,7 +879,6 @@ def generate(
    883879
    repeat_penalty=repeat_penalty,
    884880
    frequency_penalty=frequency_p 1CF5 enalty,
    885881
    presence_penalty=presence_penalty,
    886-
    tfs_z=tfs_z,
    887882
    mirostat_mode=mirostat_mode,
    888883
    mirostat_tau=mirostat_tau,
    889884
    mirostat_eta=mirostat_eta,
    @@ -938,7 +933,6 @@ def generate(
    938933
    repeat_penalty=repeat_penalty,
    939934
    frequency_penalty=frequency_penalty,
    940935
    presence_penalty=presence_penalty,
    941-
    tfs_z=tfs_z,
    942936
    mirostat_mode=mirostat_mode,
    943937
    mirostat_tau=mirostat_tau,
    944938
    mirostat_eta=mirostat_eta,
    @@ -1157,7 +1151,6 @@ def _create_completion(
    11571151
    top_n_sigma: float = -1.00,
    11581152
    stream: bool = False,
    11591153
    seed: Optional[int] = None,
    1160-
    tfs_z: float = 1.0,
    11611154
    mirostat_mode: int = 0,
    11621155
    mirostat_tau: float = 5.0,
    11631156
    mirostat_eta: float = 0.1,
    @@ -1348,7 +1341,6 @@ def logit_bias_processor(
    13481341
    min_p=min_p,
    13491342
    typical_p=typical_p,
    13501343
    temp=temperature,
    1351-
    tfs_z=tfs_z,
    13521344
    mirostat_mode=mirostat_mode,
    13531345
    mirostat_tau=mirostat_tau,
    13541346
    mirostat_eta=mirostat_eta,
    @@ -1783,7 +1775,6 @@ def create_completion(
    17831775
    top_n_sigma: float = -1.00,
    17841776
    stream: bool = False,
    17851777
    seed: Optional[int] = None,
    1786-
    tfs_z: float = 1.0,
    17871778
    mirostat_mode: int = 0,
    17881779
    mirostat_tau: float = 5.0,
    17891780
    mirostat_eta: float = 0.1,
    @@ -1815,7 +1806,6 @@ def create_completion(
    18151806
    top_n_sigma: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1.00, -1.00 = disabled).
    18161807
    stream: Whether to stream the results.
    18171808
    seed: The seed to use for sampling.
    1818-
    tfs_z: The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
    18191809
    mirostat_mode: The mirostat sampling mode.
    18201810
    mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
    18211811
    mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
    @@ -1852,7 +1842,6 @@ def create_completion(
    18521842
    top_n_sigma=top_n_sigma,
    18531843
    stream=stream,
    18541844
    seed=seed,
    1855-
    tfs_z=tfs_z,
    18561845
    mirostat_mode=mirostat_mode,
    18571846
    mirostat_tau=mirostat_tau,
    18581847
    mirostat_eta=mirostat_eta,
    @@ -1889,7 +1878,6 @@ def __call__(
    18891878
    top_n_sigma: float = -1.00,
    18901879
    stream: bool = False,
    18911880
    seed: Optional[int] = None,
    1892-
    tfs_z: float = 1.0,
    18931881
    mirostat_mode: int = 0,
    18941882
    mirostat_tau: float = 5.0,
    18951883
    mirostat_eta: float = 0.1,
    @@ -1921,7 +1909,6 @@ def __call__(
    19211909
    top_n_sigma: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1.00, -1.00 = disabled).
    19221910
    stream: Whether to stream the results.
    19231911
    seed: The seed to use for sampling.
    1924-
    tfs_z: The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
    19251912
    mirostat_mode: The mirostat sampling mode.
    19261913
    mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
    19271914
    mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
    @@ -1958,7 +1945,6 @@ def __call__(
    19581945
    top_n_sigma=top_n_sigma,
    19591946
    stream=stream,
    19601947
    seed=seed,
    1961-
    tfs_z=tfs_z,
    19621948
    mirostat_mode=mirostat_mode,
    19631949
    mirostat_tau=mirostat_tau,
    19641950
    mirostat_eta=mirostat_eta,
    @@ -1992,7 +1978,6 @@ def create_chat_completion(
    19921978
    presence_penalty: float = 0.0,
    19931979
    frequency_penalty: float = 0.0,
    19941980
    repeat_penalty: float = 1.0,
    1995-
    tfs_z: float = 1.0,
    19961981
    mirostat_mode: int = 0,
    19971982
    mirostat_tau: float = 5.0,
    19981983
    mirostat_eta: float = 0.1,
    @@ -2029,7 +2014,6 @@ def create_chat_completion(
    20292014
    presence_penalty: The penalty to apply to tokens based on their presence in the prompt.
    20302015
    frequency_penalty: The penalty to apply to tokens based on their frequency in the prompt.
    20312016
    repeat_penalty: The penalty to apply to repeated tokens.
    2032-
    tfs_z: The tail-free sampling parameter.
    20332017
    mirostat_mode: The mirostat sampling mode.
    20342018
    mirostat_tau: The mirostat sampling tau parameter.
    20352019
    mirostat_eta: The mirostat sampling eta parameter.
    @@ -2071,7 +2055,6 @@ def create_chat_completion(
    20712055
    presence_penalty=presence_penalty,
    20722056
    frequency_penalty=frequency_penalty,
    20732057
    repeat_penalty=repeat_penalty,
    2074-
    tfs_z=tfs_z,
    20752058
    mirostat_mode=mirostat_mode,
    20762059
    mirostat_tau=mirostat_tau,
    20772060
    mirostat_eta=mirostat_eta,

    0 commit comments

    Comments
     (0)
    0