मैं ग्राफ कनवल्शनल नेटवर्क्स (जीसीएन) और मॉडल की इंटरमीडिएट परतों के आउटपुट तक पहुंचने की कोशिश कर रहा हूं। भविष्यवाणी इनपुट वैल्यू के लिए अमान्य आर्ग्यूमेंट त्रुटि फेंक रही है जहां मॉडल.फिट उसी इनपुट के साथ ठीक काम कर रहा है।
यह मेरा कोड है और यह spektral लाइब्रेरी द्वारा प्रदान किए गए OGB से 'CORA' उद्धरण डेटासेट का उपयोग कर रहा है जो ग्राफ़ के लिए एल्गोरिदम और उदाहरण प्रदान करता है कनवल्शनल नेटवर्क। मेरा कोड उसी लाइब्रेरी के एक उदाहरण पर आधारित है, यहां
from spektral.datasets import citation
from spektral.layers import GraphConv
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dropout, Dense
import numpy as np
A, X, y, train_mask, val_mask, test_mask = citation.load_data('cora')
At = A.transpose()
N = A.shape[0]
F = X.shape[-1]
n_classes = y.shape[-1]
X_in = Input(shape=(F, ))
A_in = Input((N, ), sparse=True)
X_1 = GraphConv(16, 'relu', name="layer1")([X_in, A_in])
X_1 = Dropout(0.5, name="layer2")(X_1)
X_2 = GraphConv(n_classes, 'softmax', name="output")([X_1, A_in])
model = Model(inputs=[X_in, A_in], outputs=X_2)
A = GraphConv.preprocess(A).astype('f4')
At = GraphConv.preprocess(At).astype('f4')
model.compile(optimizer='adam',
loss='categorical_crossentropy',
weighted_metrics=['acc'])
model.summary()
# Prepare data
X = X.toarray()
A = A.astype('f4')
At = At.astype('f4')
validation_data = ([X, A], y, val_mask)
# Train model
model.fit([X, A],
y,
sample_weight=train_mask,
validation_data=validation_data,
epochs=1,
batch_size=N,
shuffle=False
)
# Access intemediate layers of model
layer_name = 'layer2'
intermediate_layer_model = Model(inputs=model.input,
outputs=model.get_layer(layer_name).output)
model_input = [X,A]
intermediate_output = intermediate_layer_model.predict(model_input)
print("\n\nIntermediate_output=",intermediate_output,"\n\n")
यहाँ त्रुटि संदेश है:
Traceback (most recent call last):
File "PLGcn_example4_stackflow_debug.py", line 53, in <module>
intermediate_output = intermediate_layer_model.predict(model_input)
File "/home/mansoor4/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 130, in _method_wrapper
return method(self, *args, **kwargs)
File "/home/mansoor4/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1599, in predict
tmp_batch_outputs = predict_function(iterator)
File "/home/mansoor4/.local/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
result = self._call(*args, **kwds)
File "/home/mansoor4/.local/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 846, in _call
return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds) # pylint: disable=protected-access
File "/home/mansoor4/.local/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1848, in _filtered_call
cancellation_manager=cancellation_manager)
File "/home/mansoor4/.local/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1924, in _call_flat
ctx, args, cancellation_manager=cancellation_manager))
File "/home/mansoor4/.local/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 550, in call
ctx=ctx)
File "/home/mansoor4/.local/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot multiply A and B because inner dimension does not match: 2708 vs. 32. Did you forget a transpose? Dimensions of A: [32, 2708). Dimensions of B: [32,16]
[[node functional_3/layer1/SparseTensorDenseMatMul/SparseTensorDenseMatMul (defined at /home/mansoor4/.local/lib/python3.7/site-packages/spektral/layers/ops/matmul.py:33) ]] [Op:__inference_predict_function_22928]
Errors may have originated from an input operation.
Input Source operations connected to node functional_3/layer1/SparseTensorDenseMatMul/SparseTensorDenseMatMul:
stack (defined at PLGcn_example4_stackflow_debug.py:53)
functional_3/layer1/MatMul (defined at /home/mansoor4/.local/lib/python3.7/site-packages/spektral/layers/ops/matmul.py:45)
Function call stack:
predict_function
त्रुटि संदेश गुणन के लिए आंतरिक आयामों के बेमेल होने से संबंधित है। मैंने समस्या को ठीक करने के लिए model_input = [X, At] जैसे इनपुट के लिए ट्रांसपोंस का उपयोग करने का प्रयास किया लेकिन फिर भी उसी त्रुटि का सामना करना पड़ा।
मैं केरस और स्पेक्ट्रल के लिए नया हूँ। मैंने स्टैकफ्लो पर संबंधित पोस्ट की खोज की है और कई संभावनाओं की कोशिश की है लेकिन नेटवर्क से मध्यवर्ती मूल्यों का आउटपुट नहीं मिल सका।
1 उत्तर
समाधान
केरस Model
के predict
फ़ंक्शन में batch_size=32
का डिफ़ॉल्ट तर्क होता है। आप इसे दो तरह से हल कर सकते हैं।
intermediate_output = intermediate_layer_model.predict(model_input, batch_size=N)
या
intermediate_output = intermediate_layer_model.predict_on_batch(model_input)
आपके कोड में, आपके आसन्न मैट्रिक्स और नोड फीचर मैट्रिक्स के पहले आयाम को 32 के बैचों में विभाजित किया जाएगा। हालांकि, मॉडल को हर समय पूर्ण ग्राफ की उम्मीद है, इसलिए आपको अपने बैच का आकार N
पर सेट करना चाहिए। (model.fit
को कॉल करने पर आप यही करते हैं)।
व्याख्या
यह देखने के लिए कि इसकी आवश्यकता क्यों है, उन कार्यों के बारे में सोचें जो एक GCN परत हुड के नीचे करती है: A @ X @ W
। यह आकार (एन, एन) एक्स (एन, एफ) एक्स (एफ, एफ ') के साथ एक मैट्रिक्स गुणन है। ध्यान दें कि कैसे गुणा के आंतरिक आयाम हमेशा समान होते हैं: एन के साथ एन और एफ के साथ एफ।
अब, यदि आप बैचिंग करते हैं, तो आप A और X के पहले आयाम को B=32 पर सेट कर रहे हैं। यह आपको गुणा (बी, एन) एक्स (बी, एफ) एक्स (एफ, एफ') देता है। देखें कि पहले गुणन के आंतरिक आयाम अब कैसे मेल नहीं खाते? यह वह त्रुटि है जिसे TF उठा रहा है। यह आपको बता रहा है:
Cannot multiply A and B because inner dimension does not match: 2708 vs. 32
इस मामले में, एन = 2708 और बी = 32।
चियर्स
संबंधित सवाल
नए सवाल
python
पायथन एक बहु-प्रतिमान है, गतिशील रूप से टाइप किया हुआ, बहुउद्देशीय प्रोग्रामिंग भाषा है। यह एक साफ और एक समान वाक्यविन्यास सीखने, समझने और उपयोग करने के लिए त्वरित होने के लिए डिज़ाइन किया गया है। कृपया ध्यान दें कि अजगर 2 आधिकारिक तौर पर 01-01-2020 के समर्थन से बाहर है। फिर भी, संस्करण-विशिष्ट पायथन सवालों के लिए, [अजगर -२.०] या [अजगर -३.x] टैग जोड़ें। पायथन वेरिएंट (जैसे, ज्योथन, PyPy) या लाइब्रेरी (उदा।, पांडस और न्यूमपी) का उपयोग करते समय, कृपया इसे टैग में शामिल करें।