diff --git a/stagehand/config.py b/stagehand/config.py index a577230..d9f4b9b 100644 --- a/stagehand/config.py +++ b/stagehand/config.py @@ -1,8 +1,8 @@ import os -from typing import Any, Callable, Literal, Optional +from typing import Any, Callable, Literal, Optional, Union from browserbase.types import SessionCreateParams as BrowserbaseSessionCreateParams -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from stagehand.schemas import AvailableModel @@ -65,7 +65,7 @@ class StagehandConfig(BaseModel): alias="domSettleTimeoutMs", description="Timeout for DOM to settle (in ms)", ) - browserbase_session_create_params: Optional[BrowserbaseSessionCreateParams] = Field( + browserbase_session_create_params: Optional[Union[BrowserbaseSessionCreateParams, dict[str, Any]]] = Field( None, alias="browserbaseSessionCreateParams", description="Browserbase session create params", @@ -111,6 +111,17 @@ class StagehandConfig(BaseModel): ) model_config = ConfigDict(populate_by_name=True) + + @field_validator('browserbase_session_create_params', mode='before') + @classmethod + def validate_browserbase_params(cls, v, info): + """Validate and convert browserbase session create params.""" + if isinstance(v, dict) and 'project_id' not in v: + values = info.data + project_id = values.get('project_id') or values.get('projectId') + if project_id: + v = {**v, 'project_id': project_id} + return v def with_overrides(self, **overrides) -> "StagehandConfig": """