Issue
Let's say I want to compute an inner product along the last dimension of two matrices
a = jax.random.normal(jax.random.PRNGKey(0), shape=(64,16), dtype=jnp.float32)
b = jax.random.normal(jax.random.PRNGKey(1), shape=(64,16), dtype=jnp.float32)
I can do it with jnp.einsum
:
inner_prod1 = jnp.einsum('i d, j d -> i j', a, b)
or manually call jnp.dot
in a loop:
inner_prod2 = jnp.zeros((64,64))
for i1 in range(64):
for i2 in range(64):
inner_prod2 = inner_prod2.at[i1, i2].set(jnp.dot(a[i1], b[i2]))
print(jnp.amax(inner_prod1 - inner_prod2)) # 0.03830552
This is quite a large difference between the two, even if they are mathematically equivalent. What gives?
Solution
All operations in floating point accumulate rounding errors, so in general when you express the same operation in two different ways, you should expect the results to not be bitwise-equivalent.
The magnitude of the difference you're seeing is larger than is typical for float32 precision; it makes me think you're probably running your code on TPU, where matrix multiplication is done at lower-precision by default. You can adjust this using the default_matmul_precision
configuration; for example like this:
with jax.default_matmul_precision('float32'):
inner_prod1 = jnp.einsum('i d, j d -> i j', a, b)
inner_prod2 = jnp.zeros((64,64))
for i1 in range(64):
for i2 in range(64):
inner_prod2 = inner_prod2.at[i1, i2].set(jnp.dot(a[i1], b[i2]))
If you do the computation this way, I suspect you'll probably see a smaller difference more typical of float32 computations, on order 1E-6
or so.
Answered By - jakevdp
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.