Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 61 additions & 37 deletions scripts/generate_ergonomic_coercion.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"ListCreativesResponse",
"ListCreativeFormatsResponse",
"CreateMediaBuyResponse1",
"UpdateMediaBuyResponse1",
"GetMediaBuyDeliveryResponse",
]

Expand Down Expand Up @@ -227,25 +228,45 @@ def get_symbol_name(cls) -> str:
def generate_code() -> str:
"""Generate the _ergonomic.py module content."""
# Import all the types we need to analyze
from pydantic import BaseModel as _PydBaseModel

from adcp.types.generated_poc.creative.list_creatives_request import (
ListCreativesRequest,
Sort,
)
from adcp.types.generated_poc.creative.list_creatives_response import ListCreativesResponse
from adcp.types.generated_poc.media_buy import (
create_media_buy_response as _cmbr_module,
)
from adcp.types.generated_poc.media_buy import (
update_media_buy_response as _umbr_module,
)
from adcp.types.generated_poc.media_buy.create_media_buy_request import CreateMediaBuyRequest
from adcp.types.generated_poc.media_buy.get_media_buy_delivery_response import (
GetMediaBuyDeliveryResponse,
)
from adcp.types.generated_poc.media_buy.get_products_request import (
GetProductsRequest,
)
from adcp.types.generated_poc.media_buy.get_products_response import GetProductsResponse
from adcp.types.generated_poc.media_buy.list_creative_formats_request import (
ListCreativeFormatsRequest,
)
from adcp.types.generated_poc.media_buy.list_creative_formats_response import (
ListCreativeFormatsResponse,
)
from adcp.types.generated_poc.media_buy.package_request import PackageRequest
from adcp.types.generated_poc.media_buy.package_update import PackageUpdate

# Resolve the CreateMediaBuyResponse success variant. Different
# Resolve success variants. Different
# datamodel-codegen versions emit the variants under shifting names:
# sometimes `CreateMediaBuyResponse1`/`...2`, sometimes `CreateMediaBuyResponse`
# sometimes `FooResponse1`/`...2`, sometimes `FooResponse`
# (success, unnumbered) + `CreateMediaBuyResponseN`. Find the success
# variant by scanning the module for a pydantic model whose fields include
# the success-only `media_buy_id` key.
from pydantic import BaseModel as _PydBaseModel

def _find_success_variant() -> type[_PydBaseModel] | None:
for name in dir(_cmbr_module):
obj = getattr(_cmbr_module, name)
def _find_media_buy_success_variant(module: Any) -> type[_PydBaseModel] | None:
for name in dir(module):
obj = getattr(module, name)
if (
isinstance(obj, type)
and issubclass(obj, _PydBaseModel)
Expand All @@ -254,28 +275,8 @@ def _find_success_variant() -> type[_PydBaseModel] | None:
return obj
return None

CreateMediaBuyResponse1 = _find_success_variant()
from adcp.types.generated_poc.media_buy.get_products_request import (
Field1 as GetProductsField,
GetProductsRequest,
)

# Response types
from adcp.types.generated_poc.media_buy.get_products_response import GetProductsResponse
from adcp.types.generated_poc.media_buy.list_creative_formats_request import (
ListCreativeFormatsRequest,
)
from adcp.types.generated_poc.media_buy.list_creative_formats_response import (
ListCreativeFormatsResponse,
)
from adcp.types.generated_poc.creative.list_creatives_request import (
Field1 as ListCreativesField,
ListCreativesRequest,
Sort,
)
from adcp.types.generated_poc.creative.list_creatives_response import ListCreativesResponse
from adcp.types.generated_poc.media_buy.package_request import PackageRequest
from adcp.types.generated_poc.media_buy.package_update import PackageUpdate
create_media_buy_response1 = _find_media_buy_success_variant(_cmbr_module)
update_media_buy_response1 = _find_media_buy_success_variant(_umbr_module)

# Map names to classes
request_classes = {
Expand All @@ -291,8 +292,10 @@ def _find_success_variant() -> type[_PydBaseModel] | None:
"ListCreativeFormatsResponse": ListCreativeFormatsResponse,
"GetMediaBuyDeliveryResponse": GetMediaBuyDeliveryResponse,
}
if CreateMediaBuyResponse1 is not None:
response_classes["CreateMediaBuyResponse1"] = CreateMediaBuyResponse1
if create_media_buy_response1 is not None:
response_classes["CreateMediaBuyResponse1"] = create_media_buy_response1
if update_media_buy_response1 is not None:
response_classes["UpdateMediaBuyResponse1"] = update_media_buy_response1

nested_classes = {
"Sort": Sort,
Expand Down Expand Up @@ -422,14 +425,22 @@ def _find_success_variant() -> type[_PydBaseModel] | None:
# Add response type imports. CreateMediaBuyResponse1's numeric suffix can
# shift between codegen versions; in schema versions that collapse the
# response envelope, there may be no success-specific response variant.
if CreateMediaBuyResponse1 is not None:
cmbr_name = CreateMediaBuyResponse1.__name__
if create_media_buy_response1 is not None:
cmbr_name = create_media_buy_response1.__name__
lines.append("from adcp.types.generated_poc.media_buy.create_media_buy_response import (")
if cmbr_name == "CreateMediaBuyResponse1":
lines.append(" CreateMediaBuyResponse1,")
else:
lines.append(f" {cmbr_name} as CreateMediaBuyResponse1,")
lines.append(")")
if update_media_buy_response1 is not None:
umbr_name = update_media_buy_response1.__name__
lines.append("from adcp.types.generated_poc.media_buy.update_media_buy_response import (")
if umbr_name == "UpdateMediaBuyResponse1":
lines.append(" UpdateMediaBuyResponse1,")
else:
lines.append(f" {umbr_name} as UpdateMediaBuyResponse1,")
lines.append(")")
lines.append("from adcp.types.generated_poc.media_buy.get_media_buy_delivery_response import (")
lines.append(" GetMediaBuyDeliveryResponse,")
lines.append(" MediaBuyDelivery,")
Expand Down Expand Up @@ -461,6 +472,7 @@ def _find_success_variant() -> type[_PydBaseModel] | None:
"PackageRequest",
"PackageUpdate",
"CreateMediaBuyResponse1",
"UpdateMediaBuyResponse1",
"GetMediaBuyDeliveryResponse",
"MediaBuyDelivery",
"NotificationType",
Expand Down Expand Up @@ -500,6 +512,7 @@ def _find_success_variant() -> type[_PydBaseModel] | None:
"ListCreativesResponse",
"ListCreativeFormatsResponse",
"CreateMediaBuyResponse1",
"UpdateMediaBuyResponse1",
"GetMediaBuyDeliveryResponse",
]

Expand All @@ -523,8 +536,16 @@ def _find_success_variant() -> type[_PydBaseModel] | None:
elif c["type"] == "ext":
field_comments.append(f'{c["field"]}: ExtensionObject | dict | None')
elif c["type"] == "subclass_list":
from collections.abc import Sequence as AbcSequence

field_info = cls.model_fields[c["field"]]
ann = field_info.annotation
base_ann = get_base_type(ann)
is_seq = get_origin(base_ann if base_ann is not None else ann) is AbcSequence
container = "Sequence" if is_seq else "list"
field_comments.append(
f'{c["field"]}: list[{c["target_class"].__name__}] (accepts subclass instances)'
f'{c["field"]}: {container}[{c["target_class"].__name__}] '
"(accepts subclass instances)"
)

lines.append(f" # Apply coercion to {type_name}")
Expand All @@ -540,7 +561,8 @@ def _find_success_variant() -> type[_PydBaseModel] | None:
lines.append(f" {type_name},")
lines.append(f' "{field}",')
lines.append(
f" Annotated[{target} | None, BeforeValidator(coerce_to_enum({target}))],"
f" Annotated[{target} | None, "
f"BeforeValidator(coerce_to_enum({target}))],"
)
lines.append(" )")
elif c["type"] == "enum_list":
Expand All @@ -558,15 +580,17 @@ def _find_success_variant() -> type[_PydBaseModel] | None:
lines.append(f" {type_name},")
lines.append(f' "{field}",')
lines.append(
" Annotated[ContextObject | None, BeforeValidator(coerce_to_model(ContextObject))],"
" Annotated[ContextObject | None, "
"BeforeValidator(coerce_to_model(ContextObject))],"
)
lines.append(" )")
elif c["type"] == "ext":
lines.append(" _patch_field_annotation(")
lines.append(f" {type_name},")
lines.append(f' "{field}",')
lines.append(
" Annotated[ExtensionObject | None, BeforeValidator(coerce_to_model(ExtensionObject))],"
" Annotated[ExtensionObject | None, "
"BeforeValidator(coerce_to_model(ExtensionObject))],"
)
lines.append(" )")
elif c["type"] == "subclass_list":
Expand Down
38 changes: 34 additions & 4 deletions src/adcp/types/_ergonomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@
from adcp.types.generated_poc.media_buy.create_media_buy_response import (
CreateMediaBuyResponse1,
)
from adcp.types.generated_poc.media_buy.update_media_buy_response import (
UpdateMediaBuyResponse1,
)
from adcp.types.generated_poc.media_buy.get_media_buy_delivery_response import (
GetMediaBuyDeliveryResponse,
MediaBuyDelivery,
Expand Down Expand Up @@ -238,7 +241,7 @@ def _apply_coercion() -> None:
# Apply coercion to PackageRequest
# - pacing: Pacing | str | None
# - creative_assignments: list[CreativeAssignment] (accepts subclass instances)
# - creatives: list[CreativeAsset] (accepts subclass instances)
# - creatives: Sequence[CreativeAsset] (accepts subclass instances)
# - context: ContextObject | dict | None
# - ext: ExtensionObject | dict | None
_patch_field_annotation(
Expand Down Expand Up @@ -275,7 +278,7 @@ def _apply_coercion() -> None:
PackageRequest.model_rebuild(force=True)

# Apply coercion to CreateMediaBuyRequest
# - packages: list[PackageRequest] (accepts subclass instances)
# - packages: Sequence[PackageRequest] (accepts subclass instances)
# - advertiser_industry: AdvertiserIndustry | str | None
# - context: ContextObject | dict | None
# - ext: ExtensionObject | dict | None
Expand Down Expand Up @@ -391,7 +394,7 @@ def _apply_coercion() -> None:
# Apply coercion to ListCreativesResponse
# - context: ContextObject | dict | None
# - status: TaskStatus | str | None
# - creatives: list[Creative] (accepts subclass instances)
# - creatives: Sequence[Creative] (accepts subclass instances)
# - errors: list[Error] (accepts subclass instances)
# - ext: ExtensionObject | dict | None
_patch_field_annotation(
Expand Down Expand Up @@ -499,11 +502,38 @@ def _apply_coercion() -> None:
)
CreateMediaBuyResponse1.model_rebuild(force=True)

# Apply coercion to UpdateMediaBuyResponse1
# - affected_packages: Sequence[Package] (accepts subclass instances)
# - packages: list[Package] (accepts subclass instances)
# - media_buy_status: MediaBuyStatus | str | None
_patch_field_annotation(
UpdateMediaBuyResponse1,
"affected_packages",
Annotated[
Sequence[Package] | None,
BeforeValidator(coerce_subclass_list(Package)),
],
)
_patch_field_annotation(
UpdateMediaBuyResponse1,
"packages",
Annotated[
list[Package] | None,
BeforeValidator(coerce_subclass_list(Package)),
],
)
_patch_field_annotation(
UpdateMediaBuyResponse1,
"media_buy_status",
Annotated[MediaBuyStatus | None, BeforeValidator(coerce_to_enum(MediaBuyStatus))],
)
UpdateMediaBuyResponse1.model_rebuild(force=True)

# Apply coercion to GetMediaBuyDeliveryResponse
# - context: ContextObject | dict | None
# - status: TaskStatus | str | None
# - notification_type: NotificationType | str | None
# - media_buy_deliveries: list[MediaBuyDelivery] (accepts subclass instances)
# - media_buy_deliveries: Sequence[MediaBuyDelivery] (accepts subclass instances)
# - errors: list[Error] (accepts subclass instances)
# - ext: ExtensionObject | dict | None
_patch_field_annotation(
Expand Down
43 changes: 43 additions & 0 deletions tests/test_type_coercion.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,49 @@ class ExtendedPackage(Package):
# Internal field is preserved at runtime
assert response.packages[0].campaign_id == "campaign-456" # type: ignore[attr-defined]

def test_update_media_buy_response_accepts_package_subclasses(self):
"""UpdateMediaBuySuccessResponse package fields accept Package subclasses."""
from pydantic import Field

from adcp.types import Package, UpdateMediaBuySuccessResponse

class ExtendedPackage(Package):
"""Extended with internal tracking fields."""

campaign_id: str | None = Field(None, exclude=True)

package = ExtendedPackage(
package_id="pkg1",
campaign_id="campaign-456",
)

response = UpdateMediaBuySuccessResponse(
media_buy_id="mb1",
buyer_ref="buyer-ref",
packages=[package], # type: ignore[list-item] # Ignoring due to Python list covariance limitation
affected_packages=[package],
)

assert response.packages is not None
assert response.affected_packages is not None
assert response.packages[0].package_id == "pkg1"
assert response.affected_packages[0].package_id == "pkg1"
assert response.packages[0].campaign_id == "campaign-456" # type: ignore[attr-defined]
assert (
response.affected_packages[0].campaign_id == "campaign-456" # type: ignore[attr-defined]
)

def test_update_media_buy_response_accepts_media_buy_status_string(self):
"""UpdateMediaBuySuccessResponse.media_buy_status accepts strings."""
from adcp.types import MediaBuyStatus, UpdateMediaBuySuccessResponse

response = UpdateMediaBuySuccessResponse(
media_buy_id="mb1",
media_buy_status="active",
)

assert response.media_buy_status == MediaBuyStatus.active

def test_get_media_buy_delivery_response_accepts_dict_context(self):
"""GetMediaBuyDeliveryResponse.context accepts dict."""
from datetime import datetime, timezone
Expand Down
Loading