Issue
I'd like to manually predict data (regression task) using the model weights and biases in Keras (i.e., Z = f(WX+B), where Z is the output of a layer, W is the weight matrix, X is the input vector, and B is the bias vector).
My model architecture is as follows:
model = Sequential()
model.add(Dense(207, input_dim = 173, activation = 'relu'))
model.add(Dense(207, activation = 'relu'))
model.add(Dense(207, activation = 'relu'))
model.add(Dense(240))
I'm also done with training the model.
I have already tried extracting the weights and biases using the following:
layer_num = 1
layer_weights = model.layers[layer_num-1].get_weights()[0]
layer_biases = model.layers[layer_num-1].get_weights()[1]
where layer_num
is the layer number (layer_num
runs from 1 to 4 for my model). However, I can't figure out how to do the Z = f(WX+B) manually for my test set x_test
.
Using this manual method, I'd like to compare its results and model.predict
, as well as get the prediction computation time per instance/row of x_test
(but I suppose I can figure out this last part on my own).
Edit 1: I mentioned that I can figure out the computation time per instance/row on my own, but if you could help me with this also, I'd really appreciate it.
Solution
For a dense layer, when you call the get_weights
method, you will get a list, whose first element is the weight matrix W
and second element is the bias b
:
W, b = dense_layer.get_weights()
Suppose you have a test data matrix X
with shape [batch, 173]
. Then you can write a code snippet like below to get all the weights and biases:
Ws, bs = [], []
for layer in model.layers:
W, b = layer.get_weights()
Ws.append(W)
bs.append(b)
With them at hand, we can define our function:
def relu(x):
return np.where(x >= 0, x, 0)
def identity(x):
return x
def predict(Ws, bs, activations, X):
Z = X
for W, b, activation in zip(Ws, bs, activations):
Z = Z @ W + b
Z = activation(Z)
return Z
Then you can use the predict
function to manually predict for test X
:
Y = predict(Ws, bs, (relu, relu, relu, identity), x_test)
Answered By - Terence
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.