Most people tune the optimizer. Almost nobody touches the thing the optimizer actually eats: the gradient. But the gap between loss.backward() and optimizer.step() is a real hook point, and you can do useful work there. This is a short, runnable guide to intercepting gradients and transforming them before the update lands.
The mental model
A training step is really four moves:
- Forward pass: compute predictions and loss
- Backward pass:
loss.backward()fillsparameter.gradfor every parameter - Update:
optimizer.step()reads those.gradtensors and adjusts the weights - Reset:
optimizer.zero_grad()
The window we care about is between steps 2 and 3. After backward(), the gradients exist as plain tensors in p.grad. The optimizer has not touched them yet. That is where we intervene.
Step 1: a baseline step
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 1))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
x = torch.randn(64, 10)
y = torch.randn(64, 1)
optimizer.zero_grad()
loss = loss_fn(model(x), y)
loss.backward()
optimizer.step()
Standard. Now let's get between the last two lines.
Step 2: transform the gradients in place
After backward(), iterate the parameters and modify each .grad before stepping:
def soft_threshold(g, lam=1e-4):
# shrink every gradient component toward zero by lam
# this is the core operation behind wavelet denoising
return torch.sign(g) * torch.clamp(g.abs() - lam, min=0.0)
optimizer.zero_grad()
loss = loss_fn(model(x), y)
loss.backward()
for p in model.parameters():
if p.grad is not None:
p.grad = soft_threshold(p.grad)
optimizer.step()
That is the entire idea. Anything you can express as a function of a tensor, you can apply to the gradient: clipping, smoothing, denoising, sign-based updates, masking. The optimizer never knows the difference.
Step 3: make it reusable and optimizer-agnostic
Editing the training loop by hand gets messy. Wrap it instead, so the transform works with any first-order optimizer without changing your loop:
class GradTransform:
def __init__(self, optimizer, transform):
self.optimizer = optimizer
self.transform = transform
def zero_grad(self, *args, **kwargs):
self.optimizer.zero_grad(*args, **kwargs)
def step(self, *args, **kwargs):
for group in self.optimizer.param_groups:
for p in group["params"]:
if p.grad is not None:
p.grad = self.transform(p.grad)
self.optimizer.step(*args, **kwargs)
Usage is a drop-in:
base = torch.optim.Adam(model.parameters(), lr=1e-3)
optimizer = GradTransform(base, soft_threshold)
# training loop is unchanged
optimizer.zero_grad()
loss = loss_fn(model(x), y)
loss.backward()
optimizer.step()
Swap Adam for SGD and it still works, because the wrapper only touches .grad, which every first-order optimizer reads the same way.
A quicker alternative: tensor hooks
If you want the transform to fire automatically during the backward pass instead of after it, register a hook directly on a parameter:
for p in model.parameters():
p.register_hook(lambda g: soft_threshold(g))
The hook runs as the gradient is computed. The wrapper approach is usually easier to reason about because everything happens in one obvious place, but hooks are handy when you want the change to be invisible to the rest of your code.
One honest warning
Intercepting gradients is powerful, which means it is also a good way to quietly break training. A transform that helps on a noisy problem can actively hurt on a clean one, where the gradient signal was fine to begin with. So treat any gradient transform as a hypothesis, not a free win: run it against an untouched baseline on your actual task, with the same seeds, and keep it only if the numbers say so. The hook is easy. Earning the improvement is the hard part.
If you want to see this idea taken all the way, soft-thresholding the gradient is exactly the building block behind WaveGuard, a gradient denoiser I built that swaps the flat threshold for a Haar wavelet transform and gates it so it stays quiet when the gradient is already clean. I benchmarked it honestly, including a task where it actively made things worse.
Write-up here: https://medium.com/@chukwudieke61/adam-cant-hear-the-signal-through-the-noise-waveguard-can-ffb1d8963a38
Code here: https://github.com/Harry-Potter20/wavelet-grad.