|
50 | 50 |
|
51 | 51 | import pickle
|
52 | 52 | import re
|
| 53 | +import shutil |
53 | 54 | import sys
|
54 | 55 |
|
55 | 56 | from annoy import AnnoyIndex
|
@@ -88,12 +89,14 @@ def query_model(query, model, indices, language, topk=100):
|
88 | 89 | sys.exit(1)
|
89 | 90 | wandb_api = wandb.Api()
|
90 | 91 | # retrieve saved model from W&B for this run
|
| 92 | + print("Fetching run from W&B...") |
91 | 93 | try:
|
92 | 94 | run = wandb_api.run(args_wandb_run_id)
|
93 | 95 | except wandb.CommError as e:
|
94 | 96 | print("ERROR: Problem querying W&B for wandb_run_id: %s" % args_wandb_run_id, file=sys.stderr)
|
95 | 97 | sys.exit(1)
|
96 | 98 |
|
| 99 | + print("Fetching run files from W&B...") |
97 | 100 | gz_run_files = [f for f in run.files() if f.name.endswith('gz')]
|
98 | 101 | if not gz_run_files:
|
99 | 102 | print("ERROR: Run contains no model-like files")
|
@@ -129,10 +132,18 @@ def query_model(query, model, indices, language, topk=100):
|
129 | 132 | df = pd.DataFrame(predictions, columns=['query', 'language', 'identifier', 'url'])
|
130 | 133 | df.to_csv(predictions_csv, index=False)
|
131 | 134 |
|
| 135 | + |
132 | 136 | if run_id:
|
| 137 | + print('Uploading predictions to W&B') |
133 | 138 | # upload model predictions CSV file to W&B
|
134 | 139 |
|
135 | 140 | # we checked that there are three path components above
|
136 | 141 | entity, project, name = args_wandb_run_id.split('/')
|
| 142 | + |
| 143 | + # make sure the file is in our cwd, with the correct name |
| 144 | + predictions_base_csv = "model_predictions.csv" |
| 145 | + shutil.copyfile(predictions_csv, predictions_base_csv) |
| 146 | + |
| 147 | + # Using internal wandb API. TODO: Update when available as a public API |
137 | 148 | internal_api = InternalApi()
|
138 |
| - internal_api.push([predictions_csv], run=name, entity=entity, project=project) |
| 149 | + internal_api.push([predictions_base_csv], run=name, entity=entity, project=project) |
0 commit comments