The code snippet implements Gram-Schmidt orthogonalization to ensure the gradient is orthogonal to the momentum direction:
ref_norm = buf / (buf.norm() + 1e-8)
proj = g @ ref_norm.T @ ref_norm
g = g - proj
This process:
First normalizes the momentum vector for numerical stability Then computes the projection of the gradient onto the momentum direction Finally subtracts this projection to obtain the orthogonal component The resulting g_orth is guaranteed to be orthogonal to the momentum direction