Skip to content

Commit c2845bb

Browse files
author
Namrata Madan
committed
feat: support pipeline versioning
1 parent 23c3840 commit c2845bb

File tree

4 files changed

+165
-16
lines changed

4 files changed

+165
-16
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ classifiers = [
3232
]
3333
dependencies = [
3434
"attrs>=24,<26",
35-
"boto3>=1.35.36,<2.0",
35+
"boto3>=1.39.5,<2.0",
3636
"cloudpickle>=2.2.1",
3737
"docker",
3838
"fastapi",

src/sagemaker/workflow/pipeline.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(
124124
self._event_bridge_scheduler_helper = EventBridgeSchedulerHelper(
125125
self.sagemaker_session.boto_session.client("scheduler"),
126126
)
127+
self.latest_pipeline_version_id = None
127128

128129
def create(
129130
self,
@@ -166,7 +167,9 @@ def create(
166167
kwargs,
167168
Tags=tags,
168169
)
169-
return self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs)
170+
response = self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs)
171+
self.latest_pipeline_version_id = 1
172+
return response
170173

171174
def _create_args(
172175
self, role_arn: str, description: str, parallelism_config: ParallelismConfiguration
@@ -214,15 +217,21 @@ def _create_args(
214217
)
215218
return kwargs
216219

217-
def describe(self) -> Dict[str, Any]:
220+
def describe(self, pipeline_version_id: int = None) -> Dict[str, Any]:
218221
"""Describes a Pipeline in the Workflow service.
219222
223+
Args:
224+
pipeline_version_id (Optional[str]): version ID of the pipeline to describe.
225+
220226
Returns:
221227
Response dict from the service. See `boto3 client documentation
222228
<https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/\
223229
sagemaker.html#SageMaker.Client.describe_pipeline>`_
224230
"""
225-
return self.sagemaker_session.sagemaker_client.describe_pipeline(PipelineName=self.name)
231+
kwargs = dict(PipelineName=self.name)
232+
if pipeline_version_id:
233+
kwargs["PipelineVersionId"] = pipeline_version_id
234+
return self.sagemaker_session.sagemaker_client.describe_pipeline(**kwargs)
226235

227236
def update(
228237
self,
@@ -257,7 +266,10 @@ def update(
257266
return self.sagemaker_session.sagemaker_client.update_pipeline(self, description)
258267

259268
kwargs = self._create_args(role_arn, description, parallelism_config)
260-
return self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs)
269+
response = self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs)
270+
if "PipelineVersionId" in response:
271+
self.latest_pipeline_version_id = response["PipelineVersionId"]
272+
return response
261273

262274
def upsert(
263275
self,
@@ -332,6 +344,7 @@ def start(
332344
execution_description: str = None,
333345
parallelism_config: ParallelismConfiguration = None,
334346
selective_execution_config: SelectiveExecutionConfig = None,
347+
pipeline_version_id: int = None,
335348
):
336349
"""Starts a Pipeline execution in the Workflow service.
337350
@@ -345,6 +358,8 @@ def start(
345358
over the parallelism configuration of the parent pipeline.
346359
selective_execution_config (Optional[SelectiveExecutionConfig]): The configuration for
347360
selective step execution.
361+
pipeline_version_id (Optional[str]): version ID of the pipeline to start the execution from. If not
362+
specified, uses the latest version ID.
348363
349364
Returns:
350365
A `_PipelineExecution` instance, if successful.
@@ -366,6 +381,7 @@ def start(
366381
PipelineExecutionDisplayName=execution_display_name,
367382
ParallelismConfiguration=parallelism_config,
368383
SelectiveExecutionConfig=selective_execution_config,
384+
PipelineVersionId=pipeline_version_id,
369385
)
370386
if self.sagemaker_session.local_mode:
371387
update_args(kwargs, PipelineParameters=parameters)
@@ -461,6 +477,32 @@ def list_executions(
461477
if key in response
462478
}
463479

480+
def list_pipeline_versions(
481+
self, sort_order: str = None, max_results: int = None, next_token: str = None
482+
) -> str:
483+
"""Lists a pipeline's versions.
484+
485+
Args:
486+
sort_order (str): The sort order for results (Ascending/Descending).
487+
max_results (int): The maximum number of pipeline executions to return in the response.
488+
next_token (str): If the result of the previous `ListPipelineExecutions` request was
489+
truncated, the response includes a `NextToken`. To retrieve the next set of pipeline
490+
executions, use the token in the next request.
491+
492+
Returns:
493+
List of Pipeline Version Summaries. See
494+
boto3 client list_pipeline_versions
495+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/list_pipeline_versions.html#
496+
"""
497+
kwargs = dict(PipelineName=self.name)
498+
update_args(
499+
kwargs,
500+
SortOrder=sort_order,
501+
NextToken=next_token,
502+
MaxResults=max_results,
503+
)
504+
return self.sagemaker_session.sagemaker_client.list_pipeline_versions(**kwargs)
505+
464506
def _get_latest_execution_arn(self):
465507
"""Retrieves the latest execution of this pipeline"""
466508
response = self.list_executions(
@@ -855,7 +897,7 @@ def describe(self):
855897
sagemaker.html#SageMaker.Client.describe_pipeline_execution>`_.
856898
"""
857899
return self.sagemaker_session.sagemaker_client.describe_pipeline_execution(
858-
PipelineExecutionArn=self.arn,
900+
PipelineExecutionArn=self.arn
859901
)
860902

861903
def list_steps(self):

tests/integ/sagemaker/workflow/test_workflow.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def test_three_step_definition(
312312
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
313313
create_arn,
314314
)
315+
assert pipeline.latest_pipeline_version_id == 1
315316
finally:
316317
try:
317318
pipeline.delete()
@@ -937,7 +938,6 @@ def test_large_pipeline(sagemaker_session_for_pipeline, role, pipeline_name, reg
937938
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
938939
create_arn,
939940
)
940-
response = pipeline.describe()
941941
assert len(json.loads(pipeline.describe()["PipelineDefinition"])["Steps"]) == 2000
942942

943943
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
@@ -1387,3 +1387,56 @@ def test_caching_behavior(
13871387
except Exception:
13881388
os.remove(script_dir + "/dummy_script.py")
13891389
pass
1390+
1391+
1392+
def test_pipeline_versioning(pipeline_session, role, pipeline_name, script_dir):
1393+
sklearn_train = SKLearn(
1394+
framework_version="0.20.0",
1395+
entry_point=os.path.join(script_dir, "train.py"),
1396+
instance_type="ml.m5.xlarge",
1397+
sagemaker_session=pipeline_session,
1398+
role=role,
1399+
)
1400+
1401+
step1 = TrainingStep(
1402+
name="my-train-1",
1403+
display_name="TrainingStep",
1404+
description="description for Training step",
1405+
step_args=sklearn_train.fit(),
1406+
)
1407+
1408+
step2 = TrainingStep(
1409+
name="my-train-2",
1410+
display_name="TrainingStep",
1411+
description="description for Training step",
1412+
step_args=sklearn_train.fit(),
1413+
)
1414+
pipeline = Pipeline(
1415+
name=pipeline_name,
1416+
steps=[step1],
1417+
sagemaker_session=pipeline_session,
1418+
)
1419+
1420+
try:
1421+
pipeline.create(role)
1422+
1423+
assert pipeline.latest_pipeline_version_id == 1
1424+
1425+
describe_response = pipeline.describe(pipeline_version_id=1)
1426+
assert len(json.loads(describe_response["PipelineDefinition"])["Steps"]) == 1
1427+
1428+
pipeline.steps.append(step2)
1429+
pipeline.upsert(role)
1430+
1431+
assert pipeline.latest_pipeline_version_id == 2
1432+
1433+
describe_response = pipeline.describe(pipeline_version_id=2)
1434+
assert len(json.loads(describe_response["PipelineDefinition"])["Steps"]) == 2
1435+
1436+
assert len(pipeline.list_pipeline_versions()["PipelineVersionSummaries"]) == 2
1437+
1438+
finally:
1439+
try:
1440+
pipeline.delete()
1441+
except Exception:
1442+
pass

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,23 @@ def test_pipeline_create_and_update_with_config_injection(sagemaker_session_mock
9292
PipelineDefinition=pipeline.definition(),
9393
RoleArn=pipeline_role_arn,
9494
)
95+
96+
sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
97+
"PipelineArn": "pipeline-arn",
98+
"PipelineVersionId": 2,
99+
}
100+
95101
pipeline.update()
96102
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with(
97103
PipelineName="MyPipeline",
98104
PipelineDefinition=pipeline.definition(),
99105
RoleArn=pipeline_role_arn,
100106
)
107+
108+
sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
109+
"PipelineArn": "pipeline-arn",
110+
"PipelineVersionId": 3,
111+
}
101112
pipeline.upsert()
102113
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with(
103114
PipelineName="MyPipeline",
@@ -207,6 +218,11 @@ def test_pipeline_update(sagemaker_session_mock, role_arn):
207218
sagemaker_session=sagemaker_session_mock,
208219
)
209220
assert not pipeline.steps
221+
222+
sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
223+
"PipelineArn": "pipeline-arn",
224+
"PipelineVersionId": 1,
225+
}
210226
pipeline.update(role_arn=role_arn)
211227
assert len(json.loads(pipeline.definition())["Steps"]) == 0
212228
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with(
@@ -251,6 +267,11 @@ def test_pipeline_update(sagemaker_session_mock, role_arn):
251267
)
252268
assert len(pipeline.steps) == 2
253269

270+
sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
271+
"PipelineArn": "pipeline-arn",
272+
"PipelineVersionId": 2,
273+
}
274+
254275
pipeline.update(role_arn=role_arn)
255276
assert len(json.loads(pipeline.definition())["Steps"]) == 3
256277
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with(
@@ -345,6 +366,11 @@ def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_ar
345366
role_arn=role_arn,
346367
parallelism_config=dict(MaxParallelExecutionSteps=10),
347368
)
369+
sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
370+
"PipelineArn": "pipeline-arn",
371+
"PipelineVersionId": 2,
372+
}
373+
348374
pipeline.update(
349375
role_arn=role_arn,
350376
parallelism_config={"MaxParallelExecutionSteps": 10},
@@ -393,7 +419,8 @@ def _raise_does_already_exists_client_error(**kwargs):
393419
)
394420

395421
sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
396-
"PipelineArn": "pipeline-arn"
422+
"PipelineArn": "pipeline-arn",
423+
"PipelineVersionId": 2,
397424
}
398425
sagemaker_session_mock.sagemaker_client.list_tags.return_value = {
399426
"Tags": [{"Key": "dummy", "Value": "dummy_tag"}]
@@ -428,6 +455,7 @@ def _raise_does_already_exists_client_error(**kwargs):
428455
sagemaker_session_mock.sagemaker_client.add_tags.assert_called_with(
429456
ResourceArn="pipeline-arn", Tags=tags
430457
)
458+
assert pipeline.latest_pipeline_version_id == 2
431459

432460

433461
def test_pipeline_upsert_create_unexpected_failure(sagemaker_session_mock, role_arn):
@@ -476,18 +504,11 @@ def _raise_unexpected_client_error(**kwargs):
476504
sagemaker_session_mock.sagemaker_client.add_tags.assert_not_called()
477505

478506

479-
def test_pipeline_upsert_resourse_doesnt_exist(sagemaker_session_mock, role_arn):
507+
def test_pipeline_upsert_resource_doesnt_exist(sagemaker_session_mock, role_arn):
480508

481509
# case 3: resource does not exist
482510
sagemaker_session_mock.sagemaker_client.create_pipeline = Mock(name="create_pipeline")
483511

484-
sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
485-
"PipelineArn": "pipeline-arn"
486-
}
487-
sagemaker_session_mock.sagemaker_client.list_tags.return_value = {
488-
"Tags": [{"Key": "dummy", "Value": "dummy_tag"}]
489-
}
490-
491512
tags = [
492513
{"Key": "foo", "Value": "abc"},
493514
{"Key": "bar", "Value": "xyz"},
@@ -542,6 +563,11 @@ def test_pipeline_describe(sagemaker_session_mock):
542563
PipelineName="MyPipeline",
543564
)
544565

566+
pipeline.describe(pipeline_version_id=5)
567+
sagemaker_session_mock.sagemaker_client.describe_pipeline.assert_called_with(
568+
PipelineName="MyPipeline", PipelineVersionId=5
569+
)
570+
545571

546572
def test_pipeline_start(sagemaker_session_mock):
547573
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = {
@@ -568,6 +594,11 @@ def test_pipeline_start(sagemaker_session_mock):
568594
PipelineName="MyPipeline", PipelineParameters=[{"Name": "alpha", "Value": "epsilon"}]
569595
)
570596

597+
pipeline.start(pipeline_version_id=5)
598+
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with(
599+
PipelineName="MyPipeline", PipelineVersionId=5
600+
)
601+
571602

572603
def test_pipeline_start_selective_execution(sagemaker_session_mock):
573604
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = {
@@ -809,6 +840,29 @@ def test_pipeline_list_executions(sagemaker_session_mock):
809840
assert executions["NextToken"] == "token"
810841

811842

843+
def test_pipeline_list_versions(sagemaker_session_mock):
844+
sagemaker_session_mock.sagemaker_client.list_pipeline_versions.return_value = {
845+
"PipelineVersionSummaries": [Mock()],
846+
"NextToken": "token",
847+
}
848+
pipeline = Pipeline(
849+
name="MyPipeline",
850+
parameters=[ParameterString("alpha", "beta"), ParameterString("gamma", "delta")],
851+
steps=[],
852+
sagemaker_session=sagemaker_session_mock,
853+
)
854+
versions = pipeline.list_pipeline_versions()
855+
assert len(versions["PipelineVersionSummaries"]) == 1
856+
assert versions["NextToken"] == "token"
857+
858+
sagemaker_session_mock.sagemaker_client.list_pipeline_versions.return_value = {
859+
"PipelineVersionSummaries": [Mock(), Mock()],
860+
}
861+
versions = pipeline.list_pipeline_versions(next_token=versions["NextToken"])
862+
assert len(versions["PipelineVersionSummaries"]) == 2
863+
assert "NextToken" not in versions
864+
865+
812866
def test_pipeline_build_parameters_from_execution(sagemaker_session_mock):
813867
pipeline = Pipeline(
814868
name="MyPipeline",

0 commit comments

Comments
 (0)