pytorch-patterns
PyTorch deep learning patterns and best practices for building robust, efficient, and reproducible training pipelines, model architectures, and data loading.
PyTorch Development Patterns
Idiomatic PyTorch patterns and best practices for building robust, efficient, and reproducible deep learning applications.
When to Activate
- Writing new PyTorch models or training scripts
- Reviewing deep learning code
- Debugging training loops or data pipelines
- Optimizing GPU memory usage or training speed
- Setting up reproducible experiments
Core Principles
1. Device-Agnostic Code
Always write code that works on both CPU and GPU without hardcoding devices.
# Good: Device-agnosticdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = MyModel().to(device)data = data.to(device)
# Bad: Hardcoded devicemodel = MyModel().cuda() # Crashes if no GPUdata = data.cuda()2. Reproducibility First
Set all random seeds for reproducible results.
# Good: Full reproducibility setupdef set_seed(seed: int = 42) -> None: torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
# Bad: No seed controlmodel = MyModel() # Different weights every run3. Explicit Shape Management
Always document and verify tensor shapes.
# Good: Shape-annotated forward passdef forward(self, x: torch.Tensor) -> torch.Tensor: # x: (batch_size, channels, height, width) x = self.conv1(x) # -> (batch_size, 32, H, W) x = self.pool(x) # -> (batch_size, 32, H//2, W//2) x = x.view(x.size(0), -1) # -> (batch_size, 32*H//2*W//2) return self.fc(x) # -> (batch_size, num_classes)
# Bad: No shape trackingdef forward(self, x): x = self.conv1(x) x = self.pool(x) x = x.view(x.size(0), -1) # What size is this? return self.fc(x) # Will this even work?Model Architecture Patterns
Clean nn.Module Structure
# Good: Well-organized moduleclass ImageClassifier(nn.Module): def __init__(self, num_classes: int, dropout: float = 0.5) -> None: super().__init__() self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2), ) self.classifier = nn.Sequential( nn.Dropout(dropout), nn.Linear(64 * 16 * 16, num_classes), )
def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) x = x.view(x.size(0), -1) return self.classifier(x)
# Bad: Everything in forwardclass ImageClassifier(nn.Module): def __init__(self): super().__init__()
def forward(self, x): x = F.conv2d(x, weight=self.make_weight()) # Creates weight each call! return xProper Weight Initialization
# Good: Explicit initializationdef _init_weights(self, module: nn.Module) -> None: if isinstance(module, nn.Linear): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") elif isinstance(module, nn.BatchNorm2d): nn.init.ones_(module.weight) nn.init.zeros_(module.bias)
model = MyModel()model.apply(model._init_weights)Training Loop Patterns
Standard Training Loop
# Good: Complete training loop with best practicesdef train_one_epoch( model: nn.Module, dataloader: DataLoader, optimizer: torch.optim.Optimizer, criterion: nn.Module, device: torch.device, scaler: torch.amp.GradScaler | None = None,) -> float: model.train() # Always set train mode total_loss = 0.0
for batch_idx, (data, target) in enumerate(dataloader): data, target = data.to(device), target.to(device)
optimizer.zero_grad(set_to_none=True) # More efficient than zero_grad()
# Mixed precision training with torch.amp.autocast("cuda", enabled=scaler is not None): output = model(data) loss = criterion(output, target)
if scaler is not None: scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) scaler.step(optimizer) scaler.update() else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)Validation Loop
# Good: Proper evaluation@torch.no_grad() # More efficient than wrapping in torch.no_grad() blockdef evaluate( model: nn.Module, dataloader: DataLoader, criterion: nn.Module, device: torch.device,) -> tuple[float, float]: model.eval() # Always set eval mode — disables dropout, uses running BN stats total_loss = 0.0 correct = 0 total = 0
for data, target in dataloader: data, target = data.to(device), target.to(device) output = model(data) total_loss += criterion(output, target).item() correct += (output.argmax(1) == target).sum().item() total += target.size(0)
return total_loss / len(dataloader), correct / totalData Pipeline Patterns
Custom Dataset
# Good: Clean Dataset with type hintsclass ImageDataset(Dataset): def __init__( self, image_dir: str, labels: dict[str, int], transform: transforms.Compose | None = None, ) -> None: self.image_paths = list(Path(image_dir).glob("*.jpg")) self.labels = labels self.transform = transform
def __len__(self) -> int: return len(self.image_paths)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, int]: img = Image.open(self.image_paths[idx]).convert("RGB") label = self.labels[self.image_paths[idx].stem]
if self.transform: img = self.transform(img)
return img, labelEfficient DataLoader Configuration
# Good: Optimized DataLoaderdataloader = DataLoader( dataset, batch_size=32, shuffle=True, # Shuffle for training num_workers=4, # Parallel data loading pin_memory=True, # Faster CPU->GPU transfer persistent_workers=True, # Keep workers alive between epochs drop_last=True, # Consistent batch sizes for BatchNorm)
# Bad: Slow defaultsdataloader = DataLoader(dataset, batch_size=32) # num_workers=0, no pin_memoryCustom Collate for Variable-Length Data
# Good: Pad sequences in collate_fndef collate_fn(batch: list[tuple[torch.Tensor, int]]) -> tuple[torch.Tensor, torch.Tensor]: sequences, labels = zip(*batch) # Pad to max length in batch padded = nn.utils.rnn.pad_sequence(sequences, batch_first=True, padding_value=0) return padded, torch.tensor(labels)
dataloader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn)Checkpointing Patterns
Save and Load Checkpoints
# Good: Complete checkpoint with all training statedef save_checkpoint( model: nn.Module, optimizer: torch.optim.Optimizer, epoch: int, loss: float, path: str,) -> None: torch.save({ "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, }, path)
def load_checkpoint( path: str, model: nn.Module, optimizer: torch.optim.Optimizer | None = None,) -> dict: checkpoint = torch.load(path, map_location="cpu", weights_only=True) model.load_state_dict(checkpoint["model_state_dict"]) if optimizer: optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) return checkpoint
# Bad: Only saving model weights (can't resume training)torch.save(model.state_dict(), "model.pt")Performance Optimization
Mixed Precision Training
# Good: AMP with GradScalerscaler = torch.amp.GradScaler("cuda")for data, target in dataloader: with torch.amp.autocast("cuda"): output = model(data) loss = criterion(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True)Gradient Checkpointing for Large Models
# Good: Trade compute for memoryfrom torch.utils.checkpoint import checkpoint
class LargeModel(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: # Recompute activations during backward to save memory x = checkpoint(self.block1, x, use_reentrant=False) x = checkpoint(self.block2, x, use_reentrant=False) return self.head(x)torch.compile for Speed
# Good: Compile the model for faster execution (PyTorch 2.0+)model = MyModel().to(device)model = torch.compile(model, mode="reduce-overhead")
# Modes: "default" (safe), "reduce-overhead" (faster), "max-autotune" (fastest)Quick Reference: PyTorch Idioms
| Idiom | Description |
|---|---|
model.train() / model.eval() | Always set mode before train/eval |
torch.no_grad() | Disable gradients for inference |
optimizer.zero_grad(set_to_none=True) | More efficient gradient clearing |
.to(device) | Device-agnostic tensor/model placement |
torch.amp.autocast | Mixed precision for 2x speed |
pin_memory=True | Faster CPU→GPU data transfer |
torch.compile | JIT compilation for speed (2.0+) |
weights_only=True | Secure model loading |
torch.manual_seed | Reproducible experiments |
gradient_checkpointing | Trade compute for memory |
Anti-Patterns to Avoid
# Bad: Forgetting model.eval() during validationmodel.train()with torch.no_grad(): output = model(val_data) # Dropout still active! BatchNorm uses batch stats!
# Good: Always set eval modemodel.eval()with torch.no_grad(): output = model(val_data)
# Bad: In-place operations breaking autogradx = F.relu(x, inplace=True) # Can break gradient computationx += residual # In-place add breaks autograd graph
# Good: Out-of-place operationsx = F.relu(x)x = x + residual
# Bad: Moving data to GPU inside the training loop repeatedlyfor data, target in dataloader: model = model.cuda() # Moves model EVERY iteration!
# Good: Move model once before the loopmodel = model.to(device)for data, target in dataloader: data, target = data.to(device), target.to(device)
# Bad: Using .item() before backwardloss = criterion(output, target).item() # Detaches from graph!loss.backward() # Error: can't backprop through .item()
# Good: Call .item() only for loggingloss = criterion(output, target)loss.backward()print(f"Loss: {loss.item():.4f}") # .item() after backward is fine
# Bad: Not using torch.save properlytorch.save(model, "model.pt") # Saves entire model (fragile, not portable)
# Good: Save state_dicttorch.save(model.state_dict(), "model.pt")Remember: PyTorch code should be device-agnostic, reproducible, and memory-conscious. When in doubt, profile with torch.profiler and check GPU memory with torch.cuda.memory_summary().