Skip to content

Commit c52f6a0

Browse files
authored
Default Export for Mv2 with Real Image Input (#13019)
### Summary When testing executorch inference runners on real MCUs, it's helpful if the default export of the model includes a real input (rather than just filled with random values or all 1.0s). This PR does this for the mobilenet_v2 model, following the steps from the tutorial [here](https://pytorch.org/hub/pytorch_vision_mobilenet_v2/). ### Test plan Ran `python -m examples.portable.scripts.export --model_name="mv2"` and saw the PTE now contains values associated with the `dog.jpg` image.
1 parent d4f5ee6 commit c52f6a0

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

examples/models/mobilenet_v2/model.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616

1717
class MV2Model(EagerModelBase):
18-
def __init__(self):
18+
def __init__(self, use_real_input=True):
19+
self.use_real_input = use_real_input
1920
pass
2021

2122
def get_eager_model(self) -> torch.nn.Module:
@@ -26,6 +27,36 @@ def get_eager_model(self) -> torch.nn.Module:
2627

2728
def get_example_inputs(self):
2829
tensor_size = (1, 3, 224, 224)
30+
input_batch = (torch.randn(tensor_size),)
31+
if self.use_real_input:
32+
logging.info("Loaded real input image dog.jpg")
33+
import urllib
34+
35+
url, filename = (
36+
"https://github.com/pytorch/hub/raw/master/images/dog.jpg",
37+
"dog.jpg",
38+
)
39+
try:
40+
urllib.URLopener().retrieve(url, filename)
41+
except:
42+
urllib.request.urlretrieve(url, filename)
43+
from PIL import Image
44+
from torchvision import transforms
45+
46+
input_image = Image.open(filename)
47+
preprocess = transforms.Compose(
48+
[
49+
transforms.Resize(256),
50+
transforms.CenterCrop(224),
51+
transforms.ToTensor(),
52+
transforms.Normalize(
53+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
54+
),
55+
]
56+
)
57+
input_tensor = preprocess(input_image)
58+
input_batch = input_tensor.unsqueeze(0)
59+
input_batch = (input_batch,)
2960
return (torch.randn(tensor_size),)
3061

3162

0 commit comments

Comments
 (0)