8000 Fix crash due to file in parent sub-dir. · github/CodeSearchNet@901fdfd · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Apr 11, 2023. It is now read-only.

Commit 901fdfd

Browse files
committed
Fix crash due to file in parent sub-dir.
1 parent 7f7416f commit 901fdfd

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

src/predict.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
import pickle
5252
import re
53+
import shutil
5354
import sys
5455

5556
from annoy import AnnoyIndex
@@ -88,12 +89,14 @@ def query_model(query, model, indices, language, topk=100):
8889
sys.exit(1)
8990
wandb_api = wandb.Api()
9091
# retrieve saved model from W&B for this run
92+
print("Fetching run from W&B...")
9193
try:
9294
run = wandb_api.run(args_wandb_run_id)
9395
except wandb.CommError as e:
9496
print("ERROR: Problem querying W&B for wandb_run_id: %s" % args_wandb_run_id, file=sys.stderr)
9597
sys.exit(1)
9698

99+
print("Fetching run files from W&B...")
97100
gz_run_files = [f for f in run.files() if f.name.endswith('gz')]
98101
if not gz_run_files:
99102
print("ERROR: Run contains no model-like files")
@@ -129,10 +132,18 @@ def query_model(query, model, indices, language, topk=100):
129132
df = pd.DataFrame(predictions, columns=['query', 'language', 'identifier', 'url'])
130133
df.to_csv(predictions_csv, index=False)
131134

135+
132136
if run_id:
137+
print('Uploading predictions to W&B')
133138
# upload model predictions CSV file to W&B
134139

135140
# we checked that there are three path components above
136141
entity, project, name = args_wandb_run_id.split('/')
142+
143+
# make sure the file is in our cwd, with the correct name
144+
predictions_base_csv = "model_predictions.csv"
145+
shutil.copyfile(predictions_csv, predictions_base_csv)
146+
147+
# Using internal wandb API. TODO: Update when available as a public API
137148
internal_api = InternalApi()
138-
internal_api.push([predictions_csv], run=name, entity=entity, project=project)
149+
internal_api.push([predictions_base_csv], run=name, entity=entity, project=project)

0 commit comments

Comments
 (0)
0