8000 Harmonize inferenceProviderMapping additional parameter in modelInfo … · huggingface/huggingface.js@4e05d9e · GitHub
[go: up one dir, main page]

Skip to content

Commit 4e05d9e

Browse files
authored
Harmonize inferenceProviderMapping additional parameter in modelInfo / listModel (#1515)
Equivalent to huggingface/huggingface_hub#3022 in Python. Main problem is that `expand[]=inferenceProviderMapping` do not return the same data structure when getting model info or listing models. We will fix this in ~3 months (?) but in the meantime we want a client compatible with both current and future structure (see huggingface/huggingface_hub#3022 for more details). **TODO:** - [x] adapt call in `@huggingface/inference`
1 parent 6d68f66 commit 4e05d9e

File tree

13 files changed

+220
-36
lines changed

13 files changed

+220
-36
lines changed

packages/hub/src/lib/list-models.spec.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,23 @@ describe("listModels", () => {
115115

116116
expect(count).to.equal(10);
117117
});
118+
119+
it("should list deepseek-ai models with inference provider mapping", async () => {
120+
let count = 0;
121+
for await (const entry of listModels({
122+
search: { owner: "deepseek-ai" },
123+
additionalFields: ["inferenceProviderMapping"],
124+
limit: 1,
125+
})) {
126+
count++;
127+
expect(entry.inferenceProviderMapping).to.be.an("array").that.is.not.empty;
128+
for (const item of entry.inferenceProviderMapping ?? []) {
129+
expect(item).to.have.property("provider").that.is.a("string").and.is.not.empty;
130+
expect(item).to.have.property("hfModelId").that.is.a("string").and.is.not.empty;
131+
expect(item).to.have.property("providerId").that.is.a("string").and.is.not.empty;
132+
}
133+
}
134+
135+
expect(count).to.equal(1);
136+
});
118137
});

packages/hub/src/lib/list-models.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import type { CredentialsParams, PipelineType } from "../types/public";
55
import { checkCredentials } from "../utils/checkCredentials";
66
import { parseLinkHeader } from "../utils/parseLinkHeader";
77
import { pick } from "../utils/pick";
8+
import { normalizeInferenceProviderMapping } from "../utils/normalizeInferenceProviderMapping";
89

910
export const MODEL_EXPAND_KEYS = [
1011
"pipeline_tag",
@@ -113,8 +114,20 @@ export async function* listModels<
113114
const items: ApiModelInfo[] = await res.json();
114115

115116
for (const item of items) {
117+
// Handle inferenceProviderMapping normalization
118+
const normalizedItem = { ...item };
119+
if (
120+
(params?.additionalFields as string[])?.includes("inferenceProviderMapping") &&
121+
item.inferenceProviderMapping
122+
) {
123+
normalizedItem.inferenceProviderMapping = normalizeInferenceProviderMapping(
124+
item.id,
125+
item.inferenceProviderMapping
126+
);
127+
}
128+
116129
yield {
117-
...(params?.additionalFields && pick(item, params.additionalFields)),
130+
...(params?.additionalFields && pick(normalizedItem, params.additionalFields)),
118131
id: item._id,
119132
name: item.id,
120133
private: item.private,

packages/hub/src/lib/model-info.spec.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,20 @@ describe("modelInfo", () => {
5656
sha: "f27b190eeac4c2302d24068eabf5e9d6044389ae",
5757
});
5858
});
59+
60+
it("should return model info deepseek-ai models with inference provider mapping", async () => {
61+
const info = await modelInfo({
62+
name: "deepseek-ai/DeepSeek-R1-0528",
63+
additionalFields: ["inferenceProviderMapping"],
64+
});
65+
66+
expect(info.inferenceProviderMapping).toBeDefined();
67+
expect(info.inferenceProviderMapping).toBeInstanceOf(Array);
68+
expect(info.inferenceProviderMapping?.length).toBeGreaterThan(0);
69+
info.inferenceProviderMapping?.forEach((item) => {
70+
expect(item).toHaveProperty("provider");
71+
expect(item).toHaveProperty("hfModelId", "deepseek-ai/DeepSeek-R1-0528");
72+
expect(item).toHaveProperty("providerId");
73+
});
74+
});
5975
});

packages/hub/src/lib/model-info.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import type { ApiModelInfo } from "../types/api/api-model";
44
import type { CredentialsParams } from "../types/public";
55
import { checkCredentials } from "../utils/checkCredentials";
66
import { pick } from "../utils/pick";
7+
import { normalizeInferenceProviderMapping } from "../utils/normalizeInferenceProviderMapping";
78
import { MODEL_EXPAND_KEYS, type MODEL_EXPANDABLE_KEYS, type ModelEntry } from "./list-models";
89

910
export async function modelInfo<
@@ -48,8 +49,14 @@ export async function modelInfo<
4849

4950
const data = await response.json();
5051

52+
// Handle inferenceProviderMapping normalization
53+
const normalizedData = { ...data };
54+
if ((params?.additionalFields as string[])?.includes("inferenceProviderMapping") && data.inferenceProviderMapping) {
55+
normalizedData.inferenceProviderMapping = normalizeInferenceProviderMapping(data.id, data.inferenceProviderMapping);
56+
}
57+
5158
return {
52-
...(params?.additionalFields && pick(data, params.additionalFields)),
59+
...(params?.additionalFields && pick(normalizedData, params.additionalFields)),
5360
id: data._id,
5461
name: data.id,
5562
private: data.private,

packages/hub/src/types/api/api-model.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ export interface ApiModelInfo {
1818
downloadsAllTime: number;
1919
files: string[];
2020
gitalyUid: string;
21-
inferenceProviderMapping: Partial<
22-
Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>
23-
>;
21+
inferenceProviderMapping?: ApiModelInferenceProviderMappingEntry[];
2422
lastAuthor: { email: string; user?: string };
2523
lastModified: string; // convert to date
2624
library_name?: ModelLibraryKey;
@@ -271,3 +269,14 @@ export interface ApiModelMetadata {
271269
extra_gated_description?: string;
272270
extra_gated_button_content?: string;
273271
}
272+
273+
export interface ApiModelInferenceProviderMappingEntry {
274+
provider: string; // Provider name
275+
hfModelId: string; // ID of the model on the Hugging Face Hub
276+
providerId: string; // ID of the model on the provider's side
277+
status: "live" | "staging";
278+
task: WidgetType;
279+
adapter?: string;
280+
adapterWeightsPath?: string;
281+
type?: "single-file" | "tag-filter";
282+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import type { WidgetType } from "@huggingface/tasks";
2+
import type { ApiModelInferenceProviderMappingEntry } from "../types/api/api-model";
3+
4+
/**
5+
* Normalize inferenceProviderMapping to always return an array format.
6+
*
7+
* Little hack to simplify Inference Providers logic and make it backward and forward compatible.
8+
* Right now, API returns a dict on model-info and a list on list-models. Let's harmonize to list.
9+
*/
10+
export function normalizeInferenceProviderMapping(
11+
hfModelId: string,
12+
inferenceProviderMapping?:
13+
| ApiModelInferenceProviderMappingEntry[]
14+
| Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>
15+
): ApiModelInferenceProviderMappingEntry[] {
16+
if (!inferenceProviderMapping) {
17+
return [];
18+
}
19+
20+
// If it's already an array, return it as is
21+
if (Array.isArray(inferenceProviderMapping)) {
22+
return inferenceProviderMapping.map((entry) => ({
23+
...entry,
24+
hfModelId,
25+
}));
26+
}
27+
28+
// Convert mapping to array format
29+
return Object.entries(inferenceProviderMapping).map(([provider, mapping]) => ({
30+
provider,
31+
hfModelId,
32+
providerId: mapping.providerId,
33+
status: mapping.status,
34+
task: mapping.task,
35+
}));
36+
}

packages/inference/src/lib/getInferenceProviderMapping.ts

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,48 @@ import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../t
66
import { typedInclude } from "../utils/typedInclude.js";
77
import { InferenceClientHubApiError, InferenceClientInputError } from "../errors.js";
88

9-
export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
9+
export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMappingEntry[]>();
1010

11-
export type InferenceProviderMapping = Partial<
12-
Record<InferenceProvider, Omit<InferenceProviderModelMapping, "hfModelId">>
13-
>;
14-
15-
export interface InferenceProviderModelMapping {
11+
export interface InferenceProviderMappingEntry {
1612
adapter?: string;
1713
adapterWeightsPath?: string;
1814
hfModelId: ModelId;
15+
provider: string;
1916
providerId: string;
2017
status: "live" | "staging";
2118
task: WidgetType;
19+
type?: "single-model" | "tag-filter";
20+
}
21+
22+
/**
23+
* Normalize inferenceProviderMapping to always return an array format.
24+
* This provides backward and forward compatibility for the API changes.
25+
*
26+
* Vendored from @huggingface/hub to avoid extra dependency.
27+
*/
28+
function normalizeInferenceProviderMapping(
29+
modelId: ModelId,
30+
inferenceProviderMapping?:
31+
| InferenceProviderMappingEntry[]
32+
| Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>
33+
): InferenceProviderMappingEntry[] {
34+
if (!inferenceProviderMapping) {
35+
return [];
36+
}
37+
38+
// If it's already an array, return it as is
39+
if (Array.isArray(inferenceProviderMapping)) {
40+
return inferenceProviderMapping;
41+
}
42+
43+
// Convert mapping to array format
44+
return Object.entries(inferenceProviderMapping).map(([provider, mapping]) => ({
45+
provider,
46+
hfModelId: modelId,
47+
providerId: mapping.providerId,
48+
status: mapping.status,
49+
task: mapping.task,
50+
}));
2251
}
2352

2453
export async function fetchInferenceProviderMappingForModel(
@@ -27,8 +56,8 @@ export async function fetchInferenceProviderMappingForModel(
2756
options?: {
2857
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
2958
}
30-
): Promise<InferenceProviderMapping> {
31-
let inferenceProviderMapping: InferenceProviderMapping | null;
59+
): Promise<InferenceProviderMappingEntry[]> {
60+
let inferenceProviderMapping: InferenceProviderMappingEntry[] | null;
3261
if (inferenceProviderMappingCache.has(modelId)) {
3362
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
3463
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
@@ -55,7 +84,11 @@ export async function fetchInferenceProviderMappingForModel(
5584
);
5685
}
5786
}
58-
let payload: { inferenceProviderMapping?: InferenceProviderMapping } | null = null;
87+
let payload: {
88+
inferenceProviderMapping?:
89+
| InferenceProviderMappingEntry[]
90+
| Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>;
91+
} | null = null;
5992
try {
6093
payload = await resp.json();
6194
} catch {
@@ -72,7 +105,8 @@ export async function fetchInferenceProviderMappingForModel(
72105
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
73106
);
74107
}
75-
inferenceProviderMapping = payload.inferenceProviderMapping;
108+
inferenceProviderMapping = normalizeInferenceProviderMapping(modelId, payload.inferenceProviderMapping);
109+
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
76110
}
77111
return inferenceProviderMapping;
78112
}
@@ -87,16 +121,12 @@ export async function getInferenceProviderMapping(
87121
options: {
88122
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
89123
}
90-
): Promise<InferenceProviderModelMapping | null> {
124+
): Promise<InferenceProviderMappingEntry | null> {
91125
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
92126
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
93127
}
94-
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
95-
params.modelId,
96-
params.accessToken,
97-
options
98-
);
99-
const providerMapping = inferenceProviderMapping[params.provider];
128+
const mappings = await fetchInferenceProviderMappingForModel(params.modelId, params.accessToken, options);
129+
const providerMapping = mappings.find((mapping) => mapping.provider === params.provider);
100130
if (providerMapping) {
101131
const equivalentTasks =
102132
params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task)
@@ -112,7 +142,7 @@ export async function getInferenceProviderMapping(
112142
`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
113143
);
114144
}
115-
return { ...providerMapping, hfModelId: params.modelId };
145+
return providerMapping;
116146
}
117147
return null;
118148
}
@@ -139,8 +169,8 @@ export async function resolveProvider(
139169
if (!modelId) {
140170
throw new InferenceClientInputError("Specifying a model is required when provider is 'auto'");
141171
}
142-
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
143-
provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider | undefined;
172+
const mappings = await fetchInferenceProviderMappingForModel(modelId);
173+
provider = mappings[0]?.provider as InferenceProvider | undefined;
144174
}
145175
if (!provider) {
146176
throw new InferenceClientInputError(`No Inference Provider available for model ${modelId}.`);

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { HF_HEADER_X_BILL_TO, HF_HUB_URL } from "../config.js";
22
import { PACKAGE_NAME, PACKAGE_VERSION } from "../package.js";
33
import type { InferenceTask, Options, RequestArgs } from "../types.js";
4-
import type { InferenceProviderModelMapping } from "./getInferenceProviderMapping.js";
4+
import type { InferenceProviderMappingEntry } from "./getInferenceProviderMapping.js";
55
import { getInferenceProviderMapping } from "./getInferenceProviderMapping.js";
66
import type { getProviderHelper } from "./getProviderHelper.js";
77
import { isUrl } from "./isUrl.js";
@@ -64,14 +64,15 @@ export async function makeRequestOptions(
6464

6565
const inferenceProviderMapping = providerHelper.clientSideRoutingOnly
6666
? ({
67+
provider: provider,
6768
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
6869
providerId: removeProviderPrefix(maybeModel!, provider),
6970
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
7071
hfModelId: maybeModel!,
7172
status: "live",
7273
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
7374
task: task!,
74-
} satisfies InferenceProviderModelMapping)
75+
} satisfies InferenceProviderMappingEntry)
7576
: await getInferenceProviderMapping(
7677
{
7778
modelId: hfModel,
@@ -109,7 +110,7 @@ export function makeRequestOptionsFromResolvedModel(
109110
data?: Blob | ArrayBuffer;
110111
stream?: boolean;
111112
},
112-
mapping: InferenceProviderModelMapping | undefined,
113+
mapping: InferenceProviderMappingEntry | undefined,
113114
options?: Options & {
114115
task?: InferenceTask;
115116
}

packages/inference/src/providers/consts.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping.js";
1+
import type { InferenceProviderMappingEntry } from "../lib/getInferenceProviderMapping.js";
22
import type { InferenceProvider } from "../types.js";
33
import { type ModelId } from "../types.js";
44

@@ -11,7 +11,7 @@ import { type ModelId } from "../types.js";
1111
*/
1212
export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
1313
InferenceProvider,
14-
Record<ModelId, InferenceProviderModelMapping>
14+
Record<ModelId, InferenceProviderMappingEntry>
1515
> = {
1616
/**
1717
* "HF model ID" => "Model ID on Inference Provider's side"

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import {
88
} from "@huggingface/tasks";
99
import type { PipelineType, WidgetType } from "@huggingface/tasks";
1010
import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingface/tasks";
11-
import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping.js";
11+
import type { InferenceProviderMappingEntry } from "../lib/getInferenceProviderMapping.js";
1212
import { getProviderHelper } from "../lib/getProviderHelper.js";
1313
import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions.js";
1414
import type { InferenceProviderOrPolicy, InferenceTask, RequestArgs } from "../types.js";
@@ -138,7 +138,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
138138
return (
139139
model: ModelDataMinimal,
140140
provider: InferenceProviderOrPolicy,
141-
inferenceProviderMapping?: InferenceProviderModelMapping,
141+
inferenceProviderMapping?: InferenceProviderMappingEntry,
142142
opts?: InferenceSnippetOptions
143143
): InferenceSnippet[] => {
144144
const providerModelId = inferenceProviderMapping?.providerId ?? model.id;
@@ -331,7 +331,7 @@ const snippets: Partial<
331331
(
332332
model: ModelDataMinimal,
333333
provider: InferenceProviderOrPolicy,
334-
inferenceProviderMapping?: InferenceProviderModelMapping,
334+
inferenceProviderMapping?: InferenceProviderMappingEntry,
335335
opts?: InferenceSnippetOptions
336336
) => InferenceSnippet[]
337337
>
@@ -370,7 +370,7 @@ const snippets: Partial<
370370
export function getInferenceSnippets(
371371
model: ModelDataMinimal,
372372
provider: InferenceProviderOrPolicy,
373-
inferenceProviderMapping?: InferenceProviderModelMapping,
373+
inferenceProviderMapping?: InferenceProviderMappingEntry,
374374
opts?: Record<string, unknown>
375375
): InferenceSnippet[] {
376376
return model.pipeline_tag && model.pipeline_tag in snippets

0 commit comments

Comments
 (0)
0