8000 adds support to default to tuned model (#287) · nod-ai/diffusers@48e9818 · GitHub
[go: up one dir, main page]

Skip to content

Commit 48e9818

Browse files
authored
adds support to default to tuned model (huggingface#287)
currently setup for tf bert/resnet50 going to refactor test class to avoid having to add an argument to 50+ files
1 parent 1485777 commit 48e9818

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

shark/shark_downloader.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def gs_download_model():
188188
return mlir_file, function_name, inputs_tuple, golden_out_tuple
189189

190190

191-
def download_tf_model(model_name):
191+
def download_tf_model(model_name, tuned=None):
192192
model_name = model_name.replace("/", "_")
193193
os.makedirs(WORKDIR, exist_ok=True)
194194
model_dir_name = model_name + "_tf"
@@ -230,7 +230,12 @@ def gs_download_model():
230230
)
231231

232232
model_dir = os.path.join(WORKDIR, model_dir_name)
233-
with open(os.path.join(model_dir, model_name + "_tf.mlir")) as f:
233+
suffix = "_tf.mlir" if tuned is None else "_tf_" + tuned + ".mlir"
234+
filename = os.path.join(model_dir, model_name + suffix)
235+
if not os.path.isfile(filename):
236+
filename = os.path.join(model_dir, model_name + "_tf.mlir")
237+
238+
with open(filename) as f:
234239
mlir_file = f.read()
235240

236241
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))

tank/MiniLM-L12-H384-uncased/MiniLM-L12-H384-uncased_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(
2020

2121
def create_and_check_module(self, dynamic, device):
2222
model, func_name, inputs, golden_out = download_tf_model(
23-
"microsoft/MiniLM-L12-H384-uncased"
23+
"microsoft/MiniLM-L12-H384-uncased", device
2424
)
2525

2626
shark_module = SharkInference(

0 commit comments

Comments
 (0)
0