diff --git a/intermediate_source/mario_rl_tutorial.py b/intermediate_source/mario_rl_tutorial.py index 74a08a47b37..20f1b618773 100755 --- a/intermediate_source/mario_rl_tutorial.py +++ b/intermediate_source/mario_rl_tutorial.py @@ -43,7 +43,7 @@ import numpy as np from pathlib import Path from collections import deque -import random, datetime, os, copy +import random, datetime, os # Gym is an OpenAI toolkit for RL import gym @@ -424,20 +424,10 @@ def __init__(self, input_dim, output_dim): if w != 84: raise ValueError(f"Expecting input width: 84, got: {w}") - self.online = nn.Sequential( - nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4), - nn.ReLU(), - nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), - nn.ReLU(), - nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), - nn.ReLU(), - nn.Flatten(), - nn.Linear(3136, 512), - nn.ReLU(), - nn.Linear(512, output_dim), - ) + self.online = self.__build_cnn(c, output_dim) - self.target = copy.deepcopy(self.online) + self.target = self.__build_cnn(c, output_dim) + self.target.load_state_dict(self.online.state_dict()) # Q_target parameters are frozen. for p in self.target.parameters(): @@ -449,6 +439,20 @@ def forward(self, input, model): elif model == "target": return self.target(input) + def __build_cnn(self, c, output_dim): + return nn.Sequential( + nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4), + nn.ReLU(), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), + nn.ReLU(), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), + nn.ReLU(), + nn.Flatten(), + nn.Linear(3136, 512), + nn.ReLU(), + nn.Linear(512, output_dim), + ) + ###################################################################### # TD Estimate & TD Target