Skip to content

Commit dee4cd3

Browse files
authored
Example on how to save and load model along with Albumentations preprocessing (#914)
* Update docs * Update notebook
1 parent 85bb28e commit dee4cd3

File tree

2 files changed

+196
-9
lines changed

2 files changed

+196
-9
lines changed

docs/save_load.rst

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,37 @@ For example:
5959
# Or saved and pushed to the Hub simultaneously
6060
model.save_pretrained('username/my-model', push_to_hub=True, metrics={'accuracy': 0.95}, dataset='my_dataset')
6161
62+
Saving with preprocessing transform (Albumentations)
63+
----------------------------------------------------
64+
65+
You can save the preprocessing transform along with the model and push it to the Hub.
66+
This can be useful when you want to share the model with the preprocessing transform that was used during training,
67+
to make sure that the inference pipeline is consistent with the training pipeline.
68+
69+
.. code:: python
70+
71+
import albumentations as A
72+
import segmentation_models_pytorch as smp
73+
74+
# Define a preprocessing transform for image that would be used during inference
75+
preprocessing_transform = A.Compose([
76+
A.Resize(256, 256),
77+
A.Normalize()
78+
])
79+
80+
model = smp.Unet()
81+
82+
directory_or_repo_on_the_hub = "qubvel-hf/unet-with-transform" # <username>/<repo-name>
83+
84+
# Save the model and transform (and pus ot hub, if needed)
85+
model.save_pretrained(directory_or_repo_on_the_hub, push_to_hub=True)
86+
preprocessing_transform.save_pretrained(directory_or_repo_on_the_hub, push_to_hub=True)
87+
88+
# Loading transform and model
89+
restored_model = smp.from_pretrained(directory_or_repo_on_the_hub)
90+
restored_transform = A.Compose.from_pretrained(directory_or_repo_on_the_hub)
91+
92+
print(restored_transform)
6293
6394
Conclusion
6495
----------
@@ -71,4 +102,6 @@ By following these steps, you can easily save, share, and load your models, faci
71102
:target: https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/binary_segmentation_intro.ipynb
72103
:alt: Open In Colab
73104

74-
105+
.. |colab-badge| image:: https://colab.research.google.com/assets/colab-badge.svg
106+
:target: https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/save_load_model_and_share_with_hf_hub.ipynb
107+
:alt: Open In Colab

examples/save_load_model_and_share_with_hf_hub.ipynb

Lines changed: 162 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
},
4949
{
5050
"cell_type": "code",
51-
"execution_count": 6,
51+
"execution_count": 3,
5252
"metadata": {},
5353
"outputs": [],
5454
"source": [
@@ -70,7 +70,7 @@
7070
},
7171
{
7272
"cell_type": "code",
73-
"execution_count": 7,
73+
"execution_count": 4,
7474
"metadata": {},
7575
"outputs": [
7676
{
@@ -82,9 +82,11 @@
8282
"license: mit\n",
8383
"pipeline_tag: image-segmentation\n",
8484
"tags:\n",
85+
"- model_hub_mixin\n",
86+
"- pytorch_model_hub_mixin\n",
87+
"- segmentation-models-pytorch\n",
8588
"- semantic-segmentation\n",
8689
"- pytorch\n",
87-
"- segmentation-models-pytorch\n",
8890
"languages:\n",
8991
"- python\n",
9092
"---\n",
@@ -157,7 +159,7 @@
157159
{
158160
"data": {
159161
"application/vnd.jupyter.widget-view+json": {
160-
"model_id": "075ae026811542bdb4030e53b943efc7",
162+
"model_id": "1d6fe9d868c24175aa5f23a2893a2c21",
161163
"version_major": 2,
162164
"version_minor": 0
163165
},
@@ -179,13 +181,13 @@
179181
},
180182
{
181183
"cell_type": "code",
182-
"execution_count": 8,
184+
"execution_count": 6,
183185
"metadata": {},
184186
"outputs": [
185187
{
186188
"data": {
187189
"application/vnd.jupyter.widget-view+json": {
188-
"model_id": "2921a81d7fd747939b4a425cc17d6104",
190+
"model_id": "2f4f5e4973e44f9a857e89d9ac707b53",
189191
"version_major": 2,
190192
"version_minor": 0
191193
},
@@ -199,10 +201,10 @@
199201
{
200202
"data": {
201203
"text/plain": [
202-
"CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-metadata/commit/9f821c7bc3a12db827c0da96a31f354ec6ba5253', commit_message='Push model using huggingface_hub.', commit_description='', oid='9f821c7bc3a12db827c0da96a31f354ec6ba5253', pr_url=None, pr_revision=None, pr_num=None)"
204+
"CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-metadata/commit/4ac3d2925d34cf183dc79a2e21b6e2f4bfe87586', commit_message='Push model using huggingface_hub.', commit_description='', oid='4ac3d2925d34cf183dc79a2e21b6e2f4bfe87586', pr_url=None, pr_revision=None, pr_num=None)"
203205
]
204206
},
205-
"execution_count": 8,
207+
"execution_count": 6,
206208
"metadata": {},
207209
"output_type": "execute_result"
208210
}
@@ -224,6 +226,158 @@
224226
"\n",
225227
"# see result here https://huggingface.co/qubvel-hf/unet-with-metadata"
226228
]
229+
},
230+
{
231+
"cell_type": "markdown",
232+
"metadata": {},
233+
"source": [
234+
"## Save model with preprocessing (using albumentations)"
235+
]
236+
},
237+
{
238+
"cell_type": "code",
239+
"execution_count": null,
240+
"metadata": {},
241+
"outputs": [],
242+
"source": [
243+
"!pip install -U albumentations numpy==1.*"
244+
]
245+
},
246+
{
247+
"cell_type": "code",
248+
"execution_count": 2,
249+
"metadata": {},
250+
"outputs": [],
251+
"source": [
252+
"import albumentations as A\n",
253+
"import segmentation_models_pytorch as smp"
254+
]
255+
},
256+
{
257+
"cell_type": "code",
258+
"execution_count": 3,
259+
"metadata": {},
260+
"outputs": [],
261+
"source": [
262+
"# define a preprocessing transform for image that would be used during inference\n",
263+
"preprocessing_transform = A.Compose([\n",
264+
" A.Resize(256, 256),\n",
265+
" A.Normalize()\n",
266+
"])\n",
267+
"\n",
268+
"model = smp.Unet()"
269+
]
270+
},
271+
{
272+
"cell_type": "code",
273+
"execution_count": 4,
274+
"metadata": {},
275+
"outputs": [
276+
{
277+
"data": {
278+
"application/vnd.jupyter.widget-view+json": {
279+
"model_id": "1aa3f4db4cd2489baeac3b844977d5a2",
280+
"version_major": 2,
281+
"version_minor": 0
282+
},
283+
"text/plain": [
284+
"model.safetensors: 0%| | 0.00/97.8M [00:00<?, ?B/s]"
285+
]
286+
},
287+
"metadata": {},
288+
"output_type": "display_data"
289+
},
290+
{
291+
"data": {
292+
"text/plain": [
293+
"CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-transform/commit/680dad16431fa6efbb25832d33a24056bdf7dc1a', commit_message='Push transform using huggingface_hub.', commit_description='', oid='680dad16431fa6efbb25832d33a24056bdf7dc1a', pr_url=None, pr_revision=None, pr_num=None)"
294+
]
295+
},
296+
"execution_count": 4,
297+
"metadata": {},
298+
"output_type": "execute_result"
299+
}
300+
],
301+
"source": [
302+
"directory_or_repo_on_the_hub = \"qubvel-hf/unet-with-transform\"\n",
303+
"\n",
304+
"# save the model\n",
305+
"model.save_pretrained(directory_or_repo_on_the_hub, push_to_hub=True)\n",
306+
"\n",
307+
"# save transform\n",
308+
"preprocessing_transform.save_pretrained(directory_or_repo_on_the_hub, push_to_hub=True)"
309+
]
310+
},
311+
{
312+
"cell_type": "markdown",
313+
"metadata": {},
314+
"source": [
315+
"Now, let's restore model and preprocessing transform for inference:"
316+
]
317+
},
318+
{
319+
"cell_type": "code",
320+
"execution_count": 5,
321+
"metadata": {},
322+
"outputs": [
323+
{
324+
"name": "stdout",
325+
"output_type": "stream",
326+
"text": [
327+
"Loading weights from local directory\n",
328+
"Compose([\n",
329+
" Resize(p=1.0, height=256, width=256, interpolation=1),\n",
330+
" Normalize(p=1.0, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, normalization='standard'),\n",
331+
"], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}, is_check_shapes=True)\n"
332+
]
333+
}
334+
],
335+
"source": [
336+
"restored_model = smp.from_pretrained(directory_or_repo_on_the_hub)\n",
337+
"restored_transform = A.Compose.from_pretrained(directory_or_repo_on_the_hub)\n",
338+
"\n",
339+
"print(restored_transform)"
340+
]
341+
},
342+
{
343+
"cell_type": "code",
344+
"execution_count": 6,
345+
"metadata": {},
346+
"outputs": [
347+
{
348+
"name": "stdout",
349+
"output_type": "stream",
350+
"text": [
351+
"Compose([\n",
352+
" HorizontalFlip(p=0.5),\n",
353+
" RandomBrightnessContrast(p=0.2, brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), brightness_by_max=True),\n",
354+
" ShiftScaleRotate(p=0.5, shift_limit_x=(-0.0625, 0.0625), shift_limit_y=(-0.0625, 0.0625), scale_limit=(-0.09999999999999998, 0.10000000000000009), rotate_limit=(-45, 45), interpolation=1, border_mode=4, value=0.0, mask_value=0.0, rotate_method='largest_box'),\n",
355+
"], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}, is_check_shapes=True)\n"
356+
]
357+
}
358+
],
359+
"source": [
360+
"# You can also save training augmentations to the Hub too (and load it back)!\n",
361+
"#! Just make sure to provide key=\"train\" when saving and loading the augmentations.\n",
362+
"\n",
363+
"train_augmentations = A.Compose([\n",
364+
" A.HorizontalFlip(p=0.5),\n",
365+
" A.RandomBrightnessContrast(p=0.2),\n",
366+
" A.ShiftScaleRotate(p=0.5),\n",
367+
"])\n",
368+
"\n",
369+
"train_augmentations.save_pretrained(directory_or_repo_on_the_hub, key=\"train\", push_to_hub=True)\n",
370+
"\n",
371+
"restored_train_augmentations = A.Compose.from_pretrained(directory_or_repo_on_the_hub, key=\"train\")\n",
372+
"print(restored_train_augmentations)"
373+
]
374+
},
375+
{
376+
"cell_type": "markdown",
377+
"metadata": {},
378+
"source": [
379+
"See saved model and `albumentations` configs on the hub: https://huggingface.co/qubvel-hf/unet-with-transform/tree/main"
380+
]
227381
}
228382
],
229383
"metadata": {

0 commit comments

Comments
 (0)