Learning Workbook: Deep Q-Learning for Pong¶
Author: Dr. Yves J. Hilpisch
Context: Hands-on excerpt from A Short History of Computer Science
Run every cell in order on Google Colab with a GPU runtime. The workflow installs dependencies, builds the Atari Pong environment, defines a modern Deep Q-Network (DQN) agent, trains it with settings that fit a 12 GB GPU, records a rollout, and saves artifacts for later analysis.
0. Mounting Google Drive¶
Mount Drive to save checkpoints and logs in a permanent folder on your Google Drive.
from google.colab import drive
drive.mount("/content/drive")
import os
project_dir = "/content/drive/MyDrive/pong" # adjust to your setup
os.makedirs(project_dir, exist_ok=True)
os.chdir(project_dir)
!pwd
1. Runtime & dependencies¶
- Use Runtime → Change runtime type → GPU in Colab.
- Re-run the installation cell whenever the runtime restarts.
- Installation already includes the Atari ROMs via AutoROM.
# Install required packages (idempotent on Colab)
!pip install -q gymnasium[atari] autorom[accept-rom-license] imageio imageio-ffmpeg
# Accept ROM license and download Atari ROMs
!python -m autorom --accept-license --install-dir /usr/local/lib/python3.12/dist-packages/ale_py/roms
# Verify installations
import sys
print(f"Python version: {sys.version}")
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
import gymnasium as gym
print(f"Gymnasium version: {gym.__version__}")
import imageio
print(f"Imageio version: {imageio.__version__}")
# Verify Atari ROMs are installed
try:
import ale_py
print(f"ALE-Py version: {ale_py.__version__}")
print("Atari ROMs installed successfully!")
except ImportError:
print("Warning: ALE-Py not found")
2. Imports and runtime context¶
Common imports, deterministic seeding, and basic hardware diagnostics.
import collections
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from gymnasium.wrappers import RecordEpisodeStatistics, AtariPreprocessing
import imageio
# Speed-friendly CuDNN settings
if torch.backends.cudnn.is_available():
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
# Matmul precision hint (PyTorch >= 1.12)
if hasattr(torch.backends.cuda, 'matmul'):
torch.backends.cuda.matmul.allow_tf32 = False
if hasattr(torch.backends.cudnn, 'allow_tf32'):
torch.backends.cudnn.allow_tf32 = False
if hasattr(torch, 'set_float32_matmul_precision'):
torch.set_float32_matmul_precision('high')
def set_global_seeds(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
set_global_seeds(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == "cuda":
print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
print("⚠️ GPU not detected. Training on CPU will be significantly slower.")
3. Training configuration¶
A single dataclass keeps all hyperparameters, including limits tailored for a 12 GB Colab GPU. The epsilon schedule and target-network updates are also defined here.
@dataclass
class TrainConfig:
env_id: str = "ALE/Pong-v5"
seed: int = 42
replay_capacity: int = 120_000 # fits easily in 12 GB when storing uint8 frames
warmup_steps: int = 20_000
batch_size: int = 64
gamma: float = 0.99
learning_rate: float = 1e-4
adam_eps: float = 1e-4
eps_start: float = 1.0
eps_final: float = 0.05
eps_decay_frames: int = 250_000
max_frames: int = 600_000
updates_per_step: int = 1
target_sync_interval: int = 2_000
target_tau: float = 1.0 # tau=1.0 -> hard update
grad_clip: float = 10.0
log_interval: int = 5_000
eval_epsilon: float = 0.01
video_path: str = "pong_dqn.mp4"
checkpoint_path: str = "dqn_pong_checkpoint.pt"
checkpoint_interval: int = 20_000
cfg = TrainConfig()
if device.type != "cuda":
original_frames = cfg.max_frames
cfg.max_frames = min(cfg.max_frames, 60_000)
print(f"Adjusted max_frames from {original_frames:,} to {cfg.max_frames:,} for CPU execution.")
rng = np.random.default_rng(cfg.seed)
def linear_schedule(start: float, end: float, duration: int, t: int) -> float:
if duration <= 0:
return end
fraction = min(t / duration, 1.0)
return start + fraction * (end - start)
def update_target_network(target: nn.Module, online: nn.Module, tau: float = 1.0) -> None:
with torch.no_grad():
for target_param, online_param in zip(target.parameters(), online.parameters()):
target_param.data.mul_(1.0 - tau).add_(online_param.data, alpha=tau)
def save_checkpoint(path: str, step: int):
torch.save(
{
"model_state_dict": q_online.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"config": cfg.__dict__,
"steps": step,
"returns": returns,
},
path,
)
print(f"Checkpoint written to {path} at step {step:,}.")
# Optional: resume from a saved checkpoint.
# Set this to True if you want the notebook to
# try loading an existing checkpoint before training.
RESUME_FROM_CHECKPOINT = True
CHECKPOINT_PATH = cfg.checkpoint_path
4. Environment helpers¶
Create Pong with standard preprocessing (grayscale, resize to 84×84, frame stack of 4). A simple wrapper keeps the code explicit and easy to reason about.
class FrameStackWrapper(gym.Wrapper):
"""Stack the last `num_stack` frames along the channel axis."""
def __init__(self, env: gym.Env, num_stack: int = 4):
super().__init__(env)
self.num_stack = num_stack
self.frames = collections.deque(maxlen=num_stack)
single_frame_shape = env.observation_space.shape
self.observation_space = gym.spaces.Box(
low=0,
high=255,
shape=(num_stack,) + single_frame_shape,
dtype=env.observation_space.dtype,
)
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
self.frames.clear()
processed = np.expand_dims(obs, axis=0)
for _ in range(self.num_stack):
self.frames.append(processed)
return self._get_obs(), info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
processed = np.expand_dims(obs, axis=0)
self.frames.append(processed)
return self._get_obs(), reward, terminated, truncated, info
def _get_obs(self):
return np.concatenate(list(self.frames), axis=0)
def make_env(seed: int, render_mode: str = "rgb_array") -> gym.Env:
env = gym.make(cfg.env_id, render_mode=render_mode, frameskip=1)
env = RecordEpisodeStatistics(env)
env = AtariPreprocessing(
env,
frame_skip=4,
screen_size=84,
grayscale_obs=True,
scale_obs=False,
)
env = FrameStackWrapper(env, num_stack=4)
env.reset(seed=seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
train_env = make_env(cfg.seed)
eval_env = make_env(cfg.seed + 1, render_mode="rgb_array")
print(f"Action space: {train_env.action_space}")
print(f"Observation space: {train_env.observation_space}")
5. Replay buffer & policy helpers¶
Use a NumPy-backed replay buffer that stores uint8 frames (memory friendly) and serves PyTorch tensors on demand.
class ReplayBuffer:
def __init__(self, capacity: int, obs_shape: Tuple[int, ...]):
self.capacity = capacity
self.obs_shape = obs_shape
self.states = np.zeros((capacity, *obs_shape), dtype=np.uint8)
self.next_states = np.zeros((capacity, *obs_shape), dtype=np.uint8)
self.actions = np.zeros(capacity, dtype=np.int64)
self.rewards = np.zeros(capacity, dtype=np.float32)
self.dones = np.zeros(capacity, dtype=np.bool_)
self.pos = 0
self.full = False
def add(self, state, action, reward, next_state, done):
self.states[self.pos] = state
self.next_states[self.pos] = next_state
self.actions[self.pos] = action
self.rewards[self.pos] = reward
self.dones[self.pos] = done
self.pos = (self.pos + 1) % self.capacity
if self.pos == 0:
self.full = True
def __len__(self):
return self.capacity if self.full else self.pos
def can_sample(self, batch_size: int) -> bool:
return len(self) >= batch_size
def sample(self, batch_size: int):
idxs = rng.choice(len(self), size=batch_size, replace=False)
states = torch.as_tensor(self.states[idxs], dtype=torch.float32, device=device)
next_states = torch.as_tensor(self.next_states[idxs], dtype=torch.float32, device=device)
actions = torch.as_tensor(self.actions[idxs], dtype=torch.int64, device=device)
rewards = torch.as_tensor(self.rewards[idxs], dtype=torch.float32, device=device)
dones = torch.as_tensor(self.dones[idxs].astype(np.float32), dtype=torch.float32, device=device)
return states, actions, rewards, next_states, dones
def select_action(state: np.ndarray, q_net: nn.Module, epsilon: float) -> int:
if random.random() < epsilon:
return train_env.action_space.sample()
state_t = torch.as_tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
with torch.no_grad():
q_values = q_net(state_t)
return int(q_values.argmax(dim=1).item())
obs_shape = train_env.observation_space.shape
buffer = ReplayBuffer(cfg.replay_capacity, obs_shape)
print(f"Replay buffer capacity: {cfg.replay_capacity:,} transitions")
6. DQN architecture¶
Standard Atari CNN (32/64/64 conv layers + dense head). We keep weights light enough for Colab and apply orthogonal initialisation for stability.
class DQN(nn.Module):
def __init__(self, action_dim: int):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(4, 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(64 * 7 * 7, 512),
nn.ReLU(),
nn.Linear(512, action_dim),
)
self.apply(self._init_weights)
@staticmethod
def _init_weights(module):
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.orthogonal_(module.weight, gain=nn.init.calculate_gain("relu"))
if module.bias is not None:
nn.init.constant_(module.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x / 255.0
return self.net(x)
action_dim = train_env.action_space.n
q_online = DQN(action_dim).to(device)
q_target = DQN(action_dim).to(device)
update_target_network(q_target, q_online, tau=1.0)
optimizer = optim.Adam(q_online.parameters(), lr=cfg.learning_rate, eps=cfg.adam_eps)
try:
scaler = torch.amp.GradScaler(device='cuda', enabled=(device.type == "cuda"))
except AttributeError:
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
print(f"Model parameters: {sum(p.numel() for p in q_online.parameters()):,}")
7. Training loop¶
Implements Double DQN targets, gradient clipping, automatic mixed precision (AMP) on GPU, and periodic logging. Adjust cfg.max_frames or cfg.warmup_steps if you need shorter runs.
%%time
import time
from pathlib import Path
# Initial values; may be overwritten if we resume from checkpoint
global_step = 0
returns = []
# Handle optional resume-from-checkpoint logic.
if RESUME_FROM_CHECKPOINT and Path(CHECKPOINT_PATH).exists():
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
q_online.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
saved_cfg = checkpoint.get("config", {})
for k, v in saved_cfg.items():
if hasattr(cfg, k):
setattr(cfg, k, v)
global_step = int(checkpoint.get("steps", 0))
returns = list(checkpoint.get("returns", []))
print(f"Resumed from {CHECKPOINT_PATH} at step {global_step:,}.")
cfg.max_frames = 800_000
elif RESUME_FROM_CHECKPOINT:
print(f"No checkpoint found at {CHECKPOINT_PATH}. Starting from scratch.")
state, _ = train_env.reset(seed=cfg.seed)
update_target_network(q_target, q_online, tau=1.0)
loss_history = collections.deque(maxlen=200)
lengths = []
train_seed = cfg.seed
start_time = time.time()
checkpoint_interval = cfg.checkpoint_interval
print("Starting training...")
while global_step < cfg.max_frames:
epsilon = linear_schedule(cfg.eps_start, cfg.eps_final, cfg.eps_decay_frames, global_step)
action = select_action(state, q_online, epsilon)
next_state, reward, terminated, truncated, info = train_env.step(action)
done = terminated or truncated
buffer.add(state, action, reward, next_state, done)
state = next_state
global_step += 1
if done:
if "episode" in info:
returns.append(info["episode"]["r"])
lengths.append(info["episode"]["l"])
train_seed += 1
state, _ = train_env.reset(seed=train_seed)
if len(buffer) >= max(cfg.batch_size, cfg.warmup_steps):
for _ in range(cfg.updates_per_step):
if not buffer.can_sample(cfg.batch_size):
break
states_b, actions_b, rewards_b, next_states_b, dones_b = buffer.sample(cfg.batch_size)
with torch.amp.autocast(device_type="cuda", enabled=(device.type == "cuda")):
q_values = q_online(states_b)
current_q = q_values.gather(1, actions_b.unsqueeze(1)).squeeze(1)
with torch.no_grad():
next_q_online = q_online(next_states_b)
best_actions = next_q_online.argmax(dim=1)
next_q_target = q_target(next_states_b)
target_q = next_q_target.gather(1, best_actions.unsqueeze(1)).squeeze(1)
targets = rewards_b + cfg.gamma * (1.0 - dones_b) * target_q
loss = nn.functional.smooth_l1_loss(current_q, targets)
scaler.scale(loss).backward()
if cfg.grad_clip is not None:
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(q_online.parameters(), cfg.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
loss_history.append(loss.item())
if global_step % cfg.target_sync_interval == 0:
update_target_network(q_target, q_online, tau=cfg.target_tau)
if global_step > 0 and global_step % checkpoint_interval == 0:
save_checkpoint(cfg.checkpoint_path, global_step)
if global_step % cfg.log_interval == 0:
elapsed_min = (time.time() - start_time) / 60.0
recent_returns = returns[-10:]
mean_return = np.mean(recent_returns) if recent_returns else float("nan")
mean_loss = np.mean(loss_history) if loss_history else float("nan")
print(
f"Step {global_step:>7,}/{cfg.max_frames:,} | "
f"epsilon {epsilon:.3f} | "
f"buffer {len(buffer):>7,} | "
f"episodes {len(returns):>5} | "
f"mean return (last 10): {mean_return:>6.2f} | "
f"loss {mean_loss:>8.4f} | "
f"elapsed {elapsed_min:>5.1f} min"
)
print("Training complete.")
if global_step % checkpoint_interval != 0:
save_checkpoint(cfg.checkpoint_path, global_step)
if returns:
print(f"Final mean return (last 10 episodes): {np.mean(returns[-10:]):.2f}")
8. Record and evaluate the agent¶
Render one evaluation episode with a nearly greedy policy, write an MP4, and report the return.
@torch.inference_mode()
def play_and_record(env: gym.Env, q_net: nn.Module, epsilon_eval: float, path: str) -> float:
frames = []
state, _ = env.reset()
done = False
total_reward = 0.0
q_net.eval()
while not done:
frame = env.render()
frames.append(frame)
action = select_action(state, q_net, epsilon_eval)
state, reward, terminated, truncated, info = env.step(action)
total_reward += reward
done = terminated or truncated
imageio.mimsave(path, frames, fps=30)
print(f"Saved rollout to {path} ({len(frames)} frames)")
print(f"Evaluation return: {total_reward}")
return total_reward
eval_reward = play_and_record(eval_env, q_online, cfg.eval_epsilon, cfg.video_path)
9. Display the recorded video¶
from IPython.display import Video
Video(cfg.video_path, embed=True, width=400)
10. Save the model checkpoint¶
Save the trained weights together with optimizer state and training metadata for reproducibility.
# Manual checkpoint save (optional)
save_checkpoint(cfg.checkpoint_path, global_step)