diff --git a/actor_critic/main.py b/actor_critic/main.py index 52c6f68f7a..a23dbb4d3b 100644 --- a/actor_critic/main.py +++ b/actor_critic/main.py @@ -34,17 +34,15 @@ class Policy(nn.Module): def __init__(self): super(Policy, self).__init__() - self.affine1 = nn.Linear(4, 16) - self.affine2 = nn.Linear(16, 32) - self.action_head = nn.Linear(32, 2) - self.value_head = nn.Linear(32, 1) + self.affine1 = nn.Linear(4, 128) + self.action_head = nn.Linear(128, 2) + self.value_head = nn.Linear(128, 1) self.saved_actions = [] self.rewards = [] def forward(self, x): x = F.relu(self.affine1(x)) - x = F.relu(self.affine2(x)) action_scores = self.action_head(x) state_values = self.value_head(x) return F.softmax(action_scores), state_values