From cad795cceb363f0b9f91fcc305e999df95b8010b Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Tue, 25 Feb 2025 21:03:17 +0400 Subject: [PATCH 1/9] multihead_attention_optimization: allocate in init --- src/nf/nf_multihead_attention.f90 | 12 ++ src/nf/nf_multihead_attention_submodule.f90 | 119 ++++++++------------ 2 files changed, 57 insertions(+), 74 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 80a59dfb..2222b27b 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -39,6 +39,18 @@ module nf_multihead_attention_layer real, allocatable :: k_input(:, :) real, allocatable :: v_input(:, :) real, allocatable :: o_input(:, :) + + ! temporary storages for forward and backward passes + real, allocatable, private :: q_or_dq(:, :, :) + real, allocatable, private :: k_or_dk(:, :, :) + real, allocatable, private :: v_or_dv(:, :, :) + real, allocatable, private :: d_output(:, :, :) + real, allocatable, private :: v_heads(:, :, :) + real, allocatable, private :: k_heads(:, :, :) + real, allocatable, private :: q_heads(:, :, :) + real, allocatable, private :: d_sdpa(:, :) + real, allocatable, private :: jacobian(:, :) + real, allocatable, private :: d_normalize(:, :, :) contains procedure :: common_backward diff --git a/src/nf/nf_multihead_attention_submodule.f90 b/src/nf/nf_multihead_attention_submodule.f90 index d0e43a2e..80eabda9 100644 --- a/src/nf/nf_multihead_attention_submodule.f90 +++ b/src/nf/nf_multihead_attention_submodule.f90 @@ -1,5 +1,4 @@ submodule(nf_multihead_attention_layer) nf_multihead_attention_layer_submodule -! use iso_fortran_env, only: stderr => error_unit use nf_activation, only: softmax use nf_base_layer, only: base_layer use nf_linear2d_layer, only: linear2d_layer @@ -19,46 +18,26 @@ pure module subroutine common_backward(self, input, gradient) real, intent(in) :: input(:, :) real, intent(in) :: gradient(:, :) - real, allocatable :: d_output(:, :, :) - real, allocatable :: v_heads(:, :, :) - real, allocatable :: k_heads(:, :, :) - real, allocatable :: q_heads(:, :, :) - real, allocatable :: dv(:, :, :) - real, allocatable :: d_sdpa(:, :) - real, allocatable :: jacobian(:, :) - real, allocatable :: d_normalize(:, :, :) - real, allocatable :: dq(:, :, :) - real, allocatable :: dk(:, :, :) integer :: head, seq, i, j - ! allocate temporary storages for backward computation - allocate(d_output(self % sequence_length, self % head_size, self % n_heads)) - allocate(v_heads(self % sequence_length, self % head_size, self % n_heads)) - allocate(k_heads(self % sequence_length, self % head_size, self % n_heads)) - allocate(q_heads(self % sequence_length, self % head_size, self % n_heads)) - - allocate(dv(self % sequence_length, self % head_size, self % n_heads)) - allocate(d_sdpa(self % sequence_length, self % sequence_length)) - allocate(jacobian(self % sequence_length, self % sequence_length)) - allocate(d_normalize(self % sequence_length, self % sequence_length, self % n_heads)) - allocate(dq(self % sequence_length, self % head_size, self % n_heads)) - allocate(dk(self % sequence_length, self % head_size, self % n_heads)) - ! calculate output layer delta call self % output_layer % backward(self % o_input, gradient) ! split heads from output gradient - d_output = self % split_heads(self % output_layer % gradient) - v_heads = self % split_heads(self % value_layer % output) - k_heads = self % split_heads(self % key_layer % output) - q_heads = self % split_heads(self % query_layer % output) + self % d_output = self % split_heads(self % output_layer % gradient) + self % v_heads = self % split_heads(self % value_layer % output) + self % k_heads = self % split_heads(self % key_layer % output) + self % q_heads = self % split_heads(self % query_layer % output) ! iterate over heads to calculate deltas for each of them do concurrent(head = 1: self % n_heads) - dv(:, :, head) = matmul(transpose(self % attention_matrix(:, :, head)), d_output(:, :, head)) + self % v_or_dv(:, :, head) = matmul(& + transpose(self % attention_matrix(:, :, head)),& + self % d_output(:, :, head)& + ) ! calculate delta for attention matrix - d_sdpa = matmul(d_output(:, :, head), transpose(v_heads(:, :, head))) + self % d_sdpa = matmul(self % d_output(:, :, head), transpose(self % v_heads(:, :, head))) ! this monstrosity below is scaled derivative of softmax do concurrent(seq = 1: self % sequence_length) @@ -69,11 +48,11 @@ pure module subroutine common_backward(self, input, gradient) ! should be: `softmax(x_i) * (1 - softmax(x_i))` ! for off-diagonal: `-softmax(x_i) * softmax(x_j)` if (i == j) then - jacobian(i, j) = & + self % jacobian(i, j) = & self % attention_matrix(seq, i, head) & * (1 - self % attention_matrix(seq, i, head)) else - jacobian(i, j) = & + self % jacobian(i, j) = & - self % attention_matrix(seq, i, head) & * self % attention_matrix(seq, j, head) end if @@ -82,49 +61,29 @@ pure module subroutine common_backward(self, input, gradient) ! multiply output of softmax by temp jacobian matrix ! For computational efficiency (avoid more temp storages), scaling is also done here ! reshapes: [3] -> [1, 3] @ [3, 3] = [1, 3] -> [3] - d_normalize(seq, :, head) = reshape(matmul(& - reshape(d_sdpa(seq, :), [1, self % sequence_length]),& - jacobian * self % scaling_factor& + self % d_normalize(seq, :, head) = reshape(matmul(& + reshape(self % d_sdpa(seq, :), [1, self % sequence_length]),& + self % jacobian * self % scaling_factor& ), [self % sequence_length]) end do ! calculate delta for query - dq(:, :, head) = matmul(d_normalize(:, :, head), k_heads(:, :, head)) + self % q_or_dq(:, :, head) = matmul(self % d_normalize(:, :, head), self % k_heads(:, :, head)) ! calculate delta for key, attention matrix should be transposed unlike for query - dk(:, :, head) = matmul(transpose(d_normalize(:, :, head)), q_heads(:, :, head)) + self % k_or_dk(:, :, head) = matmul(transpose(self % d_normalize(:, :, head)), self % q_heads(:, :, head)) end do ! calculate deltas for input layers - call self % value_layer % backward(self % v_input, self % combine_heads(dv)) - call self % key_layer % backward(self % k_input, self % combine_heads(dk)) - call self % query_layer % backward(self % q_input, self % combine_heads(dq)) - - ! free temporary storages - deallocate(d_output) - deallocate(v_heads) - deallocate(k_heads) - deallocate(q_heads) - deallocate(d_sdpa) - deallocate(jacobian) - deallocate(d_normalize) - deallocate(dq) - deallocate(dk) + call self % value_layer % backward(self % v_input, self % combine_heads(self % v_or_dv)) + call self % key_layer % backward(self % k_input, self % combine_heads(self % k_or_dk)) + call self % query_layer % backward(self % q_input, self % combine_heads(self % q_or_dq)) end subroutine common_backward pure module subroutine common_forward(self, query, key, value) class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: query(:, :), key(:, :), value(:, :) - real, allocatable :: q(:, :, :) - real, allocatable :: k(:, :, :) - real, allocatable :: v(:, :, :) - - ! allocate storage for intermidiate stages - allocate(q(self % sequence_length, self % head_size, self % n_heads)) - allocate(k(self % sequence_length, self % head_size, self % n_heads)) - allocate(v(self % sequence_length, self % head_size, self % n_heads)) - self % q_input = query self % k_input = key self % v_input = value @@ -135,25 +94,20 @@ pure module subroutine common_forward(self, query, key, value) call self % value_layer % forward(value) ! split attention heads for more efficient computation - q = self % split_heads(self % query_layer % output) - k = self % split_heads(self % key_layer % output) - v = self % split_heads(self % value_layer % output) + self % q_or_dq = self % split_heads(self % query_layer % output) + self % k_or_dk = self % split_heads(self % key_layer % output) + self % v_or_dv = self % split_heads(self % value_layer % output) ! create key by value matrix - call self % create_attention_matrix(q, k) + call self % create_attention_matrix(self % q_or_dq, self % k_or_dk) ! apply softmax and scaling call self % normalize_attention_matrix() ! multiply attention matrix by value - call self % scaled_dot_product_attention(v) + call self % scaled_dot_product_attention(self % v_or_dv) self % o_input = self % combine_heads(self % sdpa) call self % output_layer % forward(self % o_input) self % output = self % output_layer % output - - ! free temp vars from memory - deallocate(q) - deallocate(k) - deallocate(v) end subroutine common_forward pure module function split_heads(self, input) result(output) @@ -335,9 +289,26 @@ module subroutine init_base(self, input_shape) self % scaling_factor = sqrt(1 / real(self % head_size)) - allocate(self % q_input(self % sequence_length, self % model_dimension)) - allocate(self % k_input(self % sequence_length, self % model_dimension)) - allocate(self % v_input(self % sequence_length, self % model_dimension)) - allocate(self % o_input(self % sequence_length, self % model_dimension)) + allocate(self % q_input, mold=self % output) + allocate(self % k_input, mold=self % output) + allocate(self % v_input, mold=self % output) + allocate(self % o_input, mold=self % output) + + ! allocate temporary storages + ! the following three are used twice: + ! Forward pass: As inputs after the corresponding linear layer and head reshape + ! Backward pass: As deltas for each input array + allocate(self % q_or_dq, mold=self % sdpa) + allocate(self % k_or_dk, mold=self % sdpa) + allocate(self % v_or_dv, mold=self % sdpa) + + ! the other seven below are for backward pass + allocate(self % d_output, mold=self % sdpa) + allocate(self % v_heads, mold=self % sdpa) + allocate(self % k_heads, mold=self % sdpa) + allocate(self % q_heads, mold=self % sdpa) + allocate(self % d_sdpa(self % sequence_length, self % sequence_length)) + allocate(self % jacobian, mold=self % d_sdpa) + allocate(self % d_normalize, mold=self % attention_matrix) end subroutine init_base end submodule nf_multihead_attention_layer_submodule \ No newline at end of file From f1d6fde8f0177de6180f7918662a8f7169186c2d Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Tue, 25 Feb 2025 21:08:58 +0400 Subject: [PATCH 2/9] multihead_attention_optimization: remove last runtime allocation --- src/nf/nf_multihead_attention.f90 | 1 + src/nf/nf_multihead_attention_submodule.f90 | 14 ++++++-------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 2222b27b..1def2966 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -41,6 +41,7 @@ module nf_multihead_attention_layer real, allocatable :: o_input(:, :) ! temporary storages for forward and backward passes + real, allocatable, private :: normalized_attention(:, :, :) real, allocatable, private :: q_or_dq(:, :, :) real, allocatable, private :: k_or_dk(:, :, :) real, allocatable, private :: v_or_dv(:, :, :) diff --git a/src/nf/nf_multihead_attention_submodule.f90 b/src/nf/nf_multihead_attention_submodule.f90 index 80eabda9..88915816 100644 --- a/src/nf/nf_multihead_attention_submodule.f90 +++ b/src/nf/nf_multihead_attention_submodule.f90 @@ -131,12 +131,8 @@ end subroutine create_attention_matrix pure module subroutine normalize_attention_matrix(self, attention_mask) class(multihead_attention_layer), intent(in out) :: self real, optional, intent(in) :: attention_mask(:, :, :) - real, allocatable :: output(:, :, :) integer :: head, seq - ! temporary storage - allocate(output(self % sequence_length, self % sequence_length, self % n_heads)) - ! scale dowm by square root of each head's size self % attention_matrix = self % attention_matrix * self % scaling_factor ! attention mask is used to mask out some of the tokens if necessary @@ -145,11 +141,9 @@ pure module subroutine normalize_attention_matrix(self, attention_mask) end if ! softmax by last sequnce_length do concurrent(head = 1: self % n_heads, seq = 1: self % sequence_length) - output(seq, :, head) = self % softmax_func % eval_1d(self % attention_matrix(seq, :, head)) + self % normalized_attention(seq, :, head) = self % softmax_func % eval_1d(self % attention_matrix(seq, :, head)) end do - self % attention_matrix = output - - deallocate(output) + self % attention_matrix = self % normalized_attention end subroutine normalize_attention_matrix pure module subroutine scaled_dot_product_attention(self, value) @@ -295,6 +289,10 @@ module subroutine init_base(self, input_shape) allocate(self % o_input, mold=self % output) ! allocate temporary storages + + ! this one is for forward pass + allocate(self % normalized_attention, mold=self % attention_matrix) + ! the following three are used twice: ! Forward pass: As inputs after the corresponding linear layer and head reshape ! Backward pass: As deltas for each input array From b064882c694a60c565fb36328d54ecb0ef8d6367 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Tue, 25 Feb 2025 22:52:18 +0400 Subject: [PATCH 3/9] multihead_attention_optimization: make attention mask actually useable --- src/nf/nf_multihead_attention.f90 | 8 ++- src/nf/nf_multihead_attention_submodule.f90 | 18 ++++-- test/test_multihead_attention_layer.f90 | 67 +++++++++++++++++++++ 3 files changed, 85 insertions(+), 8 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 1def2966..3cdb33f7 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -81,7 +81,7 @@ end function multihead_attention_layer_cons interface - pure module subroutine common_backward(self, input, gradient) + pure module subroutine common_backward(self, input, gradient, attention_mask) !! General backprop for MultiHead Attention mechanism !! Might be used for both Self and Cross Attention !! Self Attention: sum output gradients @@ -89,15 +89,17 @@ pure module subroutine common_backward(self, input, gradient) class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: input(:, :) real, intent(in) :: gradient(:, :) + real, optional, intent(in) :: attention_mask(:, :) end subroutine common_backward - pure module subroutine common_forward(self, query, key, value) + pure module subroutine common_forward(self, query, key, value, attention_mask) !! General forward propagation for MultiHead Attention Mechanism !! Might be used for both Self and Cross Attention !! Self Attention: pass the same value thrice !! Cross Attention: pass three values for your query, key and value class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: query(:, :), key(:, :), value(:, :) + real, optional, intent(in) :: attention_mask(:, :) end subroutine common_forward pure module subroutine init(self, input_shape) @@ -132,7 +134,7 @@ pure module subroutine normalize_attention_matrix(self, attention_mask) !! Output dims: sequence_length, sequence_length, n_heads class(multihead_attention_layer), intent(in out) :: self !! (sequence_length, sequence_length, n_heads) - real, optional, intent(in) :: attention_mask(:, :, :) + real, optional, intent(in) :: attention_mask(:, :) !! (sequence_length, sequence_length, n_heads) end subroutine normalize_attention_matrix diff --git a/src/nf/nf_multihead_attention_submodule.f90 b/src/nf/nf_multihead_attention_submodule.f90 index 88915816..41ff5a53 100644 --- a/src/nf/nf_multihead_attention_submodule.f90 +++ b/src/nf/nf_multihead_attention_submodule.f90 @@ -13,10 +13,11 @@ module function multihead_attention_layer_cons(n_heads) result(res) res % n_heads = n_heads end function multihead_attention_layer_cons - pure module subroutine common_backward(self, input, gradient) + pure module subroutine common_backward(self, input, gradient, attention_mask) class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: input(:, :) real, intent(in) :: gradient(:, :) + real, intent(in), optional :: attention_mask(:, :) integer :: head, seq, i, j @@ -39,6 +40,10 @@ pure module subroutine common_backward(self, input, gradient) ! calculate delta for attention matrix self % d_sdpa = matmul(self % d_output(:, :, head), transpose(self % v_heads(:, :, head))) + if (present(attention_mask)) then + self % d_sdpa = self % d_sdpa + attention_mask + end if + ! this monstrosity below is scaled derivative of softmax do concurrent(seq = 1: self % sequence_length) ! create jacobian matrix @@ -80,9 +85,10 @@ pure module subroutine common_backward(self, input, gradient) call self % query_layer % backward(self % q_input, self % combine_heads(self % q_or_dq)) end subroutine common_backward - pure module subroutine common_forward(self, query, key, value) + pure module subroutine common_forward(self, query, key, value, attention_mask) class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: query(:, :), key(:, :), value(:, :) + real, intent(in), optional :: attention_mask(:, :) self % q_input = query self % k_input = key @@ -101,7 +107,7 @@ pure module subroutine common_forward(self, query, key, value) ! create key by value matrix call self % create_attention_matrix(self % q_or_dq, self % k_or_dk) ! apply softmax and scaling - call self % normalize_attention_matrix() + call self % normalize_attention_matrix(attention_mask) ! multiply attention matrix by value call self % scaled_dot_product_attention(self % v_or_dv) @@ -130,14 +136,16 @@ end subroutine create_attention_matrix pure module subroutine normalize_attention_matrix(self, attention_mask) class(multihead_attention_layer), intent(in out) :: self - real, optional, intent(in) :: attention_mask(:, :, :) + real, optional, intent(in) :: attention_mask(:, :) integer :: head, seq ! scale dowm by square root of each head's size self % attention_matrix = self % attention_matrix * self % scaling_factor ! attention mask is used to mask out some of the tokens if necessary if (present(attention_mask)) then - self % attention_matrix = self % attention_matrix + attention_mask + do concurrent(head = 1: self % n_heads) + self % attention_matrix(:, :, head) = self % attention_matrix(:, :, head) + attention_mask + end do end if ! softmax by last sequnce_length do concurrent(head = 1: self % n_heads, seq = 1: self % sequence_length) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index fdc6862d..7ff4e684 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -27,6 +27,7 @@ program test_multihead_attention_layer call test_multihead_attention_backward(attention, ok) call test_multihead_attention_update_gradients(attention, ok) call test_multihead_attention_forward_reallife_shape(ok) + call test_multihead_attention_mask(ok) call test_self_attention(ok) call test_cross_attention(ok) @@ -315,6 +316,72 @@ subroutine test_multihead_attention_update_gradients(attention, ok) end if end subroutine test_multihead_attention_update_gradients + subroutine test_multihead_attention_mask(ok) + logical, intent(in out) :: ok + type(multihead_attention_layer) :: attention + real :: input(3, 4) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]) + real :: gradient(3, 4) = reshape([0.1, 3., 2., 0.1, 3., 3., 0.1, 2., 0.1, 3., 0.1, 3.], [3, 4]) + real :: attention_mask(3, 3) = reshape([& + 0., 0., 0.,& + 0., 0., -100.,& + -100., 0., -100.& + ], [3, 3]) + real :: output(3, 4) + real, volatile :: output_flat(12) + real, volatile :: attn_weights_flat(18) + real :: expected_output_flat(12) = [& + 0.94935673, 1.0040786, 0.626,& + 0.94935673, 1.0040786, 0.626,& + 0.94935673, 1.0040786, 0.626,& + 0.94935673, 1.0040786, 0.626& + ] + real :: expected_attn_weights_flat(18) = [& + 0.149956360, 2.28110179E-02, 1.0,& + 0.850043654, 0.464612424, 0.0,& + 0.0, 0.512576580, 0.0,& + 0.149956360, 2.28110179E-02, 1.0,& + 0.850043654, 0.464612424, 0.0,& + 0.0, 0.512576580, 0.0& + ] + real :: gradient_flat(12) + real :: expacted_gradient_flat(12) = [& + 0.32137412, 0.30436403, 0.1854456,& + 0.32137412, 0.30436403, 0.1854456,& + 0.32137412, 0.30436403, 0.1854456,& + 0.32137412, 0.30436403, 0.1854456& + ] + + attention = multihead_attention_layer(n_heads=2) + call attention % init_base([3, 4]) + call set_weights(attention) + + call attention % common_forward(input, input, input, attention_mask=attention_mask) + + output_flat = reshape(attention % output, shape(output_flat)) + if (.not. allclose(output_flat, expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'forward w. attention mask returned incorrect values.. failed' + end if + + attn_weights_flat = reshape(attention % attention_matrix, shape(attn_weights_flat)) + if (.not. allclose(attn_weights_flat, expected_attn_weights_flat)) then + ok = .false. + write(stderr, '(a)') 'forward w. attention mask returned incorrect attention weights values.. failed' + end if + + call attention % common_backward(input, gradient, attention_mask) + gradient_flat = reshape(& + attention % query_layer % gradient & + + attention % key_layer % gradient & + + attention % value_layer % gradient,& + [12]& + ) + if (.not. allclose(gradient_flat, expacted_gradient_flat)) then + ok = .false. + write(stderr, '(a)') 'backward w. attention mask returned incorrect gradient values.. failed' + end if + end subroutine test_multihead_attention_mask + subroutine test_self_attention(ok) logical, intent(in out) :: ok type(self_attention_layer) :: attention From 81d386921d82475ca5cca2c206d812221baffc98 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Tue, 25 Feb 2025 22:55:04 +0400 Subject: [PATCH 4/9] multihead_attention_optimization: tests cleanup --- test/test_multihead_attention_layer.f90 | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index 7ff4e684..c8704a33 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -258,9 +258,9 @@ subroutine test_multihead_attention_backward(attention, ok) ! sample for Self Attention: sum of output gradients ! FIXME: remove reshapes when linear2d situation is resolved output = & - reshape(attention % query_layer % gradient, [attention % sequence_length, attention % model_dimension]) & - + reshape(attention % key_layer % gradient, [attention % sequence_length, attention % model_dimension]) & - + reshape(attention % value_layer % gradient, [attention % sequence_length, attention % model_dimension]) + attention % query_layer % gradient & + + attention % key_layer % gradient & + + attention % value_layer % gradient output_shape = shape(output) if (.not. all(output_shape.eq.expected_shape)) then From 7d1a10de341feed06766dc5e05e4f8c4685e7d37 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Tue, 25 Feb 2025 23:00:05 +0400 Subject: [PATCH 5/9] multihead_attention_optimization: cleanup --- src/nf/nf_multihead_attention.f90 | 6 +++--- src/nf/nf_multihead_attention_submodule.f90 | 7 +++---- test/test_multihead_attention_layer.f90 | 1 - 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 3cdb33f7..3f67c2a0 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -158,18 +158,18 @@ elemental module function get_num_params(self) result(num_params) end function get_num_params module function get_params(self) result(params) - class(multihead_attention_layer), intent(in), target :: self + class(multihead_attention_layer), intent(in) :: self real, allocatable :: params(:) end function get_params module function get_gradients(self) result(gradients) - class(multihead_attention_layer), intent(in), target :: self + class(multihead_attention_layer), intent(in) :: self real, allocatable :: gradients(:) end function get_gradients module subroutine set_params(self, params) class(multihead_attention_layer), intent(in out) :: self - real, intent(in), target :: params(:) + real, intent(in) :: params(:) end subroutine set_params module subroutine init_base(self, input_shape) diff --git a/src/nf/nf_multihead_attention_submodule.f90 b/src/nf/nf_multihead_attention_submodule.f90 index 41ff5a53..766d1eb7 100644 --- a/src/nf/nf_multihead_attention_submodule.f90 +++ b/src/nf/nf_multihead_attention_submodule.f90 @@ -187,7 +187,7 @@ elemental module function get_num_params(self) result(num_params) end function get_num_params module function get_params(self) result(params) - class(multihead_attention_layer), intent(in), target :: self + class(multihead_attention_layer), intent(in) :: self real, allocatable :: params(:) params = [& @@ -203,7 +203,7 @@ module function get_params(self) result(params) end function get_params module function get_gradients(self) result(gradients) - class(multihead_attention_layer), intent(in), target :: self + class(multihead_attention_layer), intent(in) :: self real, allocatable :: gradients(:) gradients = [ & @@ -220,8 +220,7 @@ end function get_gradients module subroutine set_params(self, params) class(multihead_attention_layer), intent(in out) :: self - real, intent(in), target :: params(:) - real, pointer :: p_(:,:) => null() + real, intent(in) :: params(:) integer :: i, j, window ! check if the number of parameters is correct diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index c8704a33..e9845bba 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -256,7 +256,6 @@ subroutine test_multihead_attention_backward(attention, ok) call attention % common_backward(input, gradient) ! sample for Self Attention: sum of output gradients - ! FIXME: remove reshapes when linear2d situation is resolved output = & attention % query_layer % gradient & + attention % key_layer % gradient & From aa59523e67dbcbf31dfd06a0cda66f5f0fc9ea81 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Wed, 26 Feb 2025 12:05:57 +0400 Subject: [PATCH 6/9] multihead_attention_optimization: refactoring, split methods even more (will be needed for llama attention) --- src/nf/nf_multihead_attention.f90 | 13 +++ src/nf/nf_multihead_attention_submodule.f90 | 92 ++++++++++++--------- 2 files changed, 68 insertions(+), 37 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 3f67c2a0..cc4cbf9d 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -56,6 +56,8 @@ module nf_multihead_attention_layer procedure :: common_backward procedure :: common_forward + procedure :: sdpa_forward + procedure :: sdpa_backward procedure :: get_num_params procedure :: get_params procedure :: get_gradients @@ -102,6 +104,17 @@ pure module subroutine common_forward(self, query, key, value, attention_mask) real, optional, intent(in) :: attention_mask(:, :) end subroutine common_forward + pure module subroutine sdpa_forward(self, attention_mask) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in), optional :: attention_mask(:, :) + end subroutine sdpa_forward + + pure module subroutine sdpa_backward(self, gradient, attention_mask) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: gradient(:, :) + real, intent(in), optional :: attention_mask(:, :) + end subroutine sdpa_backward + pure module subroutine init(self, input_shape) !! Initialize the layer data structures. !! diff --git a/src/nf/nf_multihead_attention_submodule.f90 b/src/nf/nf_multihead_attention_submodule.f90 index 766d1eb7..37ecd25f 100644 --- a/src/nf/nf_multihead_attention_submodule.f90 +++ b/src/nf/nf_multihead_attention_submodule.f90 @@ -21,6 +21,60 @@ pure module subroutine common_backward(self, input, gradient, attention_mask) integer :: head, seq, i, j + ! bakward through attention mechanism + call self % sdpa_backward(gradient, attention_mask) + + ! calculate deltas for input layers + call self % value_layer % backward(self % v_input, self % combine_heads(self % v_or_dv)) + call self % key_layer % backward(self % k_input, self % combine_heads(self % k_or_dk)) + call self % query_layer % backward(self % q_input, self % combine_heads(self % q_or_dq)) + end subroutine common_backward + + pure module subroutine common_forward(self, query, key, value, attention_mask) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: query(:, :), key(:, :), value(:, :) + real, intent(in), optional :: attention_mask(:, :) + + self % q_input = query + self % k_input = key + self % v_input = value + + ! run inputs through linear layers (trainable params) + call self % query_layer % forward(query) + call self % key_layer % forward(key) + call self % value_layer % forward(value) + + ! split attention heads for more efficient computation + self % q_or_dq = self % split_heads(self % query_layer % output) + self % k_or_dk = self % split_heads(self % key_layer % output) + self % v_or_dv = self % split_heads(self % value_layer % output) + + call self % sdpa_forward(attention_mask) + end subroutine common_forward + + pure module subroutine sdpa_forward(self, attention_mask) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in), optional :: attention_mask(:, :) + + ! create key by value matrix + call self % create_attention_matrix(self % q_or_dq, self % k_or_dk) + ! apply softmax and scaling + call self % normalize_attention_matrix(attention_mask) + ! multiply attention matrix by value + call self % scaled_dot_product_attention(self % v_or_dv) + + self % o_input = self % combine_heads(self % sdpa) + call self % output_layer % forward(self % o_input) + self % output = self % output_layer % output + end subroutine sdpa_forward + + pure module subroutine sdpa_backward(self, gradient, attention_mask) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: gradient(:, :) + real, intent(in), optional :: attention_mask(:, :) + + integer :: head, seq, i, j + ! calculate output layer delta call self % output_layer % backward(self % o_input, gradient) @@ -78,43 +132,7 @@ pure module subroutine common_backward(self, input, gradient, attention_mask) ! calculate delta for key, attention matrix should be transposed unlike for query self % k_or_dk(:, :, head) = matmul(transpose(self % d_normalize(:, :, head)), self % q_heads(:, :, head)) end do - - ! calculate deltas for input layers - call self % value_layer % backward(self % v_input, self % combine_heads(self % v_or_dv)) - call self % key_layer % backward(self % k_input, self % combine_heads(self % k_or_dk)) - call self % query_layer % backward(self % q_input, self % combine_heads(self % q_or_dq)) - end subroutine common_backward - - pure module subroutine common_forward(self, query, key, value, attention_mask) - class(multihead_attention_layer), intent(in out) :: self - real, intent(in) :: query(:, :), key(:, :), value(:, :) - real, intent(in), optional :: attention_mask(:, :) - - self % q_input = query - self % k_input = key - self % v_input = value - - ! run inputs through linear layers (trainable params) - call self % query_layer % forward(query) - call self % key_layer % forward(key) - call self % value_layer % forward(value) - - ! split attention heads for more efficient computation - self % q_or_dq = self % split_heads(self % query_layer % output) - self % k_or_dk = self % split_heads(self % key_layer % output) - self % v_or_dv = self % split_heads(self % value_layer % output) - - ! create key by value matrix - call self % create_attention_matrix(self % q_or_dq, self % k_or_dk) - ! apply softmax and scaling - call self % normalize_attention_matrix(attention_mask) - ! multiply attention matrix by value - call self % scaled_dot_product_attention(self % v_or_dv) - - self % o_input = self % combine_heads(self % sdpa) - call self % output_layer % forward(self % o_input) - self % output = self % output_layer % output - end subroutine common_forward + end subroutine sdpa_backward pure module function split_heads(self, input) result(output) class(multihead_attention_layer), intent(in) :: self From a07624626624915ecea4d25e7ac634e142891af2 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Wed, 26 Feb 2025 23:57:00 +0400 Subject: [PATCH 7/9] multihead_attention_optimization: make attributes public --- src/nf/nf_multihead_attention.f90 | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index cc4cbf9d..6ff51dd5 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -41,17 +41,17 @@ module nf_multihead_attention_layer real, allocatable :: o_input(:, :) ! temporary storages for forward and backward passes - real, allocatable, private :: normalized_attention(:, :, :) - real, allocatable, private :: q_or_dq(:, :, :) - real, allocatable, private :: k_or_dk(:, :, :) - real, allocatable, private :: v_or_dv(:, :, :) - real, allocatable, private :: d_output(:, :, :) - real, allocatable, private :: v_heads(:, :, :) - real, allocatable, private :: k_heads(:, :, :) - real, allocatable, private :: q_heads(:, :, :) - real, allocatable, private :: d_sdpa(:, :) - real, allocatable, private :: jacobian(:, :) - real, allocatable, private :: d_normalize(:, :, :) + real, allocatable :: normalized_attention(:, :, :) + real, allocatable :: q_or_dq(:, :, :) + real, allocatable :: k_or_dk(:, :, :) + real, allocatable :: v_or_dv(:, :, :) + real, allocatable :: d_output(:, :, :) + real, allocatable :: v_heads(:, :, :) + real, allocatable :: k_heads(:, :, :) + real, allocatable :: q_heads(:, :, :) + real, allocatable :: d_sdpa(:, :) + real, allocatable :: jacobian(:, :) + real, allocatable :: d_normalize(:, :, :) contains procedure :: common_backward From 0a399cf49f9ee8dc2dde8181b484374320be2378 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Thu, 27 Feb 2025 23:38:49 +0400 Subject: [PATCH 8/9] multihead_attention_optimization: move heads separation out of sdpa backward --- src/nf/nf_multihead_attention_submodule.f90 | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/nf/nf_multihead_attention_submodule.f90 b/src/nf/nf_multihead_attention_submodule.f90 index 37ecd25f..f78abafd 100644 --- a/src/nf/nf_multihead_attention_submodule.f90 +++ b/src/nf/nf_multihead_attention_submodule.f90 @@ -21,6 +21,10 @@ pure module subroutine common_backward(self, input, gradient, attention_mask) integer :: head, seq, i, j + self % v_heads = self % split_heads(self % value_layer % output) + self % k_heads = self % split_heads(self % key_layer % output) + self % q_heads = self % split_heads(self % query_layer % output) + ! bakward through attention mechanism call self % sdpa_backward(gradient, attention_mask) @@ -80,9 +84,6 @@ pure module subroutine sdpa_backward(self, gradient, attention_mask) ! split heads from output gradient self % d_output = self % split_heads(self % output_layer % gradient) - self % v_heads = self % split_heads(self % value_layer % output) - self % k_heads = self % split_heads(self % key_layer % output) - self % q_heads = self % split_heads(self % query_layer % output) ! iterate over heads to calculate deltas for each of them do concurrent(head = 1: self % n_heads) From 1f3be869b27a850aa5190b5a9cdbe3826dc5bba3 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Fri, 28 Feb 2025 17:40:21 +0400 Subject: [PATCH 9/9] multihead_attention_optimization: add attention mask to self_attention --- src/nf/nf_self_attention_layer.f90 | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/nf/nf_self_attention_layer.f90 b/src/nf/nf_self_attention_layer.f90 index 15e8f40c..0b5f217d 100644 --- a/src/nf/nf_self_attention_layer.f90 +++ b/src/nf/nf_self_attention_layer.f90 @@ -35,14 +35,15 @@ module function self_attention_layer_cons(n_heads) result(res) res % n_heads = n_heads end function self_attention_layer_cons - pure module subroutine backward(self, input, gradient) + pure module subroutine backward(self, input, gradient, attention_mask) !! Self Attention back propagation !! Returns sum of Query, Key and Value gradients class(self_attention_layer), intent(in out) :: self real, intent(in) :: input(:, :) real, intent(in) :: gradient(:, :) + real, intent(in), optional :: attention_mask(:, :) - call self % common_backward(input, gradient) + call self % common_backward(input, gradient, attention_mask) self % gradient = & self % query_layer % gradient & + self % key_layer % gradient &