From 92e1abc6d81264e4c983db2652eca7a32bb32896 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Mon, 14 Jul 2025 13:48:24 -0700 Subject: [PATCH 1/3] Fix DQN w RNN tutorial --- intermediate_source/dqn_with_rnn_tutorial.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/intermediate_source/dqn_with_rnn_tutorial.py b/intermediate_source/dqn_with_rnn_tutorial.py index bcc484f0a00..f28ad9f6903 100644 --- a/intermediate_source/dqn_with_rnn_tutorial.py +++ b/intermediate_source/dqn_with_rnn_tutorial.py @@ -342,7 +342,10 @@ # will return a new instance of the LSTM (with shared weights) that will # assume that the input data is sequential in nature. # -policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval) +from torchrl.modules import set_recurrent_mode + +with set_recurrent_mode(True): + policy = Seq(feature, lstm, mlp, qval) ###################################################################### # Because we still have a couple of uninitialized parameters we should @@ -389,7 +392,9 @@ # For the sake of efficiency, we're only running a few thousands iterations # here. In a real setting, the total number of frames should be set to 1M. # -collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200, device=device) +collector = SyncDataCollector( + env, stoch_policy, frames_per_batch=50, total_frames=200, device=device +) rb = TensorDictReplayBuffer( storage=LazyMemmapStorage(20_000), batch_size=4, prefetch=10 ) @@ -464,5 +469,5 @@ # # Further Reading # --------------- -# +# # - The TorchRL documentation can be found `here `_. From 5e85c0a72b0c6b09fc876fd1c73ea55f09178391 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Mon, 14 Jul 2025 16:02:40 -0700 Subject: [PATCH 2/3] Update intermediate_source/dqn_with_rnn_tutorial.py --- intermediate_source/dqn_with_rnn_tutorial.py | 1 + 1 file changed, 1 insertion(+) diff --git a/intermediate_source/dqn_with_rnn_tutorial.py b/intermediate_source/dqn_with_rnn_tutorial.py index f28ad9f6903..9b41dbfecaf 100644 --- a/intermediate_source/dqn_with_rnn_tutorial.py +++ b/intermediate_source/dqn_with_rnn_tutorial.py @@ -392,6 +392,7 @@ # For the sake of efficiency, we're only running a few thousands iterations # here. In a real setting, the total number of frames should be set to 1M. # + collector = SyncDataCollector( env, stoch_policy, frames_per_batch=50, total_frames=200, device=device ) From f4d1b1cc7b406a4bc43eba443da779ab2342415f Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Thu, 17 Jul 2025 14:14:21 -0700 Subject: [PATCH 3/3] Update --- intermediate_source/dqn_with_rnn_tutorial.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/intermediate_source/dqn_with_rnn_tutorial.py b/intermediate_source/dqn_with_rnn_tutorial.py index 9b41dbfecaf..462415dcc74 100644 --- a/intermediate_source/dqn_with_rnn_tutorial.py +++ b/intermediate_source/dqn_with_rnn_tutorial.py @@ -344,8 +344,7 @@ # from torchrl.modules import set_recurrent_mode -with set_recurrent_mode(True): - policy = Seq(feature, lstm, mlp, qval) +policy = Seq(feature, lstm, mlp, qval) ###################################################################### # Because we still have a couple of uninitialized parameters we should @@ -428,7 +427,8 @@ rb.extend(data.unsqueeze(0).to_tensordict().cpu()) for _ in range(utd): s = rb.sample().to(device, non_blocking=True) - loss_vals = loss_fn(s) + with set_recurrent_mode(True): + loss_vals = loss_fn(s) loss_vals["loss"].backward() optim.step() optim.zero_grad()