Skip to content

Commit 216b8a5

Browse files
fix: no blank children names (#666)
Co-authored-by: Andrew Truong <[email protected]>
1 parent 0cd86d2 commit 216b8a5

File tree

4 files changed

+63
-4
lines changed

4 files changed

+63
-4
lines changed

polyfactory/factories/pydantic_factory.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,21 @@ def from_field_info(
181181
if is_union(annotation):
182182
constraints = {}
183183
children = []
184+
185+
# create a child for each of the possible union values
184186
for arg in get_args(annotation):
187+
# don't add the NoneType in an optional to the list of children
185188
if arg is NoneType:
186189
continue
187190
child_field_info = FieldInfo.from_annotation(arg)
188191
merged_field_info = FieldInfo.merge_field_infos(field_info, child_field_info)
192+
189193
children.append(
194+
# recurse for each element of the union
190195
cls.from_field_info(
191-
field_name="",
196+
# this is a fake field name, but it makes it possible to debug which type variant
197+
# is the source of an exception downstream
198+
field_name=field_name,
192199
field_info=merged_field_info,
193200
use_alias=use_alias,
194201
),

polyfactory/field_meta.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ def from_type(
144144
metadata = cls.get_constraints_metadata(annotation)
145145
constraints = cls.parse_constraints(metadata)
146146

147+
# annotations can take many forms: Optional, an Annotated type, or anything with __args__
148+
# in order to normalize the annotation, we need to unwrap the annotation.
147149
if not annotated and (origin := get_origin(annotation)) and origin in TYPE_MAPPING:
148150
container = TYPE_MAPPING[origin]
149151
annotation = container[get_args(annotation)] # type: ignore[index]

polyfactory/utils/helpers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def unwrap_new_type(annotation: Any) -> Any:
3838

3939
@deprecated("v2.21.0")
4040
def unwrap_union(annotation: Any, random: Random) -> Any:
41-
"""Unwraps union types - recursively.
41+
"""Unwraps union types recursively and picks a random type from each union.
4242
4343
:param annotation: A type annotation, possibly a type union.
4444
:param random: An instance of random.Random.
@@ -86,7 +86,12 @@ def unwrap_annotation(annotation: Any, random: Random | None = None) -> Any:
8686

8787

8888
def flatten_annotation(annotation: Any) -> list[Any]:
89-
"""Flattens an annotation.
89+
"""Flattens an annotation into an array of possible types. For example:
90+
91+
* Union[str, int] → [str, int]
92+
* Optional[str] → [str, None]
93+
* Union[str, Optional[int]] → [str, int, None]
94+
* NewType('UserId', int) → [int]
9095
9196
:param annotation: A type annotation.
9297

tests/test_pydantic_factory.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from decimal import Decimal
77
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
88
from pathlib import Path
9-
from typing import Callable, Dict, FrozenSet, List, Literal, Optional, Sequence, Set, Tuple, Type, Union
9+
from typing import Any, Callable, Dict, FrozenSet, List, Literal, Optional, Sequence, Set, Tuple, Type, Union
1010
from uuid import UUID
1111

1212
import pytest
@@ -64,8 +64,10 @@
6464
validator,
6565
)
6666

67+
from polyfactory.exceptions import ParameterException
6768
from polyfactory.factories import DataclassFactory
6869
from polyfactory.factories.pydantic_factory import _IS_PYDANTIC_V1, ModelFactory
70+
from polyfactory.field_meta import FieldMeta
6971
from tests.models import Person, PetFactory
7072

7173
IS_PYDANTIC_V1 = _IS_PYDANTIC_V1
@@ -634,6 +636,49 @@ class A(BaseModel):
634636
assert AFactory.build()
635637

636638

639+
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires modern union types")
640+
@pytest.mark.skipif(IS_PYDANTIC_V1, reason="pydantic 2 only test")
641+
def test_optional_custom_type() -> None:
642+
from pydantic_core import core_schema
643+
644+
class CustomType:
645+
def __init__(self, _: Any) -> None:
646+
pass
647+
648+
def __get_pydantic_core_schema__(self, _: Any) -> core_schema.StringSchema:
649+
# for pydantic to stop complaining
650+
return core_schema.str_schema()
651+
652+
class OptionalFormOne(BaseModel):
653+
optional_custom_type: Optional[CustomType]
654+
655+
@classmethod
656+
def should_set_none_value(cls, field_meta: FieldMeta) -> bool:
657+
return False
658+
659+
class OptionalFormOneFactory(ModelFactory[OptionalFormOne]):
660+
@classmethod
661+
def should_set_none_value(cls, field_meta: FieldMeta) -> bool:
662+
return False
663+
664+
class OptionalFormTwo(BaseModel):
665+
# this is represented differently than `Optional[None]` internally
666+
optional_custom_type_second_form: CustomType | None
667+
668+
class OptionalFormTwoFactory(ModelFactory[OptionalFormTwo]):
669+
@classmethod
670+
def should_set_none_value(cls, field_meta: FieldMeta) -> bool:
671+
return False
672+
673+
# ensure the custom type field name and variant is in the error message
674+
675+
with pytest.raises(ParameterException, match=r"optional_custom_type"):
676+
OptionalFormOneFactory.build()
677+
678+
with pytest.raises(ParameterException, match=r"optional_custom_type_second_form"):
679+
OptionalFormTwoFactory.build()
680+
681+
637682
def test_collection_unions_with_models() -> None:
638683
class A(BaseModel):
639684
a: int

0 commit comments

Comments
 (0)