LoRA and Weight Decay (2023)

· ai · Source ↗

TLDR

  • LoRA with weight decay implicitly regularizes toward the base model weights, not zero, making it a fundamentally different optimization problem than full finetuning.

Key Takeaways

  • Standard weight decay on LoRA adapter matrices A and B pushes them to zero, which means the effective weight W converges to W_init, not zero as in full finetuning.
  • Increasing LoRA rank r, even to full rank, does not fix this divergence – the implicit objective remains biased toward the frozen base model.
  • This can be a feature (preserving base model knowledge) or a bug (task adaptation is implicitly constrained), depending on use case.
  • A corrected weight decay can be derived: decay terms become (W_init + AB)B^T and A^T(W_init + AB) for A and B respectively, matching full finetuning’s regularization target.
  • The fix is implementable in Optax with a custom update_fn that extracts W_init, A, and B per layer and applies the corrected decay.

Hacker News Comment Review

  • No substantive HN discussion yet.

Original | Discuss on HN