Skip to content

Commit f4d1b1c

Browse files
committed
Update
1 parent 5e85c0a commit f4d1b1c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

intermediate_source/dqn_with_rnn_tutorial.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,7 @@
344344
#
345345
from torchrl.modules import set_recurrent_mode
346346

347-
with set_recurrent_mode(True):
348-
policy = Seq(feature, lstm, mlp, qval)
347+
policy = Seq(feature, lstm, mlp, qval)
349348

350349
######################################################################
351350
# Because we still have a couple of uninitialized parameters we should
@@ -428,7 +427,8 @@
428427
rb.extend(data.unsqueeze(0).to_tensordict().cpu())
429428
for _ in range(utd):
430429
s = rb.sample().to(device, non_blocking=True)
431-
loss_vals = loss_fn(s)
430+
with set_recurrent_mode(True):
431+
loss_vals = loss_fn(s)
432432
loss_vals["loss"].backward()
433433
optim.step()
434434
optim.zero_grad()

0 commit comments

Comments
 (0)