Issue
Currently, I separated train.py
with model.py
for my deep learning project.
So for the datasets, they are sent to cuda device inside the epoch for loop
like below.
train.py
...
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = MyNet(~).to(device)
...
for batch_data in train_loader:
s0 = batch_data[0].to(device)
s1 = batch_data[1].to(device)
pred = model(s0, s1)
However, inside my model (in model.py), it also needs to access the device variable for skip connection like method. To make a new copy of hidden unit (for residual connection)
model.py
class MyNet(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats):
super(MyNet, self).__init__()
self.conv1 = GCNConv(in_feats, hid_feats)
...
def forward(self, data):
x, edge_index = data.x, data.edge_index
x1 = copy.copy(x.float())
x = self.conv1(x, edge_index)
skip_conn = torch.zeros(len(data.batch), x1.size(1)).to(device) # <--
(some opps for x1 -> skip_conn)
x = torch.cat((x, skip_conn), 1)
In this case, I am currently passing device
as a parameter, however, I believe this is not a best practice.
- Where should be the best practice to send the dataset to CUDA?
- In the case of multiple scripts need to access
device
, how sould I handle this? (parameter, global variable?)
Solution
You can add a new attribute to MyModel
to store the device
info and use this in the skip_conn
initialization.
class MyNet(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, device): # <--
super(MyNet, self).__init__()
self.conv1 = GCNConv(in_feats, hid_feats)
self.device = device # <--
self.to(self.device) # <--
...
def forward(self, data):
x, edge_index = data.x, data.edge_index
x1 = copy.copy(x.float())
x = self.conv1(x, edge_index)
skip_conn = torch.zeros(len(data.batch), x1.size(1), device=self.device) # <--
(some opps for x1 -> skip_conn)
x = torch.cat((x, skip_conn), 1)
Notice that in this example, MyNet
is responsible for all the device logic including the .to(device)
call. This way, we are encapsulating all model-related device management in the model class itself.
Answered By - Dani Cores
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.