मैं अपने मॉडल को पुनर्स्थापित करने के लिए इस कोड का उपयोग करता हूं, लेकिन मुझे नहीं पता कि इसे पुनर्स्थापित करने के बाद भविष्यवाणी कैसे करें, मैं किस फ़ंक्शन का उपयोग कर सकता हूं? मैं टेंसरफ़्लो में एक नौसिखिया हूं, मुझे नहीं पता कि कौन से पैरामीटर या फ़ंक्शन सहेजे जाएंगे।

मेटा मॉडल में:

sess = tf.Session()
saver = tf.train.import_meta_graph("/home/MachineLearning/model.ckpt.meta")
saver.restore(sess,tf.train.latest_checkpoint('./'))
print("Model restored with success ")
x_predict,y_predict= load_svmlight_file('/MachineLearning/to_predict.csv')
x_predict = x_valid.toarray()
sess.run([] ,feed_dict ) #i don't know how to use predict function

ये परिणाम हैं:

$python predict.py
Model restored with success 
Traceback (most recent call last):
  File "predict.py", line 23, in <module>
    sess.run([] ,feed_dict )
NameError: name 'feed_dict' is not defined
2
Vectoria 18 अप्रैल 2018, 17:31

1 उत्तर

सबसे बढ़िया उत्तर

तुम लगभग वहां थे। Tensorflow बस एक गणित पुस्तकालय है। आपका ग्राफ संबंधित निर्भरताओं के साथ गणित संचालन का एक संग्रह है (उदाहरण के लिए एक ग्राफ, विशेष रूप से डीएजी)।

जब आप ग्राफ़ और संबद्ध चर (भार) लोड करते हैं तो आपने सभी परिभाषाएं लोड की हैं। अब आपको ग्राफ में कुछ मान की गणना करने के लिए टेंसरफ़्लो से पूछने की आवश्यकता है। बहुत सारे मूल्य हैं जिनकी यह गणना कर सकता है, जिसे आप चाहते हैं उसे अक्सर logits नाम दिया जाता है (तंत्रिका नेटवर्क की आउटपुट परत के लिए एक विशिष्ट नाम)। लेकिन ध्यान दें कि इसे कुछ भी नाम दिया जा सकता है (विशेषकर यदि यह एक तंत्रिका नेटवर्क मॉडल नहीं है), तो आपको मॉडल को समझने की आवश्यकता है। आप accuracy नाम के एक ऑपरेशन की गणना भी करना चाह सकते हैं, जिसे इनपुट के एक विशेष बैच की सटीकता की गणना करने के लिए परिभाषित किया गया है (फिर से आपके मॉडल पर निर्भर करता है)।

ध्यान दें कि इन गणनाओं को करने के लिए आपको जो कुछ भी चाहिए, उसके साथ आपको टेंसरफ़्लो प्रदान करने की आवश्यकता होगी। आम तौर पर एक placeholder होता है जहां आप अपने डेटा में पास होते हैं (और अपने लेबल के लिए एक placeholder प्रशिक्षण के दौरान, जिसकी आपको भविष्यवाणी करने की आवश्यकता नहीं होती है क्योंकि कोई भी ऑपरेशन जिसे आप टेंसरफ़्लो को गणना करने के लिए नहीं कहेंगे, उस पर निर्भर करता है। )

लेकिन आपको इन विभिन्न परिचालनों (logits, और accuracy) और प्लेसहोल्डर्स (x एक विशिष्ट नाम है) के संदर्भ प्राप्त करने की आवश्यकता होगी। चूंकि आपने डिस्क से अपना ग्राफ लोड किया है, आपके पास संदर्भ नहीं हैं (ध्यान दें कि मॉडल को लोड करने का एक वैकल्पिक तरीका मॉडल बनाने वाले कोड को फिर से चलाना है, जो आपको आवश्यक संदर्भों तक आसान पहुंच प्रदान करता है)।

सही संदर्भ प्राप्त करने के लिए आप उन्हें नाम से देख सकते हैं। यहां बताया गया है कि आपको सभी कार्यों की सूची कैसे मिलेगी:

Tensorflow में ग्राफ़ में टेंसर नामों की सूची

फिर नाम से एक विशिष्ट ओपी (ऑपरेशन) प्राप्त करने के लिए:

नाम से टेंसरफ़्लो सेशन कैसे प्राप्त करें?

तो आपके पास कुछ ऐसा होगा:

logits = tf.get_default_graph().get_operation_by_name("logits:0")
x = tf.get_default_graph().get_operation_by_name("x:0")
accuracy = tf.get_default_graph().get_operation_by_name("accuracy:0")

ध्यान दें कि :0 डुप्लीकेट नामों से बचने के लिए टेंसरफ़्लो में सभी नामों में जोड़ा गया एक इंडेक्स है। अब आपके पास आवश्यक सभी संदर्भ हैं और आप एक विशिष्ट गणना करने के लिए sess.run का उपयोग कर सकते हैं, इनपुट डेटा प्रदान कर सकते हैं, और ओपी जिन्हें आप गणना करना चाहते हैं:

sess.run([logits, accuracy], feed_dict={x:your_input_data_in_numpy_format})

आपके कार्यान्वयन में इन तत्वों के नाम अलग-अलग होंगे, मैंने सबसे सामान्य नामों का उपयोग किया है। अगर उन्हें सुंदर नाम नहीं दिए गए तो उन्हें पहचानना मुश्किल होगा और आपको उस मूल कोड को देखना होगा जिसने ग्राफ़ तैयार किया था। वास्तव में अगर उनका नाम ठीक से नहीं रखा गया था तो उन्हें नाम से देखना इतना दर्दनाक है कि मेटा ग्राफ़ को आयात करने के बजाय मूल ग्राफ़ का उत्पादन करने वाले कोड को फिर से चलाने के लिए शायद बेहतर है। ध्यान दें कि saver.restore केवल वास्तविक डेटा को पुनर्स्थापित करता है, import_meta_graph वैकल्पिक टुकड़ा है जिसे प्रोग्रामेटिक रूप से ग्राफ़ को फिर से बनाकर प्रतिस्थापित किया जा सकता है।

2
David Parks 18 अप्रैल 2018, 18:42