Issue
- Is there a simpler and more performant way to do this?
- Can it be done for each channel individually instead of having a global max/min, in a single operation?
Essentially "soft clamping" values above/below threshold/-threshold to boundary/-boundary
import torch
input_tensor = torch.tensor([[
[[-19, -7, -5, -4, -3, -1, 0, 1, 2, 3, 4, 7, 8, 10, 12, 13, 14, 15, 19],
[-17, -7, -5, -4, -3, -1, 0, 1, 2, 3, 4, 6, 7, 9, 12, 13, 14, 15, 16],
[-12, -7, -5, -4, -3, -1, 0, 1, 2, 3, 4, 7, 8, 9, 11, 13, 13, 15, 17],
[-11, -7, -5, -4, -3, -1, 0, 1, 2, 3, 4, 7, 8, 9, 12, 13, 14, 15, 19]]
]], dtype=torch.float16)
# Define the threshold and boundary
threshold = 3
boundary = 4
# Apply the smooth clamping operation to each channel individually
soft_clamped = torch.where(
input_tensor > threshold, # Above threshold
((input_tensor - threshold) / (input_tensor.max() - threshold)) * (boundary - threshold) + threshold,
torch.where(
input_tensor < -threshold, # Below -threshold
((input_tensor + threshold) / (input_tensor.min() + threshold)) * (-boundary + threshold) - threshold,
input_tensor
)
)
print(soft_clamped)
tensor([[[[-4.0000, -3.2500, -3.1250, -3.0625, -3.0000, -1.0000, 0.0000,
1.0000, 2.0000, 3.0000, 3.0625, 3.2500, 3.3125, 3.4375,
3.5625, 3.6250, 3.6875, 3.7500, 4.0000],
[-3.8750, -3.2500, -3.1250, -3.0625, -3.0000, -1.0000, 0.0000,
1.0000, 2.0000, 3.0000, 3.0625, 3.1875, 3.2500, 3.3750,
3.5625, 3.6250, 3.6875, 3.7500, 3.8125],
[-3.5625, -3.2500, -3.1250, -3.0625, -3.0000, -1.0000, 0.0000,
1.0000, 2.0000, 3.0000, 3.0625, 3.2500, 3.3125, 3.3750,
3.5000, 3.6250, 3.6250, 3.7500, 3.8750],
[-3.5000, -3.2500, -3.1250, -3.0625, -3.0000, -1.0000, 0.0000,
1.0000, 2.0000, 3.0000, 3.0625, 3.2500, 3.3125, 3.3750,
3.5625, 3.6250, 3.6875, 3.7500, 4.0000]]]],
dtype=torch.float16)
Solution
Is there a simpler and more performant way to do this?
I'm not sure there is. If I'm understanding your question correctly, the replacement values depend on the relative min/max values of the input as well as the size of the input (ie 5 values below range has a different interpolation compared to 3 values below range). This means a different replacement mask will need to be computed for each input.
Can it be done for each channel individually instead of having a global max/min, in a single operation?
Yes, you can do this by taking the max/min along a dimension and broadcasting.
import torch
input_tensor = torch.tensor([[
[[-19, -7, -5, -4, -3, -1, 0, 1, 2, 3, 4, 7, 8, 10, 12, 13, 14, 15, 19],
[-17, -7, -5, -4, -3, -1, 0, 1, 2, 3, 4, 6, 7, 9, 12, 13, 14, 15, 16],
[-12, -7, -5, -4, -3, -1, 0, 1, 2, 3, 4, 7, 8, 9, 11, 13, 13, 15, 17],
[-11, -7, -5, -4, -3, -1, 0, 1, 2, 3, 4, 7, 8, 9, 12, 13, 14, 15, 19]]
]], dtype=torch.float16)
threshold = 3
boundary = 4
channel_dim = 3
max_vals = input_tensor.max(channel_dim, keepdim=True)[0]
max_replace = ((input_tensor - threshold) / (max_vals - threshold)) * (boundary - threshold) + threshold
over_mask = (input_tensor > threshold)
min_vals = input_tensor.min(channel_dim, keepdim=True)[0]
min_replace = ((input_tensor + threshold) / (min_vals + threshold)) * (-boundary + threshold) - threshold
under_mask = (input_tensor < -threshold)
soft_clamped = torch.where(over_mask, max_replace, torch.where(under_mask, min_replace, input_tensor))
print(soft_clamped)
>tensor([[[[-4.0000, -3.2500, -3.1250, -3.0625, -3.0000, -1.0000, 0.0000,
1.0000, 2.0000, 3.0000, 3.0625, 3.2500, 3.3125, 3.4375,
3.5625, 3.6250, 3.6875, 3.7500, 4.0000],
[-4.0000, -3.2852, -3.1426, -3.0723, -3.0000, -1.0000, 0.0000,
1.0000, 2.0000, 3.0000, 3.0762, 3.2305, 3.3086, 3.4609,
3.6914, 3.7695, 3.8457, 3.9219, 4.0000],
[-4.0000, -3.4453, -3.2227, -3.1113, -3.0000, -1.0000, 0.0000,
1.0000, 2.0000, 3.0000, 3.0723, 3.2852, 3.3574, 3.4277,
3.5703, 3.7148, 3.7148, 3.8574, 4.0000],
[-4.0000, -3.5000, -3.2500, -3.1250, -3.0000, -1.0000, 0.0000,
1.0000, 2.0000, 3.0000, 3.0625, 3.2500, 3.3125, 3.3750,
3.5625, 3.6250, 3.6875, 3.7500, 4.0000]]]],
You can also aggregate via
replace_mask = ~(over_mask + under_mask)
soft_clamped = (replace_mask * input_tensor) + (max_replace * over_mask) + (min_replace * under_mask)
But i think the where
syntax might be more efficient.
Answered By - Karl
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.