Issue
I am developing something on top of an already existing framework of code, and I am having some trouble extracting weights from a neural network defined as a class. Code below
import numpy as np
import torch
import torch.nn as nn
class Solver:
class Head(nn.Module):
def __init__(self, base):
super().__init__()
self.base = base
self.last_layer = nn.Linear(100, 10)
def forward(self, x):
x = self.base(x)
x = self.last_layer(x)
return x
def __init__(self, bases, HeadClass=None):
self.base = bases
if HeadClass:
self.head = self.Head(self.base)
else:
self.head = self.Head(self.base)
print('Head Class:',self.head)
class Full_Solver:
class Base(nn.Module):
def __init__(self):
super().__init__()
self.linear_1 = nn.Linear(1, 100)
self.linear_2 = nn.Linear(100, 100)
self.linear_3 = nn.Linear(100, 100)
def forward(self, x):
x = self.linear_1(x)
x = torch.tanh(x)
x = self.linear_2(x)
x = torch.tanh(x)
x = self.linear_3(x)
x = torch.tanh(x)
return x
def __init__(self, BaseClass=Base()):
self.base = BaseClass
print('Base model:',self.base)
print('Base model type:',type(self.base))
solver_1 = Solver(self.base)
print('Full model:',solver_1)
print('Full model type:',type(solver_1))
xx = Full_Solver()
In the Full_Solver
class, I am defining a Base Neural Network, which will add a Head based on some conditions in the Solver
class (I have left out all the conditions for sake of brevity). When I do print('Full model:',solver_1)
, the output I am getting is Full model: <__main__.Solver object at 0x7f83a82e9cd0>
. How do I extract the output weights from this class object? (Assume that I just want to extract the randomly assigned weights from this)
Solution
In this case it would be print('Full model:',solver_1.head.last_layer.weight)
Answered By - StBlaize
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.