कैटगरी: कंपाइल टाइम: हार्डवेयर पर काम न करने वाला RHS DataType
यह गड़बड़ी तब होती है, जब मैट्रिक्स के गुणन (जैसे, jax.lax.dot_general, jax.lax.conv,
jax.numpy.matmul या @ ऑपरेटर) को, इस्तेमाल किए जा रहे टीपीयू जनरेशन के साथ मूल रूप से काम करने की सुविधा नहीं मिलती.
गड़बड़ी के मैसेज के उदाहरण:
INTERNAL: Mosaic failed to compile TPU kernel: Unsupported matmul RHS type on target: 'vector<256x256xi8>'
...
The MLIR operation involved:
%13440 = "tpu.matmul"(%13435, %13437, %13439) <dimension_numbers = #tpu.dot_dimension_numbers<...>
XLA बैकएंड: टीपीयू
खास जानकारी
टीपीयू की मैट्रिक्स मल्टिप्लाई यूनिट (एमएक्सयू), सभी हार्डवेयर जनरेशन पर Float32 कार्रवाइयों के साथ काम करती है.
हालांकि, BFloat16 और अन्य क्वांटाइज़्ड डेटा टाइप के लिए नेटिव सपोर्ट उपलब्ध है. जैसे, Int4, Int8 या Float8) हार्डवेयर जनरेशन के हिसाब से अलग-अलग होता है. यह गड़बड़ी तब होती है, जब आपका कर्नल, मैट्रिक्स मल्टिप्लिकेशन को MXU पर मैप करने की कोशिश करता है. इसके लिए, वह ऐसे डेटा टाइप का इस्तेमाल करता है जिसे आपके टीपीयू जनरेशन के पास फ़िज़िकल सर्किट्री के तौर पर लागू करने की सुविधा नहीं होती.
आम तौर पर, इस गड़बड़ी का मतलब है कि कंपाइलर का कैननिकल वर्शन में बदलने वाला पास, काम नहीं कर सका. यह पास, काम न करने वाले टाइप को काम करने वाले टाइप में अपने-आप बदलने की कोशिश करता है. जैसे, सॉफ़्टवेयर इम्यूलेशन के ज़रिए. ऐसा इसलिए हुआ, क्योंकि उसे कन्वर्ज़न का कोई मान्य नियम नहीं मिला या कंपैटिबिलिटी मोड बंद होने की वजह से, ऐसा नहीं हो सका.
डीबग करना
इस गड़बड़ी को ठीक करने के लिए, आपको अपने डेटा टाइप को हार्डवेयर की क्षमताओं के हिसाब से अलाइन करना होगा. आपके पास ये विकल्प हैं:
1. नेटिव टाइप पर कास्ट करना
इस समस्या को ठीक करने का सबसे भरोसेमंद तरीका यह है कि matmul ऑपरेशन से पहले, अपने कर्नल में ऑपरेंड को हार्डवेयर के साथ काम करने वाले डेटाटाइप (जैसे कि टीपीयू v4+ पर Float32 या BFloat16) में मैन्युअल तरीके से कास्ट करें.
- क्यों:
Float32एक यूनिवर्सल डेटा टाइप है. यह सभी टीपीयू जनरेशन पर MXU के साथ काम करता है. - ट्रेड-ऑफ़: इसमें वीपीयू (वेक्टर प्रोसेसिंग यूनिट) की लागत शामिल होती है. यह लागत, कास्ट करने के लिए ज़रूरी साइकल की होती है. हालांकि, इससे यह गारंटी मिलती है कि आपका कर्नल मौजूदा हार्डवेयर पर चलेगा.
2. कंपैटिबिलिटी मोड की जांच करना
आम तौर पर, कंपाइलर कंपैटिबिलिटी मोड में टाइप मैच न होने की इन समस्याओं को अपने-आप ठीक कर सकता है. यह मोड डिफ़ॉल्ट रूप से चालू होता है. XLA कॉन्फ़िगरेशन की दोबारा जांच करें, ताकि यह पक्का किया जा सके कि --xla_mosaic_compat_mode को 'गलत है' पर सेट न किया गया हो.
यह "पॉलीफ़िल" के तौर पर काम करता है. यह उन कार्रवाइयों के लिए सॉफ़्टवेयर इम्यूलेशन सीक्वेंस डालता है जिन्हें आपका हार्डवेयर मूल रूप से सपोर्ट नहीं करता.
कंपैटबिलिटी मोड की मदद से ये काम किए जा सकते हैं:
- मिक्स-प्रिसिशन MatMuls: इससे पूर्णांक ऑपरेंड को फ़्लोट
एक्युमुलेटर के साथ मिक्स किया जा सकता है.इसके लिए, कास्ट ऑपरेशन अपने-आप डाले जाते हैं. जैसे, matmul से पहले पूर्णांकों को
Float32तक बढ़ाना. - कम सटीक इम्यूलेशन: कुछ हार्डवेयर जनरेशन पर, यह
4-bitफ़्लोटिंग पॉइंट (4E2M1FN) या8-bitफ़्लोटिंग पॉइंट (8E4M3FN) जैसे ऐसे टाइप का इम्यूलेशन करता है जो काम नहीं करते. इसके लिए, यह उन्हेंBFloat16याFloat32जैसे काम करने वाले टाइप में बदलकर, एक्ज़ीक्यूट करता है.
ध्यान दें कि यह मोड, सबसे अच्छी परफ़ॉर्मेंस के बजाय कंपैटिबिलिटी को प्राथमिकता देता है. ऐसा इसलिए, क्योंकि एम्यूलेशन के लिए, MXU के डेटा फ़ॉर्मैट पर काम करने से पहले, उन्हें बदलने के लिए अतिरिक्त निर्देशों की ज़रूरत होती है.
3. हार्डवेयर अपग्रेड करें या सहायता का अनुरोध करें
अगर आपके एल्गोरिदम को Int4 या Float8 जैसे टाइप के लिए, कास्टिंग या इम्यूलेशन के ओवरहेड के बिना नेटिव परफ़ॉर्मेंस की ज़रूरत है, तो आपको नेटिव सपोर्ट वाले नए टीपीयू जनरेशन पर काम करना होगा.
सुविधा का अनुरोध: अगर आपको लगता है कि आपका हार्डवेयर इस ऑपरेशन के साथ काम करता है या कंपाइलर में कंपैटिबिलिटी मोड में भी मान्य इम्यूलेशन पाथ मौजूद नहीं है, तो कृपया सुविधा का अनुरोध सबमिट करें. हम आम तौर पर इस बात की गारंटी देते हैं कि ऑपरेशन, आगे आने वाले वर्शन के साथ काम करेंगे. इसलिए, अगर आपका कर्नल किसी टीपीयू जनरेशन पर काम करता है, तो उसे आने वाले सभी जनरेशन पर काम करना चाहिए. हालांकि, इस बात की कोई गारंटी नहीं है कि पुराने जनरेशन के लिए इम्यूलेशन उपलब्ध होगा. इनमें से कुछ के लिए, कास्टिंग बहुत महंगी होगी.