@@ -442,6 +442,21 @@ def forward_step(self, input, hidden):
442
442
# :alt:
443
443
#
444
444
#
445
+ # Bahdanau attention, also known as additive attention, is a commonly used
446
+ # attention mechanism in sequence-to-sequence models, particularly in neural
447
+ # machine translation tasks. It was introduced by Dzmitry Bahdanau et al. in their
448
+ # paper titled `Neural Machine Translation by Jointly Learning to Align and Translate <https://arxiv.org/pdf/1409.0473.pdf>`__.
449
+ # This attention mechanism employs a learned alignment model to compute attention
450
+ # scores between the encoder and decoder hidden states. It utilizes a feed-forward
451
+ # neural network to calculate alignment scores.
452
+ #
453
+ # However, there are alternative attention mechanisms available, such as Luong attention,
454
+ # which computes attention scores by taking the dot product between the decoder hidden
455
+ # state and the encoder hidden states. It does not involve the non-linear transformation
456
+ # used in Bahdanau attention.
457
+ #
458
+ # In this tutorial, we will be using Bahdanau attention. However, it would be a valuable
459
+ # exercise to explore modifying the attention mechanism to use Luong attention.
445
460
446
461
class BahdanauAttention (nn .Module ):
447
462
def __init__ (self , hidden_size ):
@@ -775,7 +790,7 @@ def evaluateRandomly(encoder, decoder, n=10):
775
790
encoder = EncoderRNN (input_lang .n_words , hidden_size ).to (device )
776
791
decoder = AttnDecoderRNN (hidden_size , output_lang .n_words ).to (device )
777
792
778
- train (train_dataloader , encoder , decoder , 100 , print_every = 5 , plot_every = 5 )
793
+ train (train_dataloader , encoder , decoder , 80 , print_every = 5 , plot_every = 5 )
779
794
780
795
######################################################################
781
796
#
@@ -793,18 +808,8 @@ def evaluateRandomly(encoder, decoder, n=10):
793
808
# at each time step.
794
809
#
795
810
# You could simply run ``plt.matshow(attentions)`` to see attention output
796
- # displayed as a matrix, with the columns being input steps and rows being
797
- # output steps:
798
- #
799
-
800
- output_words , attentions = evaluate (
801
- encoder , decoder , 'je suis trop froid' , input_lang , output_lang )
802
- plt .matshow (attentions .cpu ().numpy ()[0 ])
803
-
804
-
805
- ######################################################################
806
- # For a better viewing experience we will do the extra work of adding axes
807
- # and labels:
811
+ # displayed as a matrix. For a better viewing experience we will do the
812
+ # extra work of adding axes and labels:
808
813
#
809
814
810
815
def showAttention (input_sentence , output_words , attentions ):
0 commit comments