Issue
Goal:
I have a function calling torch.linalg.solve()
that I want to run as fast as I can.
Setup: I have an input_array
(size 50x100x100
). I have a host_array
(size 100x100x100
?. My function solver
inputs input_array[:,i,j]
and outputs a vector size 100 to store in host_array[:,i,j]
. I am running a nested loop through all rows and columns of input_array
to populate host_array
.
Problem: This is slow to run, especially considering my real case where each call of the function takes a second. I am running it with a nested loop, and I would like to know if it would be faster by parallelizing my function ?
Example Code:
import torch
from tqdm import tqdm
# Create host_array and input_array with random data
host_array = torch.zeros(100, 500, 500)
input_array = torch.randn(50, 500, 500)
# Create a dummy coefficient matrix A (50x100)
A = torch.randn(50, 100)
# Define your function to solve for input_array[:, i, j] and update host_array[:, i, j]
def solver(input_vector, A):
# Solve the linear system of equation
solution = torch.linalg.solve(A.T@A, A.T@input_vector)
return solution
# Calculate total runtime
total_iterations = int(host_array.shape[1]*host_array.shape[2])
progress_bar = tqdm(total=total_iterations, dynamic_ncols=False, mininterval=1.0)
# Iterate through the input_array
for i in range(host_array.shape[1]):
for j in range(host_array.shape[2]):
host_array[:,i,j] = solver(input_array[:,i,j], A)
progress_bar.update(1)
Solution
You can make use of the broadcasting capabilities of torch.linalg.solve()
to gain a significant speedup – see the section # proposed solution
and, in particular, the function solver_batch()
in my code below. I annotated the shapes that result from the necessary reshapings (squeezing, unsqueezing, and permuting) of the inputs.
from datetime import datetime
import torch
torch.manual_seed(42) # Make result reproducible
B, C = 500, 500
M, N = 100, 50
input_array = torch.randn(N, B, C)
A = torch.randn(N, M)
# Proposed solution
t_start = datetime.now()
def solver_batch(inpt, a):
at_a = a.T @ a # MxM
inpt = inpt.permute(1, 2, 0).unsqueeze(-1) # BxCxNx1
at_input = a.T @ inpt # BxCxMx1
result = torch.linalg.solve(at_a, at_input) # BxCxMx1
return result.squeeze(-1).permute(2, 0, 1) # MxBxC
proposed_result = solver_batch(input_array, A)
t_stop = datetime.now()
print(f"Proposed solution took {(t_stop - t_start).total_seconds():.2f}s")
# Previous solution
t_start = datetime.now()
host_array = torch.zeros(M, B, C) # Will hold the result
def solver(input_vector, A):
return torch.linalg.solve(A.T@A, A.T@input_vector)
for i in range(host_array.shape[1]):
for j in range(host_array.shape[2]):
host_array[:,i,j] = solver(input_array[:,i,j], A)
t_stop = datetime.now()
print(f"Previous solution took {(t_stop - t_start).total_seconds():.2f}s")
# Check results
left = (A.T @ A) @ host_array.permute(1, 2, 0).unsqueeze(-1)
right = A.T @ input_array.permute(1, 2, 0).unsqueeze(-1)
print("(A.T @ A) @ host_array == A.T @ input_array?",
torch.allclose(left, right, atol=1e-3))
left = (A.T @ A) @ proposed_result.permute(1, 2, 0).unsqueeze(-1)
right = A.T @ input_array.permute(1, 2, 0).unsqueeze(-1)
print("(A.T @ A) @ proposed_result == A.T @ input_array?",
torch.allclose(left, right, atol=1e-3))
On my machine, I got:
>>> Proposed solution took 0.62s
>>> Previous solution took 21.74s
>>> (A.T @ A) @ host_array == A.T @ input_array? True
>>> (A.T @ A) @ proposed_result == A.T @ input_array? True
Note that, while both host_array
and proposed_result
hold valid solutions, they do not necessarily hold the same solution (in fact, for the given random seed, they are not the same). This, if I understand correctly, is because the result of torch.linalg.solve()
is unique if and only if its first argument (A.T @ A
in our case) is invertible, which it does not necessarily seem to be. If you want to check that both your and my approach indeed produce the same solution (subject to numerical error) for an invertible matrix, you could construct yourself an MxM
invertible matrix following this approach and replace A.T @ A
with it for testing purposes.
Also note that, when comparing the result with torch.allclose()
, I had to be quite generous (atol=1e-3
) as quite some numerical error seems to build up.
Answered By - simon
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.