Skip to content

Commit 48d2ae7

Browse files
committed
comments added
1 parent 9709016 commit 48d2ae7

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

intermediate_source/seq2seq_translation_tutorial.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,21 @@ def forward_step(self, input, hidden):
442442
# :alt:
443443
#
444444
#
445+
# Bahdanau attention, also known as additive attention, is a commonly used
446+
# attention mechanism in sequence-to-sequence models, particularly in neural
447+
# machine translation tasks. It was introduced by Dzmitry Bahdanau et al. in their
448+
# paper titled `Neural Machine Translation by Jointly Learning to Align and Translate <https://arxiv.org/pdf/1409.0473.pdf>`__.
449+
# This attention mechanism employs a learned alignment model to compute attention
450+
# scores between the encoder and decoder hidden states. It utilizes a feed-forward
451+
# neural network to calculate alignment scores.
452+
#
453+
# However, there are alternative attention mechanisms available, such as Luong attention,
454+
# which computes attention scores by taking the dot product between the decoder hidden
455+
# state and the encoder hidden states. It does not involve the non-linear transformation
456+
# used in Bahdanau attention.
457+
#
458+
# In this tutorial, we will be using Bahdanau attention. However, it would be a valuable
459+
# exercise to explore modifying the attention mechanism to use Luong attention.
445460

446461
class BahdanauAttention(nn.Module):
447462
def __init__(self, hidden_size):
@@ -775,7 +790,7 @@ def evaluateRandomly(encoder, decoder, n=10):
775790
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
776791
decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)
777792

778-
train(train_dataloader, encoder, decoder, 100, print_every=5, plot_every=5)
793+
train(train_dataloader, encoder, decoder, 80, print_every=5, plot_every=5)
779794

780795
######################################################################
781796
#
@@ -793,18 +808,8 @@ def evaluateRandomly(encoder, decoder, n=10):
793808
# at each time step.
794809
#
795810
# You could simply run ``plt.matshow(attentions)`` to see attention output
796-
# displayed as a matrix, with the columns being input steps and rows being
797-
# output steps:
798-
#
799-
800-
output_words, attentions = evaluate(
801-
encoder, decoder, 'je suis trop froid', input_lang, output_lang)
802-
plt.matshow(attentions.cpu().numpy()[0])
803-
804-
805-
######################################################################
806-
# For a better viewing experience we will do the extra work of adding axes
807-
# and labels:
811+
# displayed as a matrix. For a better viewing experience we will do the
812+
# extra work of adding axes and labels:
808813
#
809814

810815
def showAttention(input_sentence, output_words, attentions):

0 commit comments

Comments
 (0)