Learning Workbook: Deep Q-Learning for Pong¶

Author: Dr. Yves J. Hilpisch
Context: Hands-on excerpt from A Short History of Computer Science

Run every cell in order on Google Colab with a GPU runtime. The workflow installs dependencies, builds the Atari Pong environment, defines a modern Deep Q-Network (DQN) agent, trains it with settings that fit a 12 GB GPU, records a rollout, and saves artifacts for later analysis.

0. Mounting Google Drive¶

Mount Drive to save checkpoints and logs in a permanent folder on your Google Drive.

In [ ]:
from google.colab import drive
drive.mount("/content/drive")
In [ ]:
import os
project_dir = "/content/drive/MyDrive/pong"  # adjust to your setup
os.makedirs(project_dir, exist_ok=True)
os.chdir(project_dir)
In [ ]:
!pwd

1. Runtime & dependencies¶

  • Use Runtime → Change runtime type → GPU in Colab.
  • Re-run the installation cell whenever the runtime restarts.
  • Installation already includes the Atari ROMs via AutoROM.
In [ ]:
# Install required packages (idempotent on Colab)
!pip install -q gymnasium[atari] autorom[accept-rom-license] imageio imageio-ffmpeg

# Accept ROM license and download Atari ROMs
!python -m autorom --accept-license --install-dir /usr/local/lib/python3.12/dist-packages/ale_py/roms

# Verify installations
import sys
print(f"Python version: {sys.version}")

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

import gymnasium as gym
print(f"Gymnasium version: {gym.__version__}")

import imageio
print(f"Imageio version: {imageio.__version__}")

# Verify Atari ROMs are installed
try:
    import ale_py
    print(f"ALE-Py version: {ale_py.__version__}")
    print("Atari ROMs installed successfully!")
except ImportError:
    print("Warning: ALE-Py not found")

2. Imports and runtime context¶

Common imports, deterministic seeding, and basic hardware diagnostics.

In [ ]:
import collections
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple

import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from gymnasium.wrappers import RecordEpisodeStatistics, AtariPreprocessing
import imageio

# Speed-friendly CuDNN settings
if torch.backends.cudnn.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

# Matmul precision hint (PyTorch >= 1.12)
if hasattr(torch.backends.cuda, 'matmul'):
    torch.backends.cuda.matmul.allow_tf32 = False
if hasattr(torch.backends.cudnn, 'allow_tf32'):
    torch.backends.cudnn.allow_tf32 = False
if hasattr(torch, 'set_float32_matmul_precision'):
    torch.set_float32_matmul_precision('high')


def set_global_seeds(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


set_global_seeds(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
    print("⚠️ GPU not detected. Training on CPU will be significantly slower.")

3. Training configuration¶

A single dataclass keeps all hyperparameters, including limits tailored for a 12 GB Colab GPU. The epsilon schedule and target-network updates are also defined here.

In [ ]:
@dataclass
class TrainConfig:
    env_id: str = "ALE/Pong-v5"
    seed: int = 42
    replay_capacity: int = 120_000  # fits easily in 12 GB when storing uint8 frames
    warmup_steps: int = 20_000
    batch_size: int = 64
    gamma: float = 0.99
    learning_rate: float = 1e-4
    adam_eps: float = 1e-4
    eps_start: float = 1.0
    eps_final: float = 0.05
    eps_decay_frames: int = 250_000
    max_frames: int = 600_000
    updates_per_step: int = 1
    target_sync_interval: int = 2_000
    target_tau: float = 1.0  # tau=1.0 -> hard update
    grad_clip: float = 10.0
    log_interval: int = 5_000
    eval_epsilon: float = 0.01
    video_path: str = "pong_dqn.mp4"
    checkpoint_path: str = "dqn_pong_checkpoint.pt"
    checkpoint_interval: int = 20_000


cfg = TrainConfig()
if device.type != "cuda":
    original_frames = cfg.max_frames
    cfg.max_frames = min(cfg.max_frames, 60_000)
    print(f"Adjusted max_frames from {original_frames:,} to {cfg.max_frames:,} for CPU execution.")

rng = np.random.default_rng(cfg.seed)


def linear_schedule(start: float, end: float, duration: int, t: int) -> float:
    if duration <= 0:
        return end
    fraction = min(t / duration, 1.0)
    return start + fraction * (end - start)


def update_target_network(target: nn.Module, online: nn.Module, tau: float = 1.0) -> None:
    with torch.no_grad():
        for target_param, online_param in zip(target.parameters(), online.parameters()):
            target_param.data.mul_(1.0 - tau).add_(online_param.data, alpha=tau)
In [ ]:
def save_checkpoint(path: str, step: int):
    torch.save(
        {
            "model_state_dict": q_online.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "config": cfg.__dict__,
            "steps": step,
            "returns": returns,
        },
        path,
    )
    print(f"Checkpoint written to {path} at step {step:,}.")
In [ ]:
# Optional: resume from a saved checkpoint.
# Set this to True if you want the notebook to
# try loading an existing checkpoint before training.
RESUME_FROM_CHECKPOINT = True
CHECKPOINT_PATH = cfg.checkpoint_path

4. Environment helpers¶

Create Pong with standard preprocessing (grayscale, resize to 84×84, frame stack of 4). A simple wrapper keeps the code explicit and easy to reason about.

In [ ]:
class FrameStackWrapper(gym.Wrapper):
    """Stack the last `num_stack` frames along the channel axis."""

    def __init__(self, env: gym.Env, num_stack: int = 4):
        super().__init__(env)
        self.num_stack = num_stack
        self.frames = collections.deque(maxlen=num_stack)
        single_frame_shape = env.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=(num_stack,) + single_frame_shape,
            dtype=env.observation_space.dtype,
        )

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self.frames.clear()
        processed = np.expand_dims(obs, axis=0)
        for _ in range(self.num_stack):
            self.frames.append(processed)
        return self._get_obs(), info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        processed = np.expand_dims(obs, axis=0)
        self.frames.append(processed)
        return self._get_obs(), reward, terminated, truncated, info

    def _get_obs(self):
        return np.concatenate(list(self.frames), axis=0)


def make_env(seed: int, render_mode: str = "rgb_array") -> gym.Env:
    env = gym.make(cfg.env_id, render_mode=render_mode, frameskip=1)
    env = RecordEpisodeStatistics(env)
    env = AtariPreprocessing(
        env,
        frame_skip=4,
        screen_size=84,
        grayscale_obs=True,
        scale_obs=False,
    )
    env = FrameStackWrapper(env, num_stack=4)
    env.reset(seed=seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)
    return env


train_env = make_env(cfg.seed)
eval_env = make_env(cfg.seed + 1, render_mode="rgb_array")

print(f"Action space: {train_env.action_space}")
print(f"Observation space: {train_env.observation_space}")

5. Replay buffer & policy helpers¶

Use a NumPy-backed replay buffer that stores uint8 frames (memory friendly) and serves PyTorch tensors on demand.

In [ ]:
class ReplayBuffer:
    def __init__(self, capacity: int, obs_shape: Tuple[int, ...]):
        self.capacity = capacity
        self.obs_shape = obs_shape
        self.states = np.zeros((capacity, *obs_shape), dtype=np.uint8)
        self.next_states = np.zeros((capacity, *obs_shape), dtype=np.uint8)
        self.actions = np.zeros(capacity, dtype=np.int64)
        self.rewards = np.zeros(capacity, dtype=np.float32)
        self.dones = np.zeros(capacity, dtype=np.bool_)
        self.pos = 0
        self.full = False

    def add(self, state, action, reward, next_state, done):
        self.states[self.pos] = state
        self.next_states[self.pos] = next_state
        self.actions[self.pos] = action
        self.rewards[self.pos] = reward
        self.dones[self.pos] = done
        self.pos = (self.pos + 1) % self.capacity
        if self.pos == 0:
            self.full = True

    def __len__(self):
        return self.capacity if self.full else self.pos

    def can_sample(self, batch_size: int) -> bool:
        return len(self) >= batch_size

    def sample(self, batch_size: int):
        idxs = rng.choice(len(self), size=batch_size, replace=False)
        states = torch.as_tensor(self.states[idxs], dtype=torch.float32, device=device)
        next_states = torch.as_tensor(self.next_states[idxs], dtype=torch.float32, device=device)
        actions = torch.as_tensor(self.actions[idxs], dtype=torch.int64, device=device)
        rewards = torch.as_tensor(self.rewards[idxs], dtype=torch.float32, device=device)
        dones = torch.as_tensor(self.dones[idxs].astype(np.float32), dtype=torch.float32, device=device)
        return states, actions, rewards, next_states, dones


def select_action(state: np.ndarray, q_net: nn.Module, epsilon: float) -> int:
    if random.random() < epsilon:
        return train_env.action_space.sample()
    state_t = torch.as_tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    with torch.no_grad():
        q_values = q_net(state_t)
    return int(q_values.argmax(dim=1).item())


obs_shape = train_env.observation_space.shape
buffer = ReplayBuffer(cfg.replay_capacity, obs_shape)
print(f"Replay buffer capacity: {cfg.replay_capacity:,} transitions")

6. DQN architecture¶

Standard Atari CNN (32/64/64 conv layers + dense head). We keep weights light enough for Colab and apply orthogonal initialisation for stability.

In [ ]:
class DQN(nn.Module):
    def __init__(self, action_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
            nn.Linear(512, action_dim),
        )
        self.apply(self._init_weights)

    @staticmethod
    def _init_weights(module):
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            nn.init.orthogonal_(module.weight, gain=nn.init.calculate_gain("relu"))
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x / 255.0
        return self.net(x)


action_dim = train_env.action_space.n
q_online = DQN(action_dim).to(device)
q_target = DQN(action_dim).to(device)
update_target_network(q_target, q_online, tau=1.0)

optimizer = optim.Adam(q_online.parameters(), lr=cfg.learning_rate, eps=cfg.adam_eps)
try:
    scaler = torch.amp.GradScaler(device='cuda', enabled=(device.type == "cuda"))
except AttributeError:
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))

print(f"Model parameters: {sum(p.numel() for p in q_online.parameters()):,}")

7. Training loop¶

Implements Double DQN targets, gradient clipping, automatic mixed precision (AMP) on GPU, and periodic logging. Adjust cfg.max_frames or cfg.warmup_steps if you need shorter runs.

In [ ]:
%%time
import time
from pathlib import Path

# Initial values; may be overwritten if we resume from checkpoint
global_step = 0
returns = []

# Handle optional resume-from-checkpoint logic.
if RESUME_FROM_CHECKPOINT and Path(CHECKPOINT_PATH).exists():
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
    q_online.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    saved_cfg = checkpoint.get("config", {})
    for k, v in saved_cfg.items():
        if hasattr(cfg, k):
            setattr(cfg, k, v)
    global_step = int(checkpoint.get("steps", 0))
    returns = list(checkpoint.get("returns", []))
    print(f"Resumed from {CHECKPOINT_PATH} at step {global_step:,}.")
    cfg.max_frames = 800_000
elif RESUME_FROM_CHECKPOINT:
    print(f"No checkpoint found at {CHECKPOINT_PATH}. Starting from scratch.")

state, _ = train_env.reset(seed=cfg.seed)
update_target_network(q_target, q_online, tau=1.0)

loss_history = collections.deque(maxlen=200)
lengths = []

train_seed = cfg.seed
start_time = time.time()
checkpoint_interval = cfg.checkpoint_interval
print("Starting training...")

while global_step < cfg.max_frames:
    epsilon = linear_schedule(cfg.eps_start, cfg.eps_final, cfg.eps_decay_frames, global_step)
    action = select_action(state, q_online, epsilon)

    next_state, reward, terminated, truncated, info = train_env.step(action)
    done = terminated or truncated
    buffer.add(state, action, reward, next_state, done)
    state = next_state
    global_step += 1

    if done:
        if "episode" in info:
            returns.append(info["episode"]["r"])
            lengths.append(info["episode"]["l"])
        train_seed += 1
        state, _ = train_env.reset(seed=train_seed)

    if len(buffer) >= max(cfg.batch_size, cfg.warmup_steps):
        for _ in range(cfg.updates_per_step):
            if not buffer.can_sample(cfg.batch_size):
                break
            states_b, actions_b, rewards_b, next_states_b, dones_b = buffer.sample(cfg.batch_size)

            with torch.amp.autocast(device_type="cuda", enabled=(device.type == "cuda")):
                q_values = q_online(states_b)
                current_q = q_values.gather(1, actions_b.unsqueeze(1)).squeeze(1)

                with torch.no_grad():
                    next_q_online = q_online(next_states_b)
                    best_actions = next_q_online.argmax(dim=1)
                    next_q_target = q_target(next_states_b)
                    target_q = next_q_target.gather(1, best_actions.unsqueeze(1)).squeeze(1)
                    targets = rewards_b + cfg.gamma * (1.0 - dones_b) * target_q

                loss = nn.functional.smooth_l1_loss(current_q, targets)

            scaler.scale(loss).backward()
            if cfg.grad_clip is not None:
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(q_online.parameters(), cfg.grad_clip)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            loss_history.append(loss.item())

            if global_step % cfg.target_sync_interval == 0:
                update_target_network(q_target, q_online, tau=cfg.target_tau)

    if global_step > 0 and global_step % checkpoint_interval == 0:
        save_checkpoint(cfg.checkpoint_path, global_step)
    if global_step % cfg.log_interval == 0:
        elapsed_min = (time.time() - start_time) / 60.0
        recent_returns = returns[-10:]
        mean_return = np.mean(recent_returns) if recent_returns else float("nan")
        mean_loss = np.mean(loss_history) if loss_history else float("nan")
        print(
            f"Step {global_step:>7,}/{cfg.max_frames:,} | "
            f"epsilon {epsilon:.3f} | "
            f"buffer {len(buffer):>7,} | "
            f"episodes {len(returns):>5} | "
            f"mean return (last 10): {mean_return:>6.2f} | "
            f"loss {mean_loss:>8.4f} | "
            f"elapsed {elapsed_min:>5.1f} min"
        )

print("Training complete.")
if global_step % checkpoint_interval != 0:
    save_checkpoint(cfg.checkpoint_path, global_step)
if returns:
    print(f"Final mean return (last 10 episodes): {np.mean(returns[-10:]):.2f}")

8. Record and evaluate the agent¶

Render one evaluation episode with a nearly greedy policy, write an MP4, and report the return.

In [ ]:
@torch.inference_mode()
def play_and_record(env: gym.Env, q_net: nn.Module, epsilon_eval: float, path: str) -> float:
    frames = []
    state, _ = env.reset()
    done = False
    total_reward = 0.0
    q_net.eval()

    while not done:
        frame = env.render()
        frames.append(frame)
        action = select_action(state, q_net, epsilon_eval)
        state, reward, terminated, truncated, info = env.step(action)
        total_reward += reward
        done = terminated or truncated

    imageio.mimsave(path, frames, fps=30)
    print(f"Saved rollout to {path} ({len(frames)} frames)")
    print(f"Evaluation return: {total_reward}")
    return total_reward


eval_reward = play_and_record(eval_env, q_online, cfg.eval_epsilon, cfg.video_path)

9. Display the recorded video¶

In [ ]:
from IPython.display import Video

Video(cfg.video_path, embed=True, width=400)

10. Save the model checkpoint¶

Save the trained weights together with optimizer state and training metadata for reproducibility.

In [ ]:
# Manual checkpoint save (optional)
save_checkpoint(cfg.checkpoint_path, global_step)
In [ ]: