8000 chore(internal): add support for TypeAliasType (#786) · anthropics/anthropic-sdk-python@5c632ea · GitHub
[go: up one dir, main page]

Skip to content

Commit 5c632ea

Browse files
chore(internal): add support for TypeAliasType (#786)
1 parent 2b0c039 commit 5c632ea

File tree

8 files changed

+76
-19
lines changed

8 files changed

+76
-19
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ authors = [
1010
dependencies = [
1111
"httpx>=0.23.0, <1",
1212
"pydantic>=1.9.0, <3",
13-
"typing-extensions>=4.7, <5",
13+
"typing-extensions>=4.10, <5",
1414
"anyio>=3.5.0, <5",
1515
"distro>=1.7.0, <2",
1616
"sniffio",

src/anthropic/_legacy_response.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import pydantic
2525

2626
from ._types import NoneType
27-
from ._utils import is_given, extract_type_arg, is_annotated_type
27+
from ._utils import is_given, extract_type_arg, is_annotated_type, is_type_alias_type
2828
from ._models import BaseModel, is_basemodel, add_request_id
2929
from ._constants import RAW_RESPONSE_HEADER
3030
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
@@ -197,9 +197,15 @@ def elapsed(self) -> datetime.timedelta:
197197
return self.http_response.elapsed
198198

199199
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
200+
cast_to = to if to is not None else self._cast_to
201+
202+
# unwrap `TypeAlias('Name', T)` -> `T`
203+
if is_type_alias_type(cast_to):
204+
cast_to = cast_to.__value__ # type: ignore[unreachable]
205+
200206
# unwrap `Annotated[T, ...]` -> `T`
201-
if to and is_annotated_type(to):
202-
to = extract_type_arg(to, 0)
207+
if cast_to and is_annotated_type(cast_to):
208+
cast_to = extract_type_arg(cast_to, 0)
203209

204210
cast_to = to if to is not None else self._cast_to
205211
origin = get_origin(cast_to) or cast_to
@@ -259,16 +265,12 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
259265
return cast(
260266
R,
261267
stream_cls(
262-
cast_to=self._cast_to,
268+
cast_to=cast_to,
263269
response=self.http_response,
264270
client=cast(Any, self._client),
265271
),
266272
)
267273

268-
# unwrap `Annotated[T, ...]` -> `T`
269-
if is_annotated_type(cast_to):
270-
cast_to = extract_type_arg(cast_to, 0)
271-
272274
if cast_to is NoneType:
273275
return cast(R, None)
274276

src/anthropic/_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
strip_not_given,
4747
extract_type_arg,
4848
is_annotated_type,
49+
is_type_alias_type,
4950
strip_annotated_type,
5051
)
5152
from ._compat import (
@@ -444,6 +445,8 @@ def construct_type(*, value: object, type_: object) -> object:
444445
# we allow `object` as the input type because otherwise, passing things like
445446
# `Literal['value']` will be reported as a type error by type checkers
446447
type_ = cast("type[object]", type_)
448+
if is_type_alias_type(type_):
449+
type_ = type_.__value__ # type: ignore[unreachable]
447450

448451
# unwrap `Annotated[T, ...]` -> `T`
449452
if is_annotated_type(type_):

src/anthropic/_response.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import pydantic
2626

2727
from ._types import NoneType
28-
from ._utils import is_given, extract_type_arg, is_annotated_type, extract_type_var_from_base
28+
from ._utils import is_given, extract_type_arg, is_annotated_type, is_type_alias_type, extract_type_var_from_base
2929
from ._models import BaseModel, is_basemodel, add_request_id
3030
from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
3131
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
@@ -127,9 +127,15 @@ def __repr__(self) -> str:
127127
)
128128

129129
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
130+
cast_to = to if to is not None else self._cast_to
131+
132+
# unwrap `TypeAlias('Name', T)` -> `T`
133+
if is_type_alias_type(cast_to):
134+
cast_to = cast_to.__value__ # type: ignore[unreachable]
135+
130136
# unwrap `Annotated[T, ...]` -> `T`
131-
if to and is_annotated_type(to):
132-
to = extract_type_arg(to, 0)
137+
if cast_to and is_annotated_type(cast_to):
138+
cast_to = extract_type_arg(cast_to, 0)
133139

134140
cast_to = to if to is not None else self._cast_to
135141
origin = get_origin(cast_to) or cast_to
@@ -189,16 +195,12 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
189195
return cast(
190196
R,
191197
stream_cls(
192-
cast_to=self._cast_to,
198+
cast_to=cast_to,
193199
response=self.http_response,
194200
client=cast(Any, self._client),
195201
),
196202
)
197203

198-
# unwrap `Annotated[T, ...]` -> `T`
199-
if is_annotated_type(cast_to):
200-
cast_to = extract_type_arg(cast_to, 0)
201-
202204
if cast_to is NoneType:
203205
return cast(R, None)
204206

src/anthropic/_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
is_iterable_type as is_iterable_type,
4040
is_required_type as is_required_type,
4141
is_annotated_type as is_annotated_type,
42+
is_type_alias_type as is_type_alias_type,
4243
strip_annotated_type as strip_annotated_type,
4344
extract_type_var_from_base as extract_type_var_from_base,
4445
)

src/anthropic/_utils/_typing.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
from __future__ import annotations
22

3+
import sys
4+
import typing
5+
import typing_extensions
36
from typing import Any, TypeVar, Iterable, cast
47
from collections import abc as _c_abc
5-
from typing_extensions import Required, Annotated, get_args, get_origin
8+
from typing_extensions import (
9+
TypeIs,
10+
Required,
11+
Annotated,
12+
get_args,
13+
get_origin,
14+
)
615

716
from .._types import InheritsGeneric
817
from .._compat import is_union as _is_union
@@ -36,6 +45,26 @@ def is_typevar(typ: type) -> bool:
3645
return type(typ) == TypeVar # type: ignore
3746

3847

48+
_TYPE_ALIAS_TYPES: tuple[type[typing_extensions.TypeAliasType], ...] = (typing_extensions.TypeAliasType,)
49+
if sys.version_info >= (3, 12):
50+
_TYPE_ALIAS_TYPES = (*_TYPE_ALIAS_TYPES, typing.TypeAliasType)
51+
52+
53+
def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:
54+
"""Return whether the provided argument is an instance of `TypeAliasType`.
55+
56+
```python
57+
type Int = int
58+
is_type_alias_type(Int)
59+
# > True
60+
Str = TypeAliasType("Str", str)
61+
is_type_alias_type(Str)
62+
# > True
63+
```
64+
"""
65+
return isinstance(tp, _TYPE_ALIAS_TYPES)
66+
67+
3968
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
4069
def strip_annotated_type(typ: type) -> type:
4170
if is_required_type(typ) or is_annotated_type(typ):

tests/test_models.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from typing import Any, Dict, List, Union, Optional, cast
33
from datetime import datetime, timezone
4-
from typing_extensions import Literal, Annotated
4+
from typing_extensions import Literal, Annotated, TypeAliasType
55

66
import pytest
77
import pydantic
@@ -828,3 +828,19 @@ class B(BaseModel):
828828
# if the discriminator details object stays the same between invocations then
829829
# we hit the cache
830830
assert UnionType.__discriminator__ is discriminator
831+
832+
833+
@pytest.mark.skipif(not PYDANTIC_V2, reason="TypeAliasType is not supported in Pydantic v1")
834+
def test_type_alias_type() -> None:
835+
Alias = TypeAliasType("Alias", str)
836+
837+
class Model(BaseModel):
838+
alias: Alias
839+
union: Union[int, Alias]
840+
841+
m = construct_type(value={"alias": "foo", "union": "bar"}, type_=Model)
842+
assert isinstance(m, Model)
843+
assert isinstance(m.alias, str)
844+
assert m.alias == "foo"
845+
assert isinstance(m.union, str)
846+
assert m.union == "bar"

tests/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
is_union_type,
1717
extract_type_arg,
1818
is_annotated_type,
19+
is_type_alias_type,
1920
)
2021
from anthropic._compat import PYDANTIC_V2, field_outer_type, get_model_fields
2122
from anthropic._models import BaseModel
@@ -51,6 +52,9 @@ def assert_matches_type(
5152
path: list[str],
5253
allow_none: bool = False,
5354
) -> None:
55+
if is_type_alias_type(type_):
56+
type_ = type_.__value__
57+
5458
# unwrap `Annotated[T, ...]` -> `T`
5559
if is_annotated_type(type_):
5660
type_ = extract_type_arg(type_, 0)

0 commit comments

Comments
 (0)
0