मैं PyTorch में कई मैट्रिक्स के मैट्रिक्स गुणन करने की कोशिश कर रहा हूं और सोच रहा था कि PyTorch में numpy.linalg.multi_dot() के बराबर क्या है?

यदि कोई नहीं है, तो अगला सबसे अच्छा तरीका क्या है (गति और स्मृति के मामले में) मैं इसे PyTorch में कर सकता हूं?

कोड:

import numpy as np
import torch

A = np.random.rand(3, 3)
B = np.random.rand(3, 3)
C = np.random.rand(3, 3)

results = np.linalg.multi_dot(A, B, C)

A_tsr = torch.tensor(A)
B_tsr = torch.tensor(B)
C_tsr = torch.tensor(C)

# What is the PyTorch equivalent of np.linalg.multi_dot()?

बहुत धन्यवाद!

1
Leockl 25 अक्टूबर 2020, 09:31

1 उत्तर

सबसे बढ़िया उत्तर

~~ऐसा लगता है कि कोई मल्टी_डॉट में टेंसर भेज सकता है~~

ऐसा लगता है कि numpy कार्यान्वयन सब कुछ numpy arrays में डाल देता है। यदि आपके टेंसर सीपीयू पर हैं और अलग हो गए हैं तो यह काम करना चाहिए। अन्यथा, numpy में रूपांतरण विफल हो जाएगा।

तो सामान्य तौर पर - संभावना है कि कोई विकल्प नहीं है। मुझे लगता है कि आपका सबसे अच्छा शॉट multi_dot कार्यान्वयन करना है, उदा। यहां से numpy v1.19.0 के लिए< /a> और इसे टेंसर को संभालने के लिए समायोजित करें / कास्ट को numpy पर छोड़ दें। समान इंटरफ़ेस और कोड सादगी को देखते हुए मुझे लगता है कि यह बहुत सीधा होना चाहिए।

1
Yuri Feldman 25 अक्टूबर 2020, 09:44