|
| 1 | +import os |
| 2 | +import argparse |
| 3 | + |
| 4 | +from dataclasses import dataclass, field |
| 5 | +from typing import List, Optional |
| 6 | + |
| 7 | +# Based on https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp |
| 8 | + |
| 9 | + |
| 10 | +@dataclass |
| 11 | +class GptParams: |
| 12 | + seed: int = -1 |
| 13 | + n_threads: int = min(4, os.cpu_count() or 1) |
| 14 | + n_predict: int = 128 |
| 15 | + repeat_last_n: int = 64 |
| 16 | + n_parts: int = -1 |
| 17 | + n_ctx: int = 512 |
| 18 | + n_batch: int = 8 |
| 19 | + n_keep: int = 0 |
| 20 | + |
| 21 | + top_k: int = 40 |
| 22 | + top_p: float = 0.95 |
| 23 | + temp: float = 0.80 |
| 24 | + repeat_penalty: float = 1.10 |
| 25 | + |
| 26 | + model: str = "./models/llama-7B/ggml-model.bin" |
| 27 | + prompt: str = "" |
| 28 | + input_prefix: str = " " |
| 29 | + fix_prefix: str = "" |
| 30 | + output_postfix: str = "" |
| 31 | + input_echo: bool = True, |
| 32 | + |
| 33 | + antiprompt: List[str] = field(default_factory=list) |
| 34 | + |
| 35 | + memory_f16: bool = True |
| 36 | + random_prompt: bool = False |
| 37 | + use_color: bool = False |
| 38 | + interactive: bool = False |
| 39 | + |
| 40 | + embedding: bool = False |
| 41 | + interactive_start: bool = False |
| 42 | + |
| 43 | + instruct: bool = False |
| 44 | + ignore_eos: bool = False |
| 45 | + perplexity: bool = False |
| 46 | + use_mlock: bool = False |
| 47 | + mem_test: bool = False |
| 48 | + verbose_prompt: bool = False |
| 49 | + |
| 50 | + # Default instructions for Alpaca |
| 51 | + # switch to "Human" and "Assistant" for Vicuna. |
| 52 | + instruct_inp_prefix: str="\n\n### Instruction:\n\n", |
| 53 | + instruct_inp_suffix: str="\n\n### Response:\n\n", |
| 54 | + |
| 55 | + |
| 56 | +def gpt_params_parse(argv = None, params: Optional[GptParams] = None): |
| 57 | + if params is None: |
| 58 | + params = GptParams() |
| 59 | + |
| 60 | + parser = argparse.ArgumentParser() |
| 61 | + parser.add_argument("-h", "--help", action="store_true", help="show this help message and exit") |
| 62 | + parser.add_argument("-s", "--seed", type=int, default=-1, help="",dest="seed") |
| 63 | + parser.add_argument("-t", "--threads", type=int, default=1, help="",dest="n_threads") |
| 64 | + parser.add_argument("-p", "--prompt", type=str, default="", help="",dest="prompt") |
| 65 | + parser.add_argument("-f", "--file", type=str, default=None, help="") |
| 66 | + parser.add_argument("-c", "--ctx_size", type=int, default=512, help="",dest="n_ctx") |
| 67 | + parser.add_argument("--memory_f32", action="store_false", help="",dest="memory_f16") |
| 68 | + parser.add_argument("--top_p", type=float, default=0.9, help="",dest="top_p") |
| 69 | + parser.add_argument("--temp", type=float, default=1.0, help="",dest="temp") |
| 70 | + parser.add_argument("--repeat_last_n", type=int, default=64, help="",dest="repeat_last_n") |
| 71 | + parser.add_argument("--repeat_penalty", type=float, default=1.0, help="",dest="repeat_penalty") |
| 72 | + parser.add_argument("-b", "--batch_size", type=int, default=8, help="",dest="n_batch") |
| 73 | + parser.add_argument("--keep", type=int, default=0, help="",dest="n_keep") |
| 74 | + parser.add_argument("-m", "--model", type=str, help="",dest="model") |
| 75 | + parser.add_argument( |
| 76 | + "-i", "--interactive", action="store_true", help="run in interactive mode", dest="interactive" |
| 77 | + ) |
| 78 | + parser.add_argument("--embedding", action="store_true", help="", dest="embedding") |
| 79 | + parser.add_argument("--interactive-start", action="store_true", help="", dest="interactive_start") |
| 80 | + parser.add_argument( |
| 81 | + "--interactive-first", |
| 82 | + action="store_true", |
| 83 | + help="run in interactive mode and wait for input right away", |
| 84 | + dest="interactive" |
| 85 | + ) |
| 86 | + parser.add_argument( |
| 87 | + "-ins", |
| 88 | + "--instruct", |
| 89 | + action="store_true", |
| 90 | + help="run in instruction mode (use with Alpaca or Vicuna models)", |
| 91 | + dest="instruct" |
| 92 | + ) |
| 93 | + parser.add_argument( |
| 94 | + "--color", |
| 95 | + action="store_true", |
| 96 | + help="colorise output to distinguish prompt and user input from generations", |
| 97 | + dest="use_color" |
| 98 | + ) |
| 99 | + parser.add_argument("--mlock", action="store_true",dest="use_mlock") |
| 100 | + parser.add_argument("--mtest", action="store_true",dest="mem_test") |
| 101 | + parser.add_argument( |
| 102 | + "-r", |
| 103 | + "--reverse-prompt", |
| 104 | + type=str, |
| 105 | + action='append', |
| 106 | + help="run in interactive mode and poll user input upon seeing PROMPT (can be\nspecified more than once for multiple prompts).", |
| 107 | + dest="antiprompt" |
| 108 | + ) |
| 109 | + parser.add_argument("--perplexity", action="store_true", help="", dest="perplexity") |
| 110 | + parser.add_argument("--ignore-eos", action="store_true", help="", dest="ignore_eos") |
| 111 | + parser.add_argument("--n_parts", type=int, default=-1, help="", dest="n_parts") |
| 112 | + parser.add_argument("--random-prompt", action="store_true", help="", dest="random_prompt") |
| 113 | + parser.add_argument("--in-prefix", type=str, default=" ", help="", dest="input_prefix") |
| 114 | + parser.add_argument("--fix-prefix", type=str, default=" ", help="", dest="fix_prefix") |
| 115 | + parser.add_argument("--out-postfix", type=str, default="", help="", dest="output_postfix") |
| 116 | + parser.add_argument("--input-noecho", action="store_false", help="", dest="input_echo") |
| 117 | + args = parser.parse_args(argv) |
| 118 | + return args |
| 119 | + |
| 120 | +def gpt_random_prompt(rng): |
| 121 | + return [ |
| 122 | + "So", |
| 123 | + "Once upon a time", |
| 124 | + "When", |
| 125 | + "The", |
| 126 | + "After", |
| 127 | + "If", |
| 128 | + "import", |
| 129 | + "He", |
| 130 | + "She", |
| 131 | + "They", |
| 132 | + ][rng % 10] |
| 133 | + |
| 134 | +if __name__ == "__main__": |
| 135 | + print(GptParams(gpt_params_parse())) |
0 commit comments