רקע
PJRT הוא ה-API האחיד של המכשיר, שאנחנו רוצים להוסיף לסביבה העסקית של למידת המכונה. החזון לטווח ארוך הוא: (1) frameworks (JAX, TF וכו') יקראו ל-PJRT, שיש בו יישומים אטומים למסגרות של מכשיר ספציפיים; (2) כל מכשיר מתמקד בהטמעת ממשקי API של PJRT, ויכול להיות אטום למסגרות.
המסמך הזה מתמקד בהמלצות בנושא שילוב עם PJRT ובדרכים לבדוק את השילוב של PJRT עם JAX.
איך מבצעים שילוב עם PJRT
שלב 1: מטמיעים ממשק API של PJRT C
אפשרות א': אפשר להטמיע את PJRT C API ישירות.
אפשרות ב': אם אתם יכולים לפתח מול קוד C++ במאגר ה-xla (באמצעות פיצול או bazel), אתם יכולים גם להטמיע את PJRT C++ API ולהשתמש ב-wrapper של C←C++:
- מטמיעים לקוח C++ PJRT שעובר בירושה מלקוח PJRT הבסיסי (וממחלקות PJRT קשורות). לפניכם כמה דוגמאות של לקוח C++ PJRT: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- מטמיעים כמה שיטות C API שאינן חלק מלקוח C++ PJRT:
- PJRT_Client_Create. בהמשך מופיע פסאודו קוד לדוגמה (בהנחה ש-
GetPluginPjRtClient
מחזיר לקוח C++ PJRT שהוטמע למעלה):
- PJRT_Client_Create. בהמשך מופיע פסאודו קוד לדוגמה (בהנחה ש-
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
חשוב לשים לב ש-PJRT_Client_Create יכול לקבל את האפשרויות שהועברו מה-framework. כאן יש דוגמה לאופן שבו לקוח GPU משתמש בתכונה הזו.
- [אופציונלי] PJRT_TopologyDescription_Create.
- [אופציונלי] PJRT_Plugin_Initialize. זוהי הגדרה חד-פעמית של פלאגין, שה-framework יקרא לה לפני קריאה לפונקציות אחרות.
- [אופציונלי] PJRT_Plugin_Attributes.
באמצעות ה-wrapper, לא צריך להטמיע את שאר ממשקי ה-API של C.
שלב 2: מטמיעים את GetPjRtApi
צריך להטמיע שיטה GetPjRtApi
שמחזירה PJRT_Api*
שמכיל מצביעי פונקציות אל הטמעות של PJRT C API. דוגמה להטמעה דרך wrapper (בדומה ל-pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
שלב 3: בדיקת הטמעות של API C
אפשר לקרוא ל-RegisterPjRtCApiTestFactory כדי להריץ קבוצה קטנה של בדיקות עבור התנהגויות בסיסיות של PJRT C API.
איך להשתמש בפלאגין PJRT מ-JAX
שלב 1: הגדרת JAX
אפשר להשתמש ב-JAX ללילה
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
או ליצור JAX מהמקור.
בשלב זה, צריך להתאים את גרסת jaxlib לגרסת PJRT C API. בדרך כלל מספיק להשתמש בגרסת jaxlib ללילה מאותו יום שבו חלה ההתחייבות של TF שעבורה אתה בונה את הפלאגין, למשל.
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
אפשר גם לבנות ג'קסליב מהמקור בדיוק במסגרת ה-XLSA שאתם בונים לפיה (instructions).
בקרוב נתחיל לתמוך בתאימות לממשקי ABI.
שלב 2: שימוש במרחב השמות של jax_Plugins או הגדרה של record_point
יש שתי דרכים לגלות את הפלאגין על ידי JAX.
- שימוש בחבילות מרחב שמות (ref). יש להגדיר מודול ייחודי גלובלי במסגרת חבילת מרחב השמות של
jax_plugins
(כלומר, פשוט ליצור ספרייתjax_plugins
ולהגדיר את המודול מתחתיה). הנה מבנה ספרייה לדוגמה:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- שימוש במטא-נתונים של חבילה (ref). אם יוצרים חבילה באמצעות pyproject.toml או setup.py, מפרסמים את שם מודול הפלאגין על ידי הכללת נקודת כניסה תחת הקבוצה
jax_plugins
שמפנה לשם המודול המלא. לפניכם דוגמה דרך pyproject.toml או setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
לפניכם דוגמאות לאופן ההטמעה של openxla-pjrt-Plugin באמצעות אפשרות 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
שלב 3: הטמעה של שיטת אתחול()
עליך ליישם שיטת אתחול() במודול python כדי לרשום את הפלאגין. לדוגמה:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
כאן מוסבר איך להשתמש ב-xla_bridge.register_plugin
. זו שיטה פרטית כרגע. בעתיד יושק ממשק API ציבורי.
אפשר להריץ את השורה הבאה כדי לוודא שהפלאגין רשום ולהעלות הודעת שגיאה אם לא ניתן לטעון אותו.
jax.config.update("jax_platforms", "my_plugin")
ל-JAX יכולים להיות כמה קצוות עורפיים/יישומי פלאגין. יש כמה דרכים לוודא שהפלאגין ישמש כקצה העורפי כברירת המחדל:
- אפשרות 1: מריצים את
jax.config.update("jax_platforms", "my_plugin")
בתחילת התוכנה. - אפשרות 2: מגדירים את ENV
JAX_PLATFORMS=my_plugin
. - אפשרות 3: הגדירו עדיפות גבוהה מספיק בזמן הקריאה ל-xb.register_Plugin (ערך ברירת המחדל הוא 400, שהוא גבוה יותר מקצוות עורפיים אחרים). לתשומת ליבך, הקצה העורפי עם העדיפות הגבוהה ביותר ישמש רק כאשר
JAX_PLATFORMS=''
. ערך ברירת המחדל שלJAX_PLATFORMS
הוא''
אבל לפעמים הוא יוחלף.
איך לבצע בדיקה באמצעות JAX
כמה מקרי בדיקה בסיסיים שכדאי לנסות:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(בקרוב נוסיף הוראות להפעלת הבדיקות של יחידת jax בפלאגין שלך!)
דוגמה: פלאגין JAX CUDA
- הטמעת PJRT C API דרך wrapper (pjrt_c_api_gpu.h).
- מגדירים את נקודת הכניסה לחבילה (setup.py).
- מטמיעים שיטת startize() (__init__.py).
- אפשר לבדוק את זה בכל בדיקות jax ל-CUDA. ```