Skip to content

Commit 2e64151

Browse files
committed
Definitive bug fixes
1 parent c6b4d87 commit 2e64151

File tree

7 files changed

+24
-14
lines changed

7 files changed

+24
-14
lines changed

example/cnn_mnist.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ program cnn_mnist
1212
real, allocatable :: validation_images(:,:), validation_labels(:)
1313
real, allocatable :: testing_images(:,:), testing_labels(:)
1414
integer :: n
15-
integer, parameter :: num_epochs = 10
15+
integer, parameter :: num_epochs = 250
1616

1717
call load_mnist(training_images, training_labels, &
1818
validation_images, validation_labels, &
@@ -37,7 +37,7 @@ program cnn_mnist
3737
label_digits(training_labels), &
3838
batch_size=16, &
3939
epochs=1, &
40-
optimizer=sgd(learning_rate=0.003) &
40+
optimizer=sgd(learning_rate=0.001) &
4141
)
4242

4343
print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', accuracy( &

example/cnn_mnist_1d.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
program cnn_mnist_1d
22

33
use nf, only: network, sgd, &
4-
input, conv1d, conv2d, maxpool1d, maxpool2d, flatten, dense, reshape, reshape2d, locally_connected_1d, &
4+
input, conv1d, maxpool1d, flatten, dense, reshape, reshape2d, locally_connected_1d, &
55
load_mnist, label_digits, softmax, relu
66

77
implicit none
@@ -12,7 +12,7 @@ program cnn_mnist_1d
1212
real, allocatable :: validation_images(:,:), validation_labels(:)
1313
real, allocatable :: testing_images(:,:), testing_labels(:)
1414
integer :: n
15-
integer, parameter :: num_epochs = 25
15+
integer, parameter :: num_epochs = 250
1616

1717
call load_mnist(training_images, training_labels, &
1818
validation_images, validation_labels, &
@@ -37,7 +37,7 @@ program cnn_mnist_1d
3737
label_digits(training_labels), &
3838
batch_size=16, &
3939
epochs=1, &
40-
optimizer=sgd(learning_rate=0.005) &
40+
optimizer=sgd(learning_rate=0.01) &
4141
)
4242

4343
print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', accuracy( &

src/nf/nf_conv1d_layer_submodule.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ pure module subroutine forward(self, input)
5959
class(conv1d_layer), intent(in out) :: self
6060
real, intent(in) :: input(:,:)
6161
integer :: input_channels, input_width
62-
integer :: j, n
62+
integer :: j, n, a, b
6363
integer :: iws, iwe, half_window
6464

6565
input_channels = size(input, dim=1)
@@ -95,7 +95,7 @@ pure module subroutine backward(self, input, gradient)
9595
real, intent(in) :: gradient(:,:)
9696

9797
integer :: input_channels, input_width, output_width
98-
integer :: j, n, k
98+
integer :: j, n, k, a, b, c
9999
integer :: iws, iwe, half_window
100100
real :: gdz_val
101101

src/nf/nf_datasets_mnist_submodule.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ module subroutine load_mnist(training_images, training_labels, &
5050
real, allocatable, intent(in out), optional :: testing_labels(:)
5151

5252
integer, parameter :: dtype = 4, image_size = 784
53-
integer, parameter :: num_training_images = 500
54-
integer, parameter :: num_validation_images = 100
55-
integer, parameter :: num_testing_images = 100
53+
integer, parameter :: num_training_images = 50000
54+
integer, parameter :: num_validation_images = 10000
55+
integer, parameter :: num_testing_images = 10000
5656
logical :: file_exists
5757

5858
! Check if MNIST data is present and download it if not.

src/nf/nf_network.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ module integer function get_num_params(self)
201201
!! Network instance
202202
end function get_num_params
203203

204+
204205
module function get_params(self) result(params)
205206
!! Get the network parameters (weights and biases).
206207
class(network), intent(in) :: self

src/nf/nf_network_submodule.f90

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,6 @@ module function get_num_params(self)
460460

461461
end function get_num_params
462462

463-
464463
module function get_params(self) result(params)
465464
class(network), intent(in) :: self
466465
real, allocatable :: params(:)
@@ -480,7 +479,6 @@ module function get_params(self) result(params)
480479

481480
end function get_params
482481

483-
484482
module function get_gradients(self) result(gradients)
485483
class(network), intent(in) :: self
486484
real, allocatable :: gradients(:)
@@ -640,6 +638,12 @@ module subroutine update(self, optimizer, batch_size)
640638
type is(conv2d_layer)
641639
call co_sum(this_layer % dw)
642640
call co_sum(this_layer % db)
641+
type is(conv1d_layer)
642+
call co_sum(this_layer % dw)
643+
call co_sum(this_layer % db)
644+
type is(locally_connected_1d_layer)
645+
call co_sum(this_layer % dw)
646+
call co_sum(this_layer % db)
643647
end select
644648
end do
645649
#endif
@@ -657,6 +661,12 @@ module subroutine update(self, optimizer, batch_size)
657661
type is(conv2d_layer)
658662
this_layer % dw = 0
659663
this_layer % db = 0
664+
type is(conv1d_layer)
665+
this_layer % dw = 0
666+
this_layer % db = 0
667+
type is(locally_connected_1d_layer)
668+
this_layer % dw = 0
669+
this_layer % db = 0
660670
end select
661671
end do
662672

test/test_conv1d_network.f90

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ program test_conv1d_network
6060
call cnn % forward(sample_input)
6161
call cnn % backward(y)
6262
call cnn % update(optimizer=sgd(learning_rate=1.))
63-
o = cnn % layers(2) % get_params()
64-
print *, o
63+
6564
if (all(abs(cnn % predict(sample_input) - y) < tolerance)) exit
6665
end do
6766

0 commit comments

Comments
 (0)