मैं एक छवि डेटासेट (CIFAR-10) में माध्य वर्ग त्रुटि की गणना करने का प्रयास कर रहा हूं। मेरे पास numpy array
आयाम 5*10000*32*32*3
है, जो शब्दों में, 32*32*3
के आयामों के साथ प्रत्येक 10000 छवियों के 5 बैच हैं। ये छवियां छवियों की 10 श्रेणियों से संबंधित हैं। मैंने प्रत्येक वर्ग के औसत की गणना की है और अब मैं ५०००० छवियों में से १० औसत छवियों में से प्रत्येक की औसत वर्ग त्रुटि की गणना करने की कोशिश कर रहा हूं। यहाँ कोड है:
for i in range(0, 5):
for j in range(0, 10000):
min_diff, min_class = float('inf'), 0
for avg in class_avg: # avg class comprises of 10 average images
temp = mse(avg[1], images[i][j])
if temp < min_diff:
min_diff = temp
min_class = avg[0]
train_pred[i][j] = min_class
समस्या: क्या इसे तेज करने का कोई तरीका है। कोई सुन्न जादू? धन्यवाद।
1 उत्तर
आप expand_dims
और tile
का उपयोग कर सकते हैं।
एक सरणी के आयाम को विस्तारित करने के कई तरीके हैं, मैं उनमें से एक का उपयोग करूंगा, जो कि [:,None,:]
जैसा कुछ है, यह बीच में एक नया अक्ष जोड़ता है।
नीचे एक उदाहरण दिया गया है कि आप अपने कार्य को पूरा करने के लिए दो विधियों को कैसे जोड़ सकते हैं:
test = np.ones((5,100,32,32,3)) # batches of images
average = np.ones((10,32,32,3)) # the 10 images
average = average[None,None,...] # reshape to (1,1,10,32,32,3)
test = test[:,:,None,...] # insert an axis
test = np.tile(test,(1,1,10,1,1,1)) # reshape to (5,100,10,32,32,3)
print(test.shape,average.shape)
mse = ((test-average)**2).mean(axis=(3,4,5))
class_idx = np.argmin(mse,axis=-1)
अपडेट करें
expand_dims
और tile
का उपयोग करने का उद्देश्य for-loop
के उपयोग से बचना है। हालांकि, np.tile
ऑपरेशन मूल सरणी के 10 प्रतिकृतियां बनाएगा, यह निश्चित रूप से प्रदर्शन को नुकसान पहुंचाएगा यदि सरणी बड़ी है। np.tile
का उपयोग करने से बचने के लिए, आप नीचे दिए गए कोड को आजमा सकते हैं:
labels = np.empty((5,100,10))
average = np.ones((10,32,32,3))
average = average[None,...]
test = np.ones((5,100,32,32,3))
for ind in range(10):
labels[...,ind] = ((test-average[:,ind,...])**2).mean(axis=(2,3,4))
labels = np.argmin(labels,axis=-1)
संबंधित सवाल
नए सवाल
python
पायथन एक बहु-प्रतिमान है, गतिशील रूप से टाइप किया हुआ, बहुउद्देशीय प्रोग्रामिंग भाषा है। यह एक साफ और एक समान वाक्यविन्यास सीखने, समझने और उपयोग करने के लिए त्वरित होने के लिए डिज़ाइन किया गया है। कृपया ध्यान दें कि अजगर 2 आधिकारिक तौर पर 01-01-2020 के समर्थन से बाहर है। फिर भी, संस्करण-विशिष्ट पायथन सवालों के लिए, [अजगर -२.०] या [अजगर -३.x] टैग जोड़ें। पायथन वेरिएंट (जैसे, ज्योथन, PyPy) या लाइब्रेरी (उदा।, पांडस और न्यूमपी) का उपयोग करते समय, कृपया इसे टैग में शामिल करें।