8000 Add `OpenRouterModel` by DanKing1903 · Pull Request #1870 · pydantic/pydantic-ai · GitHub
[go: up one dir, main page]

Skip to content

Add OpenRouterModel #1870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
'OpenAIModelSettings',
'OpenAIResponsesModelSettings',
'OpenAIModelName',
'OpenAISystemPromptRole',
)

OpenAIModelName = Union[str, ChatModel]
Expand Down
105 changes: 105 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/openrouter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Any, Literal, overload

from openai import AsyncStream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from pydantic import BaseModel

from .. import ModelHTTPError
from ..messages import ModelMessage, ModelResponse
from ..profiles import ModelProfileSpec
from ..providers.openrouter import OpenRouterProvider
from . import ModelRequestParameters
from .openai import OpenAIModel, OpenAIModelName, OpenAIModelSettings, OpenAISystemPromptRole


class OpenRouterErrorResponse(BaseModel):
"""Represents error responses from upstream LLM provider relayed by OpenRouter.

Attributes:
code: The error code returned by LLM provider.
message: The error message returned by OpenRouter
metadata: Additional error context provided by OpenRouter.

See: https://openrouter.ai/docs/api-reference/errors
"""

code: int
message: str
metadata: dict[str, Any] | None


class OpenRouterChatCompletion(ChatCompletion):
"""Extends ChatCompletion with OpenRouter-specific attributes.

This class extends the base ChatCompletion model to include additional
fields returned specifically by the OpenRouter API.

Attributes:
provider: The name of the upstream LLM provider (e.g., "Anthropic",
"OpenAI", etc.) that processed the request through OpenRouter.
"""

provider: str


class OpenRouterModel(OpenAIModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please override __init__ to automatically use the OpenRouterProvider as well when this model is used?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DouweM I have added an __init__ method with narrowed type annotations for provider that does nothing apart from calling super().__init__. Is this acceptable or would you like me to add something like:

        if isinstance(provider, str):
            if provider != "openrouter":
                error_msg = ...
                raise ValueError(error_msg)
            provider = OpenRouterProvider()

"""Extends OpenAIModel to capture extra metadata for Openrouter."""

def __init__(
self,
model_name: OpenAIModelName,
*,
provider: Literal['openrouter'] | OpenRouterProvider = 'openrouter',
profile: ModelProfileSpec | None = None,
system_prompt_role: OpenAISystemPromptRole | None = None,
):
super().__init__(model_name, provider=provider, profile=profile, system_prompt_role=system_prompt_role)

@overload
async def _completions_create(
self,
messages: list[ModelMessage],
stream: Literal[True],
model_settings: OpenAIModelSettings,
model_request_parameters: ModelRequestParameters,
) -> AsyncStream[ChatCompletionChunk]: ...

@overload
async def _completions_create(
self,
messages: list[ModelMessage],
stream: Literal[False],
model_settings: OpenAIModelSettings,
model_request_parameters: ModelRequestParameters,
) -> ChatCompletion: ...

async def _completions_create(
self,
messages: list[ModelMessage],
stream: bool,
model_settings: OpenAIModelSettings,
model_request_parameters: ModelRequestParameters,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
response = await super()._completions_create(
messages=messages,
stream=stream,
model_settings=model_settings,
model_request_parameters=model_request_parameters,
)
if error := getattr(response, 'error', None):
parsed_error = OpenRouterErrorResponse.model_validate(error)
raise ModelHTTPError(
status_code=parsed_error.code, model_name=self.model_name, body=parsed_error.model_dump()
)
else:
return response

def _process_response(self, response: ChatCompletion) -> ModelResponse:
response = OpenRouterChatCompletion.construct(**response.model_dump())
model_response = super()._process_response(response=response)
openrouter_provider: str | None = getattr(response, 'provider', None)
if openrouter_provider:
vendor_details: dict[str, Any] = model_response.vendor_details or {}
vendor_details['provider'] = openrouter_provider
model_response.vendor_details = vendor_details
return model_response
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
interactions:
- request:
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '164'
content-type:
- application/json
host:
- openrouter.ai
method: POST
parsed_body:
messages:
- content: Be helpful.
role: system
- content: Tell me a joke.
role: user
model: google/gemini-2.0-flash-exp:free
n: 1
stream: false
uri: https://openrouter.ai/api/v1/chat/completions
response:
headers:
access-control-allow-origin:
- '*'
connection:
- keep-alive
content-length:
- '242'
content-type:
- application/json
vary:
- Accept-Encoding
parsed_body:
error:
code: 429 # Upstream LLM provider error
message: Provider returned error
metadata:
provider_name: Google
raw: google/gemini-2.0-flash-exp:free is temporarily rate-limited upstream; please retry shortly.
user_id: user_2uRh0l3Yi3hdjBArTOSmLXWJBc4
status:
code: 200 # Openrouter returns 200 OK
message: OK
version: 1
27 changes: 27 additions & 0 deletions tests/models/test_openrouter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
from inline_snapshot import snapshot

from pydantic_ai import Agent, ModelHTTPError

from ..conftest import try_import

with try_import() as imports_successful:
from pydantic_ai.models.openrouter import OpenRouterModel
from pydantic_ai.providers.openrouter import OpenRouterProvider

pytestmark = [
pytest.mark.skipif(not imports_successful(), reason='openai not installed'),
pytest.mark.vcr,
pytest.mark.anyio,
]


async def test_openrouter_errors_raised(allow_model_requests: None, openrouter_api_key: str) -> None:
provider = OpenRouterProvider(api_key=openrouter_api_key)
model = OpenRouterModel('google/gemini-2.0-flash-exp:free', provider=provider)
agent = Agent(model, instructions='Be helpful.', retries=1)
with pytest.raises(ModelHTTPError) as exc_info:
await agent.run('Tell me a joke.')
assert str(exc_info.value) == snapshot(
"status_code: 429, model_name: google/gemini-2.0-flash-exp:free, body: {'code': 429, 'message': 'Provider returned error', 'metadata': {'provider_name': 'Google', 'raw': 'google/gemini-2.0-flash-exp:free is temporarily rate-limited upstream; please retry shortly.'}}"
)
40 changes: 36 additions & 4 deletions tests/providers/test_openrouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from pydantic_ai.agent import Agent
from pydantic_ai.exceptions import UserError
from pydantic_ai.messages import ModelRequest, ModelResponse, TextPart, UserPromptPart
from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer
from pydantic_ai.profiles.amazon import amazon_model_profile
from pydantic_ai.profiles.anthropic import anthropic_model_profile
Expand All @@ -18,13 +19,14 @@
from pydantic_ai.profiles.mistral import mistral_model_profile
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, openai_model_profile
from pydantic_ai.profiles.qwen import qwen_model_profile
from pydantic_ai.usage import Usage

from ..conftest import TestEnv, try_import
from ..conftest import IsDatetime, IsStr, TestEnv, try_import

with try_import() as imports_successful:
import openai

from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.models.openrouter import OpenRouterModel
from pydantic_ai.providers.openrouter import OpenRouterProvider


Expand Down Expand Up @@ -69,15 +71,45 @@ def test_openrouter_pass_openai_client() -> None:

async def test_openrouter_with_google_model(allow_model_requests: None, openrouter_api_key: str) -> None:
provider = OpenRouterProvider(api_key=openrouter_api_key)
model = OpenAIModel('google/gemini-2.0-flash-exp:free', provider=provider)
agent = Agent(model, instructions='Be helpful.')
model = OpenRouterModel('google/gemini-2.0-flash-exp:free', provider=provider)
agent = Agent(model, instructions='Be helpful.', retries=1)
response = await agent.run('Tell me a joke.')
assert response.output == snapshot("""\
Why don't scientists trust atoms? \n\

Because they make up everything!
""")

assert response.all_messages() == snapshot(
[
ModelRequest(
parts=[
UserPromptPart(
content='Tell me a joke.',
timestamp=IsDatetime(iso_string=True),
)
],
instructions='Be helpful.',
),
ModelResponse(
parts=[
TextPart(
content="""\
Why don't scientists trust atoms? \n\

Because they make up everything!
"""
)
],
usage=Usage(requests=1, request_tokens=8, response_tokens=17, total_tokens=25, details={}),
model_name='google/gemini-2.0-flash-exp:free',
timestamp=IsDatetime(iso_string=True),
vendor_details={'provider': 'Google'},
vendor_id=IsStr(),
),
]
)


def test_openrouter_provider_model_profile(mocker: MockerFixture):
provider = OpenRouterProvider(api_key='api-key')
Expand Down
Loading
0