diff --git a/examples/models/mobilenet_v2/model.py b/examples/models/mobilenet_v2/model.py index f15178ac71b..5c2c7ff7016 100644 --- a/examples/models/mobilenet_v2/model.py +++ b/examples/models/mobilenet_v2/model.py @@ -15,7 +15,8 @@ class MV2Model(EagerModelBase): - def __init__(self): + def __init__(self, use_real_input=True): + self.use_real_input = use_real_input pass def get_eager_model(self) -> torch.nn.Module: @@ -26,6 +27,36 @@ def get_eager_model(self) -> torch.nn.Module: def get_example_inputs(self): tensor_size = (1, 3, 224, 224) + input_batch = (torch.randn(tensor_size),) + if self.use_real_input: + logging.info("Loaded real input image dog.jpg") + import urllib + + url, filename = ( + "https://github.com/pytorch/hub/raw/master/images/dog.jpg", + "dog.jpg", + ) + try: + urllib.URLopener().retrieve(url, filename) + except: + urllib.request.urlretrieve(url, filename) + from PIL import Image + from torchvision import transforms + + input_image = Image.open(filename) + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + input_tensor = preprocess(input_image) + input_batch = input_tensor.unsqueeze(0) + input_batch = (input_batch,) return (torch.randn(tensor_size),)