Skip to content

Proposition of API for the method network % evaluate #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ add_library(neural-fortran
src/nf/nf_loss_submodule.f90
src/nf/nf_maxpool2d_layer.f90
src/nf/nf_maxpool2d_layer_submodule.f90
src/nf/nf_metrics.f90
src/nf/nf_network.f90
src/nf/nf_network_submodule.f90
src/nf/nf_optimizers.f90
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
* Stochastic gradient descent optimizers: Classic, momentum, Nesterov momentum,
RMSProp, Adagrad, Adam, AdamW
* More than a dozen activation functions and their derivatives
* Loss functions and metrics: Quadratic, Mean Squared Error, Pearson Correlation etc.
* Loading dense and convolutional models from Keras HDF5 (.h5) files
* Data-based parallelism

Expand Down
16 changes: 12 additions & 4 deletions example/dense_mnist.f90
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
program dense_mnist

use nf, only: dense, input, network, sgd, label_digits, load_mnist
use nf, only: dense, input, network, sgd, label_digits, load_mnist, corr

implicit none

Expand Down Expand Up @@ -38,9 +38,17 @@ program dense_mnist
optimizer=sgd(learning_rate=3.) &
)

if (this_image() == 1) &
print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', accuracy( &
net, validation_images, label_digits(validation_labels)) * 100, ' %'
block
real, allocatable :: output_metrics(:,:)
real, allocatable :: mean_metrics(:)
! 2 metrics; 1st is default loss function (quadratic), other is Pearson corr.
output_metrics = net % evaluate(validation_images, label_digits(validation_labels), metric=corr())
mean_metrics = sum(output_metrics, 1) / size(output_metrics, 1)
if (this_image() == 1) &
print '(a,i2,3(a,f6.3))', 'Epoch ', n, ' done, Accuracy: ', &
accuracy(net, validation_images, label_digits(validation_labels)) * 100, &
'%, Loss: ', mean_metrics(1), ', Pearson correlation: ', mean_metrics(2)
end block

end do epochs

Expand Down
2 changes: 1 addition & 1 deletion fpm.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name = "neural-fortran"
version = "0.16.1"
version = "0.17.0"
license = "MIT"
author = "Milan Curcic"
maintainer = "[email protected]"
Expand Down
1 change: 1 addition & 0 deletions src/nf.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module nf
use nf_layer_constructors, only: &
conv2d, dense, flatten, input, maxpool2d, reshape
use nf_loss, only: mse, quadratic
use nf_metrics, only: corr, maxabs
use nf_network, only: network
use nf_optimizers, only: sgd, rmsprop, adam, adagrad
use nf_activation, only: activation_function, elu, exponential, &
Expand Down
9 changes: 2 additions & 7 deletions src/nf/nf_loss.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,20 @@ module nf_loss
!! loss type that extends the abstract loss derived type, and that
!! implements concrete eval and derivative methods that accept vectors.

use nf_metrics, only: metric_type
implicit none

private
public :: loss_type
public :: mse
public :: quadratic

type, abstract :: loss_type
type, extends(metric_type), abstract :: loss_type
contains
procedure(loss_interface), nopass, deferred :: eval
procedure(loss_derivative_interface), nopass, deferred :: derivative
end type loss_type

abstract interface
pure function loss_interface(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
real :: res
end function loss_interface
pure function loss_derivative_interface(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
Expand Down
72 changes: 72 additions & 0 deletions src/nf/nf_metrics.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
module nf_metrics

!! This module provides a collection of metric functions.

implicit none

private
public :: metric_type
public :: corr
public :: maxabs

type, abstract :: metric_type
contains
procedure(metric_interface), nopass, deferred :: eval
end type metric_type

abstract interface
pure function metric_interface(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
real :: res
end function metric_interface
end interface

type, extends(metric_type) :: corr
!! Pearson correlation
contains
procedure, nopass :: eval => corr_eval
end type corr

type, extends(metric_type) :: maxabs
!! Maximum absolute difference
contains
procedure, nopass :: eval => maxabs_eval
end type maxabs

contains

pure module function corr_eval(true, predicted) result(res)
!! Pearson correlation function:
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res
!! Resulting correlation value
real :: m_true, m_pred

m_true = sum(true) / size(true)
m_pred = sum(predicted) / size(predicted)

res = dot_product(true - m_true, predicted - m_pred) / &
sqrt(sum((true - m_true)**2)*sum((predicted - m_pred)**2))

end function corr_eval

pure function maxabs_eval(true, predicted) result(res)
!! Maximum absolute difference function:
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res
!! Resulting maximum absolute difference value

res = maxval(abs(true - predicted))

end function maxabs_eval

end module nf_metrics
13 changes: 13 additions & 0 deletions src/nf/nf_network.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module nf_network
!! This module provides the network type to create new models.

use nf_layer, only: layer
use nf_metrics, only: metric_type
use nf_loss, only: loss_type
use nf_optimizers, only: optimizer_base_type

Expand All @@ -28,13 +29,15 @@ module nf_network
procedure :: train
procedure :: update

procedure, private :: evaluate_batch_1d
procedure, private :: forward_1d
procedure, private :: forward_3d
procedure, private :: predict_1d
procedure, private :: predict_3d
procedure, private :: predict_batch_1d
procedure, private :: predict_batch_3d

generic :: evaluate => evaluate_batch_1d
generic :: forward => forward_1d, forward_3d
generic :: predict => predict_1d, predict_3d, predict_batch_1d, predict_batch_3d

Expand Down Expand Up @@ -62,6 +65,16 @@ end function network_from_keras

end interface network

interface evaluate
module function evaluate_batch_1d(self, input_data, output_data, metric) result(res)
class(network), intent(in out) :: self
real, intent(in) :: input_data(:,:)
real, intent(in) :: output_data(:,:)
class(metric_type), intent(in), optional :: metric
real, allocatable :: res(:,:)
end function evaluate_batch_1d
end interface evaluate

interface forward

pure module subroutine forward_1d(self, input)
Expand Down
30 changes: 30 additions & 0 deletions src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,36 @@ pure module subroutine backward(self, output, loss)
end subroutine backward


module function evaluate_batch_1d(self, input_data, output_data, metric) result(res)
class(network), intent(in out) :: self
real, intent(in) :: input_data(:,:)
real, intent(in) :: output_data(:,:)
class(metric_type), intent(in), optional :: metric
real, allocatable :: res(:,:)

integer :: i, n
real, allocatable :: output(:,:)

output = self % predict(input_data)

n = 1
if (present(metric)) n = n + 1

allocate(res(size(output, dim=1), n))

do concurrent (i = 1:size(output, dim=1))
res(i,1) = self % loss % eval(output_data(i,:), output(i,:))
end do

if (.not. present(metric)) return

do concurrent (i = 1:size(output, dim=1))
res(i,2) = metric % eval(output_data(i,:), output(i,:))
end do

end function evaluate_batch_1d


pure module subroutine forward_1d(self, input)
class(network), intent(in out) :: self
real, intent(in) :: input(:)
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ foreach(execid
conv2d_network
optimizers
loss
metrics
)
add_executable(test_${execid} test_${execid}.f90)
target_link_libraries(test_${execid} PRIVATE neural-fortran h5fortran::h5fortran jsonfortran::jsonfortran ${LIBS})
Expand Down
70 changes: 70 additions & 0 deletions test/test_metrics.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
program test_metrics
use iso_fortran_env, only: stderr => error_unit
use nf, only: dense, input, network, sgd, mse
implicit none
type(network) :: net
logical :: ok = .true.

! Minimal 2-layer network
net = network([ &
input(1), &
dense(1) &
])

training: block
real :: x(1), y(1)
real :: tolerance = 1e-3
integer :: n
integer, parameter :: num_iterations = 1000
real :: quadratic_loss, mse_metric
real, allocatable :: metrics(:,:)

x = [0.1234567]
y = [0.7654321]

do n = 1, num_iterations
call net % forward(x)
call net % backward(y)
call net % update(sgd(learning_rate=1.))
if (all(abs(net % predict(x) - y) < tolerance)) exit
end do

! Returns only one metric, based on the default loss function (quadratic).
metrics = net % evaluate(reshape(x, [1, 1]), reshape(y, [1, 1]))
quadratic_loss = metrics(1,1)

if (.not. all(shape(metrics) == [1, 1])) then
write(stderr, '(a)') 'metrics array is the correct shape (1, 1).. failed'
ok = .false.
end if

! Returns two metrics, one from the loss function and another specified by the user.
metrics = net % evaluate(reshape(x, [1, 1]), reshape(y, [1, 1]), metric=mse())

if (.not. all(shape(metrics) == [1, 2])) then
write(stderr, '(a)') 'metrics array is the correct shape (1, 2).. failed'
ok = .false.
end if

mse_metric = metrics(1,2)

if (.not. all(metrics < 1e-5)) then
write(stderr, '(a)') 'value for all metrics is expected.. failed'
ok = .false.
end if

if (.not. metrics(1,1) == quadratic_loss) then
write(stderr, '(a)') 'first metric should be the same as that of the loss function.. failed'
ok = .false.
end if

end block training

if (ok) then
print '(a)', 'test_metrics: All tests passed.'
else
write(stderr, '(a)') 'test_metrics: One or more tests failed.'
stop 1
end if

end program test_metrics