Skip to content

Commit d47f0f3

Browse files
jafraustrosoumith
authored andcommitted
Add accelerator API to GCN example.
- Add accel argument; update requirements file Signed-off-by: jafraustro <[email protected]>
1 parent c04a5a1 commit d47f0f3

File tree

3 files changed

+17
-33
lines changed

3 files changed

+17
-33
lines changed

gcn/README.md

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,14 @@ This repository contains an implementation of Graph Convolutional Networks (GCN)
55
## Overview
66
This project implements the GCN model proposed in the paper for semi-supervised node classification on graph-structured data. GCN leverages graph convolutions to aggregate information from neighboring nodes and learn node representations for downstream tasks. The implementation provides a flexible and efficient GCN model for graph-based machine learning tasks.
77

8-
# Requirements
9-
- Python 3.7 or higher
10-
- PyTorch 2.0 or higher
11-
- Requests 2.31 or higher
12-
- NumPy 1.24 or higher
13-
14-
15-
# Installation
8+
## Requirements
169
```bash
1710
pip install -r requirements.txt
18-
python main.py
11+
```
12+
13+
# Usage
14+
```bash
15+
python main.py --epochs 200 --lr 0.01 --l2 5e-4 --dropout-p 0.5 --hidden-dim 16 --val-every 20 --include-bias
1916
```
2017

2118
# Dataset
@@ -24,12 +21,6 @@ The implementation includes support for the Cora dataset, a standard benchmark d
2421
## Model Architecture
2522
The GCN model architecture follows the details provided in the paper. It consists of multiple graph convolutional layers with ReLU activation, followed by a final softmax layer for classification. The implementation supports customizable hyperparameters such as the number of hidden units, the number of layers, and dropout rate.
2623

27-
## Usage
28-
To train and evaluate the GCN model on the Cora dataset, use the following command:
29-
```bash
30-
python train.py --epochs 200 --lr 0.01 --l2 5e-4 --dropout-p 0.5 --hidden-dim 16 --val-every 20 --include-bias False --no-cuda False
31-
```
32-
3324
# Results
3425
The model achieves a classification accuracy of 82.5% on the test set of the Cora dataset after 200 epochs of training. This result is comparable to the performance reported in the original paper. However, the results can vary due to the randomness of the train/val/test split.
3526

gcn/main.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ def load_cora(path='./cora', device='cpu'):
165165

166166
return features.to_sparse().to(device), labels.to(device), adj_mat.to_sparse().to(device)
167167

168-
169168
def train_iter(epoch, model, optimizer, criterion, input, target, mask_train, mask_val, print_every=10):
170169
start_t = time.time()
171170
model.train()
@@ -199,8 +198,6 @@ def test(model, criterion, input, target, mask):
199198

200199

201200
if __name__ == '__main__':
202-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
203-
204201
parser = argparse.ArgumentParser(description='PyTorch Graph Convolutional Network')
205202
parser.add_argument('--epochs', type=int, default=200,
206203
help='number of epochs to train (default: 200)')
@@ -214,29 +211,25 @@ def test(model, criterion, input, target, mask):
214211
help='dimension of the hidden representation (default: 16)')
215212
parser.add_argument('--val-every', type=int, default=20,
216213
help='epochs to wait for print training and validation evaluation (default: 20)')
217-
parser.add_argument('--include-bias', action='store_true', default=False,
214+
parser.add_argument('--include-bias', action='store_true',
218215
help='use bias term in convolutions (default: False)')
219-
parser.add_argument('--no-cuda', action='store_true', default=False,
220-
help='disables CUDA training')
221-
parser.add_argument('--no-mps', action='store_true', default=False,
222-
help='disables macOS GPU training')
223-
parser.add_argument('--dry-run', action='store_true', default=False,
216+
parser.add_argument('--no-accel', action='store_true',
217+
help='disables accelerator')
218+
parser.add_argument('--dry-run', action='store_true',
224219
help='quickly check a single pass')
225220
parser.add_argument('--seed', type=int, default=42, metavar='S',
226221
help='random seed (default: 42)')
227222
args = parser.parse_args()
228223

229-
use_cuda = not args.no_cuda and torch.cuda.is_available()
230-
use_mps = not args.no_mps and torch.backends.mps.is_available()
224+
use_accel = not args.no_accel and torch.accelerator.is_available()
231225

232226
torch.manual_seed(args.seed)
233227

234-
if use_cuda:
235-
device = torch.device('cuda')
236-
elif use_mps:
237-
device = torch.device('mps')
228+
if use_accel:
229+
device = torch.accelerator.current_accelerator()
238230
else:
239231
device = torch.device('cpu')
232+
240233
print(f'Using {device} device')
241234

242235
cora_url = 'https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz'

gcn/requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
torch
2-
torchvision==0.20.0
1+
torch>=2.6
2+
torchvision
33
requests
4-
numpy<2
4+
numpy

0 commit comments

Comments
 (0)