Issue
When I do
import torch
x = torch.ones(3, 4)
x.norm(p='nuc')
it gives
tensor(3.4641)
how is this value calculated?
Solution
The nuclear norm, aka the trace norm, is the sum of singular values of x
or equivalently one of the expressions below (assuming x is real)
u,s,v = torch.svd(x, compute_uv=False)
print(torch.sum(s))
eigs, eigvecs = torch.symeig(x.transpose(1,0) @ x)
print(torch.sum(torch.sqrt(torch.abs(eigs))))
Finding the matrix square root isn't supported natively in pytorch (you could use symeig
but then this would reduce to the previous expression). If you use something like this sqrtm
implementation then you can compute the nuclear norm using
print(torch.trace(sqrtm(x.transpose(1,0) @ x))
From the above expression, it should be clear that if x
is positive semi-definite then the trace norm is just
# use this only if you know x is positive semi-definite
print(torch.trace(x))
Answered By - jodag
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.