Issue
Goal: I have a function called by a loop that inputs a 1D tensor, and a 2D tensor. I use torch.linalg.solve()
in this function. I want to parallelize the loop to optimize the runtime.
Setup: I have 3 main tensors:
input_tensor
: size 50x100x100host_tensor
: size 100x100x100A
: size 50x100 (design matrix)
input_tensor
has 100x100 input_vector
, all length 50. They also all have a different amount of NaNs that I mask, hence the input_vector
masked having a length inferior or equal to 50. Note that the design matrix A
will also be masked and have size (mask x 100).
Because each input_vector
and A
have different masked lengths, the function needs to be run point-by-point.
Problem: Is there a way to make the following code faster ? How could I deal with the design matrix A
and input_vector
having different sizes at each iteration ?
Important: The NaNs can not be replaced by 0 as this would defeat the process of the linear solving process. As background, I asked a question about similar process here.
Code:
import torch
from tqdm import tqdm
import numpy as np
from datetime import datetime
# Create "device" so we can migrate the tensors to GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Set the seed for reproducibility
torch.manual_seed(42)
# Set shapes to generate tensors
B, C = 500, 500
M, N = 100, 50
# Generate tensors
input_tensor = torch.randn(N, B, C)
host_tensor = torch.randn(M, B, C)
A = torch.randn(N, M)
# --- Here we input random NaNs in the input_tensor to simulate missing data --- #
# Define the probability of inserting NaN at each element
probability = 0.2 # You can adjust this as needed
# Generate random indices based on the probability
shape = input_tensor.shape
random_indices = torch.rand(shape) < probability
# Replace the selected indices with NaN values
input_tensor[random_indices] = float('nan')
# --- Migrate matrices to GPU --- #
A = A.to(device)
input_tensor = input_tensor.to(device)
host_tensor = host_tensor.to(device)
A = A.to(device)
t_start = datetime.now()
# --- Function that creates a vector size M from input_vector (size N) and A
def solver(input_vector, A):
# We create a mask to reduce the row size of A: rows where input_vector is NaN are not considered in the solver
mask = ~torch.isnan(input_vector)
# Mask the vector
input_vector_masked = input_vector[mask]
# Mask the array
A_masked = A[mask]
A_trans = A_masked.T
# Solve the linear system of equation: A.TA = A.Tvec_Obs
return torch.linalg.solve(A_trans@A_masked, A_trans@input_vector_masked)
# --- Iterate through each vector of the input_tensor --- #
# Define the total number of iterations
total_iterations = B*C
# Create a tqdm progress bar
progress_bar = tqdm(total=total_iterations, dynamic_ncols=False, mininterval=1.0)
# Iterate through every cell of input_array
for i in range(host_tensor.shape[1]):
for j in range(host_tensor.shape[2]):
host_tensor[:,i,j] = solver(input_tensor[:,i,j], A)
progress_bar.update(1) # Update the progress bar
t_stop = datetime.now()
print(f"Inversion took {(t_stop - t_start).total_seconds():.2f}s")
Solution
I got a bit of an unsatisfactory answer here. But let's go step by step.
Zeroing nan
s == dropping nan
s
First of all, you can replace the nan
s with zeros. Take the following example: Assume you have a vector v
and a matrix A
, given as
v = [v1 v2 v3] # N elements
A = [[a11 a12 a13] # NxM elements
[a21 a22 a23]
[a31 a32 a33]]
Now, assume v2 = nan
and thus needs to be suppressed.
What you are currently doing in solver()
is getting the non-nan
elements of v
as m
, the corresponding rows of A
as M
and then calculate A_for_solving = M.T @ M
and B_for_solving = M.T @ v
, namely
m = [v1 v3] # Masked v (n < N elements)
M = [[a11 a12 a13] # Masked A (nxM elements)
[a31 a32 a33]]
A_for_solving = M.T @ M # MxM elements
B_for_solving = M.T @ m # M elements
result = linalg.solve(A_for_solving, B_for_solving)
You should notice two things here:
The shapes of
A_for_solving
andB_for_solving
always remain the same, no matter how many elements fromv
(and thus rows fromA
) are dropped:A_for_solving
is always an M×M matrix andB_for_solving
is always an M-element vector. This hints at the possibility that we can actually still parallelize our calculation.What's more, if you would replace the
nan
s inv
and the corresponding rows inA
with zeros, you would yield exactly the same values inA_for_solving
andB_for_solving
!In other words, you could do the following:
z = [v1 0 v3] # Zeroed v Z = [[a11 a12 a13] # Zeroed A [ 0 0 0] [a31 a32 a33]] A_for_solving = Z.T @ Z B_for_solving = Z.T @ z result = linalg.solve(A_for_solving, B_for_solving)
… and you would get exactly the same inputs to
linalg.solve()
as before!
You can easily check this with your current code by extending it for testing purposes as follows:
def solver(input_vector, A):
mask = ~torch.isnan(input_vector)
input_vector_masked = input_vector[mask]
A_masked = A[mask]
A_trans = A_masked.T
# Start sanity check: nan-zeroing is the same as nan-dropping
A_zeroed = A.clone(); A_zeroed[~mask] = 0
input_vector_zeroed = input_vector.clone(); input_vector_zeroed[~mask] = 0
assert torch.allclose(A_masked.T @ A_masked,
A_zeroed.T @ A_zeroed, atol=1e-5)
assert torch.allclose(A_masked.T @ input_vector_masked,
A_zeroed.T @ input_vector_zeroed, atol=1e-5)
# End sanity check
return torch.linalg.solve(A_trans@A_masked, A_trans@input_vector_masked)
Batched calculation
If we use the zeroing approach, we can parallelize our code again, as we now have inputs of the same size for each mask again. The corresponding function could look as follows:
def solver_batch(inpt, a):
inpt = inpt.permute(1, 2, 0).unsqueeze(-1) # BxCxNx1
mask = torch.isnan(inpt) # CAUTION: True for NaNs, unlike `mask` in the question!
a_zeroed = a.repeat(*inpt.shape[:2], 1, 1) # BxCxNxM
a_zeroed[mask.expand(-1, -1, -1, a.shape[-1])] = 0
at_a = a_zeroed.transpose(-2, -1) @ a_zeroed # BxCxMxM
inpt_zeroed = inpt.clone()
inpt_zeroed[mask] = 0
at_input = a_zeroed.transpose(-2, -1) @ inpt_zeroed # BxCxMx1
result = torch.linalg.solve(at_a, at_input)
return result.squeeze(-1).permute(2, 0, 1) # MxBxC
Caveats
The batched solution is quite similar to the answer that I posted to your previous question. There are two caveats though:
Caveat 1: Memory usage
As we need a different matrix A
and thus A.T @ A
for each input vector now, we end up with a tensor at_a
of size 500×500×100×100 in your given example. This is huge (a tensor of 2.5 billion elements in this case). In my case, it doesn't fit on the GPU, so what I had to do is process the input tensor in chunks:
chunk_size = 50 # TODO: adjust chunk size for your hardware
for lo in range(0, input_tensor.shape[-1], chunk_size):
chunk_result = solver_batch(input_tensor[..., lo:lo+chunk_size], A)
host_tensor[..., lo:lo+chunk_size] = chunk_result
This is still much faster than processing the input element-wise though.
Caveat 2: Numerical instability
I tried to sanity-check results with the following for-loop, similar to the sanity check in my previous answer:
for i in range(host_tensor.shape[1]):
for j in range(host_tensor.shape[2]):
input_vec = input_tensor[..., i, j]
res_vec = host_tensor[..., i, j]
mask = ~torch.isnan(input_vec)
M = A[mask]
assert torch.allclose((M.T @ M) @ res_vec, M.T @ input_vec[mask], atol=1e-3)
What we check here is that A @ X = B
should hold for X = solve(A, B)
by definition. This, however, seems not to be the case with the given data, neither for mine nor for your approach. I don't know if this is a problem of numerical instabilities (my maths skills are lacking there) or whether I made some stupid mistake.
Answered By - simon
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.