From 933d976f4e06f2284043f8e45af15db37ed84076 Mon Sep 17 00:00:00 2001 From: sudomaze Date: Thu, 5 May 2022 16:27:47 -0700 Subject: [PATCH 1/7] add siamese network example --- .gitignore | 3 + siamese_network/README.md | 7 + siamese_network/main.py | 214 +++++++++++++++++++++++++++++++ siamese_network/requirements.txt | 2 + 4 files changed, 226 insertions(+) create mode 100644 siamese_network/README.md create mode 100644 siamese_network/main.py create mode 100644 siamese_network/requirements.txt diff --git a/.gitignore b/.gitignore index 14ec8ef205..56cd7649a4 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,6 @@ docs/venv # vi backups *~ + +# development +.vscode \ No newline at end of file diff --git a/siamese_network/README.md b/siamese_network/README.md new file mode 100644 index 0000000000..dc97bd5c53 --- /dev/null +++ b/siamese_network/README.md @@ -0,0 +1,7 @@ +# Basic MNIST Example + +```bash +pip install -r requirements.txt +python main.py +# CUDA_VISIBLE_DEVICES=2 python main.py # to specify GPU id to ex. 2 +``` diff --git a/siamese_network/main.py b/siamese_network/main.py new file mode 100644 index 0000000000..2695f2af9d --- /dev/null +++ b/siamese_network/main.py @@ -0,0 +1,214 @@ +from __future__ import print_function +import argparse, random, copy +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision +from torch.utils.data import Dataset +from torchvision import datasets +from torchvision import transforms as T +from torch.optim.lr_scheduler import StepLR + + +class SiameseNetwork(nn.Module): + def __init__(self): + super(SiameseNetwork, self).__init__() + # get resnet model + self.resnet = torchvision.models.resnet18(pretrained=False) + + # over-write the first conv layer to be able to read MNIST images + # as resnet18 reads (3,x,x) where 3 is RGB channels + # whereas MNIST has (1,x,x) where 1 is a gray-scale channel + self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + self.fc_in_features = self.resnet.fc.in_features + + # remove the last layer of resnet18 (linear layer which is before avgpool layer) + self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1])) + + # add linear layers to compare between the features of the two images + self.fc = nn.Sequential( + nn.Linear(self.fc_in_features * 2, 256), + nn.ReLU(inplace=True), + + nn.Linear(256, 1), + ) + + self.sigmoid = nn.Sigmoid() + + # initialize the weights + self.resnet.apply(self.init_weights) + self.fc.apply(self.init_weights) + + def init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform(m.weight) + m.bias.data.fill_(0.01) + + def forward_once(self, x): + output = self.resnet(x) + output = output.view(output.size()[0], -1) + return output + + def forward(self, input1, input2): + # get two images' features + output1 = self.forward_once(input1) + output2 = self.forward_once(input2) + + # concatnate both images' features + output = torch.cat((output1, output2), 1) + + # pass the concatnation to the linear layers + output = self.fc(output) + output = self.sigmoid(output) + + return output + +class APP_MATCHER(Dataset): + def __init__(self, root, train, download=False): + super(APP_MATCHER, self).__init__() + # get MNIST dataset + self.dataset = datasets.MNIST(root, train=train, download=download) + + # get targets (labels) and data (images) + self.targets = copy.deepcopy(self.dataset.targets) + self.data = copy.deepcopy(self.dataset.data.unsqueeze(1)) + + self.group_sets() + + def group_sets(self): + np_arr = np.array(self.dataset.targets.clone()) + self.grouped_indices = {} + for i in range(0,10): + self.grouped_indices[i] = np.where((np_arr==i))[0] + + def __len__(self): + return self.data.shape[0] + + def __getitem__(self, index): + selected_class = random.randint(0, 9) + random_index_1 = random.randint(0, self.grouped_indices[selected_class].shape[0]-1) + index_1 = self.grouped_indices[selected_class][random_index_1] + image_1 = self.data[index_1].clone().float() + + # same class + if index % 2 == 0: + random_index_2 = random.randint(0, self.grouped_indices[selected_class].shape[0]-1) + while random_index_2 == random_index_1: + random_index_2 = random.randint(0, self.grouped_indices[selected_class].shape[0]-1) + index_2 = self.grouped_indices[selected_class][random_index_2] + image_2 = self.data[index_2].clone().float() + target = torch.tensor(1, dtype=torch.float) + + # different class + else: + other_selected_class = random.randint(0, 9) + while other_selected_class == selected_class: + other_selected_class = random.randint(0, 9) + random_index_2 = random.randint(0, self.grouped_indices[other_selected_class].shape[0]-1) + index_2 = self.grouped_indices[other_selected_class][random_index_2] + image_2 = self.data[index_2].clone().float() + target = torch.tensor(0, dtype=torch.float) + + return image_1, image_2, target + + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + criterion = nn.BCELoss() + for batch_idx, (images_1, images_2, targets) in enumerate(train_loader): + images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device) + optimizer.zero_grad() + outputs = model(images_1, images_2).squeeze() + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(images_1), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item())) + if args.dry_run: + break + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + criterion = nn.BCELoss() + with torch.no_grad(): + for (images_1, images_2, targets) in test_loader: + images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device) + outputs = model(images_1, images_2).squeeze() + test_loss += criterion(outputs, targets).sum().item() # sum up batch loss + pred = torch.where(outputs > 0.5, 1, 0) # get the index of the max log-probability + correct += pred.eq(targets.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=14, metavar='N', + help='number of epochs to train (default: 14)') + parser.add_argument('--lr', type=float, default=1.0, metavar='LR', + help='learning rate (default: 1.0)') + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', + help='Learning rate step gamma (default: 0.7)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--dry-run', action='store_true', default=False, + help='quickly check a single pass') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument('--save-model', action='store_true', default=False, + help='For Saving the current Model') + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + + torch.manual_seed(args.seed) + + device = torch.device("cuda" if use_cuda else "cpu") + + train_kwargs = {'batch_size': args.batch_size} + test_kwargs = {'batch_size': args.test_batch_size} + if use_cuda: + cuda_kwargs = {'num_workers': 1, + 'pin_memory': True, + 'shuffle': True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + train_dataset = APP_MATCHER('../data', train=True, download=True) + test_dataset = APP_MATCHER('../data', train=False) + train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs) + test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs) + + model = SiameseNetwork().to(device) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model.state_dict(), "siamese_network.pt") + + +if __name__ == '__main__': + main() diff --git a/siamese_network/requirements.txt b/siamese_network/requirements.txt new file mode 100644 index 0000000000..ac988bdf84 --- /dev/null +++ b/siamese_network/requirements.txt @@ -0,0 +1,2 @@ +torch +torchvision From aec9e7035f3b9ab4defa19a3666886cac78a1958 Mon Sep 17 00:00:00 2001 From: sudomaze Date: Thu, 5 May 2022 16:28:19 -0700 Subject: [PATCH 2/7] update README --- siamese_network/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/siamese_network/README.md b/siamese_network/README.md index dc97bd5c53..84612bd401 100644 --- a/siamese_network/README.md +++ b/siamese_network/README.md @@ -1,4 +1,4 @@ -# Basic MNIST Example +# Basic Siamese Network Example ```bash pip install -r requirements.txt From 11cb38f9a39476055dc6e0a1a6f287300a4da989 Mon Sep 17 00:00:00 2001 From: sudomaze Date: Sat, 7 May 2022 19:10:31 -0700 Subject: [PATCH 3/7] Updated based on PR's feedback - Included `siamese_network` setup to `run_python_examples.sh` - In `siamese_network/main.py`, included more explanation and detailed comments per @msaroufim's feedback. --- run_python_examples.sh | 6 ++ siamese_network/README.md | 2 +- siamese_network/main.py | 125 +++++++++++++++++++++++++++++++------- 3 files changed, 110 insertions(+), 23 deletions(-) diff --git a/run_python_examples.sh b/run_python_examples.sh index 10fd427bb6..7244ff7ccf 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -110,6 +110,11 @@ function regression() { python main.py --epochs 1 $CUDA_FLAG || error "regression failed" } +function siamese_network() { + start + python main.py --epochs 1 --dry-run || error "siamese network example failed" +} + function reinforcement_learning() { start python reinforce.py || error "reinforcement learning reinforce failed" @@ -193,6 +198,7 @@ function run_all() { mnist_hogwild regression reinforcement_learning + siamese_network super_resolution time_sequence_prediction vae diff --git a/siamese_network/README.md b/siamese_network/README.md index 84612bd401..973a0414a4 100644 --- a/siamese_network/README.md +++ b/siamese_network/README.md @@ -1,4 +1,4 @@ -# Basic Siamese Network Example +# Siamese Network Example ```bash pip install -r requirements.txt diff --git a/siamese_network/main.py b/siamese_network/main.py index 2695f2af9d..d6f146bd48 100644 --- a/siamese_network/main.py +++ b/siamese_network/main.py @@ -14,6 +14,16 @@ class SiameseNetwork(nn.Module): + """ + Siamese network for image similarity estimation. + The network is composed of two identical networks, one for each input. + The output of each network is concatenated and passed to a linear layer. + The output of the linear layer passed through a sigmoid function. + `"FaceNet" `_ is a variant of the Siamese network. + This implementation varies from FaceNet as we use the `ResNet-18` model from + `"Deep Residual Learning for Image Recognition" `_ as our feature extractor. + In addition, we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick. + """ def __init__(self): super(SiameseNetwork, self).__init__() # get resnet model @@ -32,7 +42,6 @@ def __init__(self): self.fc = nn.Sequential( nn.Linear(self.fc_in_features * 2, 256), nn.ReLU(inplace=True), - nn.Linear(256, 1), ) @@ -57,11 +66,13 @@ def forward(self, input1, input2): output1 = self.forward_once(input1) output2 = self.forward_once(input2) - # concatnate both images' features + # concatenate both images' features output = torch.cat((output1, output2), 1) - # pass the concatnation to the linear layers + # pass the concatenation to the linear layers output = self.fc(output) + + # pass the out of the linear layers to sigmoid layer output = self.sigmoid(output) return output @@ -69,47 +80,107 @@ def forward(self, input1, input2): class APP_MATCHER(Dataset): def __init__(self, root, train, download=False): super(APP_MATCHER, self).__init__() + # get MNIST dataset self.dataset = datasets.MNIST(root, train=train, download=download) - - # get targets (labels) and data (images) - self.targets = copy.deepcopy(self.dataset.targets) - self.data = copy.deepcopy(self.dataset.data.unsqueeze(1)) - - self.group_sets() - - def group_sets(self): + + # as `self.dataset.data`'s shape is (Nx28x28), where N is the number of + # examples in MNIST dataset, a single example has the dimensions of + # (28x28) for (WxH), where W and H are the width and the height of the image. + # However, every example should have (CxWxH) dimensions where C is the number + # of channels to be passed to the network. As MNIST contains gray-scale images, + # we add an additional dimension to corresponds to the number of channels. + self.data = self.dataset.data.unsqueeze(1).clone() + + self.group_examples() + + def group_examples(self): + """ + To ease the accessibility of data based on the class, we will use `group_examples` to group + examples based on class. + + Every key in `grouped_examples` corresponds to a class in MNIST dataset. For every key in + `grouped_examples`, every value will conform to all of the indices for the MNIST + dataset examples that correspond to that key. + """ + + # get the targets from MNIST dataset np_arr = np.array(self.dataset.targets.clone()) - self.grouped_indices = {} + + # group examples based on class + self.grouped_examples = {} for i in range(0,10): - self.grouped_indices[i] = np.where((np_arr==i))[0] + self.grouped_examples[i] = np.where((np_arr==i))[0] def __len__(self): return self.data.shape[0] def __getitem__(self, index): + """ + For every example, we will select two images. There are two cases, + positive and negative examples. For positive examples, we will have two + images from the same class. For negative examples, we will have two images + from a different classes. + + Given the index, if the index is even, we will pick the second image from the same class, + but it won't be the same image we chose for the first class. This is used to ensure as it is easy + for the network to notice that two images that are the same are positive examples. However, if + we are given two different images from the same class, the network will need to learn the similarity + between two different images representing the same class. + """ + + # pick some random class for the first image selected_class = random.randint(0, 9) - random_index_1 = random.randint(0, self.grouped_indices[selected_class].shape[0]-1) - index_1 = self.grouped_indices[selected_class][random_index_1] + + # pick a random index for the first image in the grouped indices based of the label + # of the class + random_index_1 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1) + + # pick the index to get the first image + index_1 = self.grouped_examples[selected_class][random_index_1] + + # get the first image image_1 = self.data[index_1].clone().float() # same class if index % 2 == 0: - random_index_2 = random.randint(0, self.grouped_indices[selected_class].shape[0]-1) + # pick a random index for the second image + random_index_2 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1) + + # ensure that the index of the second image isn't the same as the first image while random_index_2 == random_index_1: - random_index_2 = random.randint(0, self.grouped_indices[selected_class].shape[0]-1) - index_2 = self.grouped_indices[selected_class][random_index_2] + random_index_2 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1) + + # pick the index to get the second image + index_2 = self.grouped_examples[selected_class][random_index_2] + + # get the second image image_2 = self.data[index_2].clone().float() - target = torch.tensor(1, dtype=torch.float) + # set the label for this example to be positive (1) + target = torch.tensor(1, dtype=torch.float) + # different class else: + # pick a random class other_selected_class = random.randint(0, 9) + + # ensure that the class of the second image isn't the same as the first image while other_selected_class == selected_class: other_selected_class = random.randint(0, 9) - random_index_2 = random.randint(0, self.grouped_indices[other_selected_class].shape[0]-1) - index_2 = self.grouped_indices[other_selected_class][random_index_2] + + + # pick a random index for the second image in the grouped indices based of the label + # of the class + random_index_2 = random.randint(0, self.grouped_examples[other_selected_class].shape[0]-1) + + # pick the index to get the second image + index_2 = self.grouped_examples[other_selected_class][random_index_2] + + # get the second image image_2 = self.data[index_2].clone().float() + + # set the label for this example to be negative (0) target = torch.tensor(0, dtype=torch.float) return image_1, image_2, target @@ -117,7 +188,10 @@ def __getitem__(self, index): def train(args, model, device, train_loader, optimizer, epoch): model.train() + + # we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick. criterion = nn.BCELoss() + for batch_idx, (images_1, images_2, targets) in enumerate(train_loader): images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device) optimizer.zero_grad() @@ -137,7 +211,10 @@ def test(model, device, test_loader): model.eval() test_loss = 0 correct = 0 + + # we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick. criterion = nn.BCELoss() + with torch.no_grad(): for (images_1, images_2, targets) in test_loader: images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device) @@ -148,6 +225,9 @@ def test(model, device, test_loader): test_loss /= len(test_loader.dataset) + # for the 1st epoch, the average loss is 0.0001 and the accuracy 97-98% + # using default settings. After completing the 10th epoch, the average + # loss is 0.0000 and the accuracy 99.5-100% using default settings. print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) @@ -155,7 +235,7 @@ def test(model, device, test_loader): def main(): # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser = argparse.ArgumentParser(description='PyTorch Siamese network Example') parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', @@ -177,6 +257,7 @@ def main(): parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) From 57903baa98ab1602dc552eb04dcae7ee7360aed1 Mon Sep 17 00:00:00 2001 From: sudomaze Date: Sat, 7 May 2022 19:14:38 -0700 Subject: [PATCH 4/7] included siamese network example to docs --- docs/source/index.rst | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/source/index.rst b/docs/source/index.rst index a5a89d8644..dffc26ab11 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -17,6 +17,17 @@ experiment with PyTorch. --- + Measuring Similarity using Siamese Network + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + This example demonstrates how to measure similarity between two images + using `Siamese network `__ + on the `MNIST `__ database. + + `GO TO EXAMPLE `__ :opticon:`link-external` + + --- + Word-level Language Modeling using RNN and Transformer ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From 7cb4788bfed88b7cdfe7ea70b69112e4dfdc7abd Mon Sep 17 00:00:00 2001 From: sudomaze Date: Sat, 7 May 2022 19:14:48 -0700 Subject: [PATCH 5/7] Fixed a typo to run sphinx locally `sphinx-serve -d build` doesn't work as `-d` flag doesn't exist. The correct flag is `-b` per `sphinx-serve` help page. Hence, the edit was meant to reflect the correct command `sphinx-serve -b build`. --- CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1e6e23fd44..1c15a5513e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -26,7 +26,7 @@ If you're new we encourage you to take a look at issues tagged with [good first 5. Verify that there are no issues in your doc build. You can check preview locally by installing [sphinx-serve](https://pypi.org/project/sphinx-serve/) and - then running `sphinx-serve -d build`. + then running `sphinx-serve -b build`. 5. Ensure your test passes locally 6. If you haven't already, complete the Contributor License Agreement ("CLA"). From 8c1727546c7c42ee3fc851e06d28e1bd262c3e9b Mon Sep 17 00:00:00 2001 From: sudomaze Date: Sat, 7 May 2022 19:31:54 -0700 Subject: [PATCH 6/7] fix typos --- siamese_network/main.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/siamese_network/main.py b/siamese_network/main.py index d6f146bd48..ec37b7ca39 100644 --- a/siamese_network/main.py +++ b/siamese_network/main.py @@ -120,13 +120,14 @@ def __getitem__(self, index): For every example, we will select two images. There are two cases, positive and negative examples. For positive examples, we will have two images from the same class. For negative examples, we will have two images - from a different classes. + from different classes. Given the index, if the index is even, we will pick the second image from the same class, but it won't be the same image we chose for the first class. This is used to ensure as it is easy for the network to notice that two images that are the same are positive examples. However, if we are given two different images from the same class, the network will need to learn the similarity - between two different images representing the same class. + between two different images representing the same class. If the index is odd, we will pick the second + image from a different class than the first image. """ # pick some random class for the first image From 332e138a193805a231c7dbe165708594bf7df2fc Mon Sep 17 00:00:00 2001 From: sudomaze Date: Sat, 7 May 2022 19:36:17 -0700 Subject: [PATCH 7/7] better description --- siamese_network/main.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/siamese_network/main.py b/siamese_network/main.py index ec37b7ca39..33a5f71517 100644 --- a/siamese_network/main.py +++ b/siamese_network/main.py @@ -122,12 +122,12 @@ def __getitem__(self, index): images from the same class. For negative examples, we will have two images from different classes. - Given the index, if the index is even, we will pick the second image from the same class, - but it won't be the same image we chose for the first class. This is used to ensure as it is easy - for the network to notice that two images that are the same are positive examples. However, if - we are given two different images from the same class, the network will need to learn the similarity - between two different images representing the same class. If the index is odd, we will pick the second - image from a different class than the first image. + Given an index, if the index is even, we will pick the second image from the same class, + but it won't be the same image we chose for the first class. This is used to ensure the positive + example isn't trivial as the network would easily distinguish the similarity between same images. However, + if the network were given two different images from the same class, the network will need to learn + the similarity between two different images representing the same class. If the index is odd, we will + pick the second image from a different class than the first image. """ # pick some random class for the first image