Issue
I'm trying to determine how to compute KL Divergence of two torch.distribution.Distribution
objects. I couldn't find a function to do that so far. Here is what I've tried:
import torch as t
from torch import distributions as tdist
import torch.nn.functional as F
def kl_divergence(x: t.distributions.Distribution, y: t.distributions.Distribution):
"""Compute the KL divergence between two distributions."""
return F.kl_div(x, y)
a = tdist.Normal(0, 1)
b = tdist.Normal(1, 1)
print(kl_divergence(a, b)) # TypeError: kl_div(): argument 'input' (position 1) must be Tensor, not Normal
Solution
torch.nn.functional.kl_div
is computing the KL-divergence loss. The KL-divergence between two distributions can be computed using torch.distributions.kl.kl_divergence
.
Answered By - jodag
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.