मैं एक छवि डेटासेट (CIFAR-10) में माध्य वर्ग त्रुटि की गणना करने का प्रयास कर रहा हूं। मेरे पास numpy array आयाम 5*10000*32*32*3 है, जो शब्दों में, 32*32*3 के आयामों के साथ प्रत्येक 10000 छवियों के 5 बैच हैं। ये छवियां छवियों की 10 श्रेणियों से संबंधित हैं। मैंने प्रत्येक वर्ग के औसत की गणना की है और अब मैं ५०००० छवियों में से १० औसत छवियों में से प्रत्येक की औसत वर्ग त्रुटि की गणना करने की कोशिश कर रहा हूं। यहाँ कोड है:

for i in range(0, 5):
  for j in range(0, 10000):
      min_diff, min_class = float('inf'), 0
      for avg in class_avg:  # avg class comprises of 10 average images
          temp = mse(avg[1], images[i][j])
          if temp < min_diff:
              min_diff = temp
              min_class = avg[0]
      train_pred[i][j] = min_class

समस्या: क्या इसे तेज करने का कोई तरीका है। कोई सुन्न जादू? धन्यवाद।

0
Midhun 30 सितंबर 2020, 09:06

1 उत्तर

सबसे बढ़िया उत्तर

आप expand_dims और tile का उपयोग कर सकते हैं।

एक सरणी के आयाम को विस्तारित करने के कई तरीके हैं, मैं उनमें से एक का उपयोग करूंगा, जो कि [:,None,:] जैसा कुछ है, यह बीच में एक नया अक्ष जोड़ता है।

नीचे एक उदाहरण दिया गया है कि आप अपने कार्य को पूरा करने के लिए दो विधियों को कैसे जोड़ सकते हैं:

test = np.ones((5,100,32,32,3)) # batches of images 
average = np.ones((10,32,32,3)) # the 10 images 
average = average[None,None,...] # reshape to (1,1,10,32,32,3)

test = test[:,:,None,...] # insert an axis 
test = np.tile(test,(1,1,10,1,1,1)) # reshape to (5,100,10,32,32,3)
print(test.shape,average.shape)

mse = ((test-average)**2).mean(axis=(3,4,5))
class_idx = np.argmin(mse,axis=-1)

अपडेट करें

expand_dims और tile का उपयोग करने का उद्देश्य for-loop के उपयोग से बचना है। हालांकि, np.tile ऑपरेशन मूल सरणी के 10 प्रतिकृतियां बनाएगा, यह निश्चित रूप से प्रदर्शन को नुकसान पहुंचाएगा यदि सरणी बड़ी है। np.tile का उपयोग करने से बचने के लिए, आप नीचे दिए गए कोड को आजमा सकते हैं:

labels = np.empty((5,100,10))
average = np.ones((10,32,32,3))
average = average[None,...]

test = np.ones((5,100,32,32,3))

for ind in range(10):
    labels[...,ind] = ((test-average[:,ind,...])**2).mean(axis=(2,3,4))
labels = np.argmin(labels,axis=-1)  
1
meTchaikovsky 4 अक्टूबर 2020, 12:31