המאמר הזה מתמקד בהמלצות לשילוב עם PJRT ובבדיקת השילוב של PJRT עם JAX.
איך משלבים עם PJRT
שלב 1: מטמיעים את ממשק ה-API של PJRT C
אפשרות א': אפשר להטמיע את PJRT C API ישירות.
אפשרות ב': אם אתם יכולים לבנות קוד C++ ב-xla repo (באמצעות forking או bazel), תוכלו גם להטמיע את PJRT C++ API ולהשתמש ב-wrapper של C→C++:
- מטמיעים לקוח PJRT ב-C++ שעובר בירושה מלקוח PJRT הבסיסי (וכיתות PJRT קשורות). הנה כמה דוגמאות ללקוח PJRT ב-C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- הטמעה של כמה שיטות API של C שלא נכללות בלקוח PJRT של C++:
- PJRT_Client_Create. בהמשך מופיע קוד-דמה לדוגמה (בהנחה ש-
GetPluginPjRtClient
מחזיר לקוח PJRT ב-C++ שהוטמע למעלה):
- PJRT_Client_Create. בהמשך מופיע קוד-דמה לדוגמה (בהנחה ש-
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
הערה: PJRT_Client_Create יכול לקבל אפשרויות שהועברו מהמסגרת. כאן מופיעה דוגמה לאופן שבו לקוח GPU משתמש בתכונה הזו.
- [אופציונלי] PJRT_TopologyDescription_Create.
- [אופציונלי] PJRT_Plugin_Initialize. זוהי הגדרה חד-פעמית של הפלאגין, שתתבצע על ידי המסגרת לפני ההפעלה של כל פונקציה אחרת.
- [אופציונלי] PJRT_Plugin_Attributes.
בעזרת העטיפה, אין צורך להטמיע את ממשקי ה-API הנותרים של C.
שלב 2: הטמעת GetPjRtApi
צריך להטמיע את השיטה GetPjRtApi
שמחזירה PJRT_Api*
שמכיל מצביע פונקציה להטמעות של PJRT C API. בדוגמה הבאה אנו מניחים שההטמעה מתבצעת באמצעות מעטפת (בדומה ל-pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
שלב 3: בדיקת הטמעות של ממשקי API ל-C
אפשר להפעיל את RegisterPjRtCApiTestFactory כדי להריץ קבוצה קטנה של בדיקות להתנהגויות בסיסיות של PJRT C API.
איך משתמשים בפלאגין PJRT מ-JAX
שלב 1: הגדרת JAX
אפשר להשתמש ב-JAX nightly
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
בשלב הזה, צריך להתאים את גרסת jaxlib לגרסה של PJRT C API. בדרך כלל מספיק להשתמש בגרסה נייטלי של jaxlib מאותו יום שבו בוצע השמירה ב-TF שעבורו אתם מפתחים את הפלאגין, למשל:
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
אפשר גם ליצור jaxlib מקוד מקור בדיוק בגרסה של XLA שאליה רוצים לבנות (הוראות).
בקרוב נתחיל לתמוך בתאימות ABI.
שלב 2: משתמשים במרחב השמות jax_plugins או מגדירים את entry_point
יש שתי אפשרויות לזיהוי הפלאגין על ידי JAX.
- שימוש בחבילות של מרחבי שמות (מידע נוסף). מגדירים מודול ייחודי גלובלי בחבילת מרחב השמות
jax_plugins
(כלומר, פשוט יוצרים ספרייהjax_plugins
ומגדירים את המודול מתחתיה). דוגמה למבנה ספרייה:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- באמצעות המטא-נתונים של החבילה (ref). אם אתם יוצרים חבילה באמצעות pyproject.toml או setup.py, כדאי לפרסם את שם המודול של הפלאגין על ידי הוספת נקודת כניסה בקבוצה
jax_plugins
שמצביעה על שם המודול המלא. דוגמה באמצעות pyproject.toml או setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
ריכזנו כאן דוגמאות להטמעה של openxla-pjrt-plugin באמצעות האפשרות השנייה: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
שלב 3: הטמעת שיטת initialize()
כדי לרשום את הפלאגין, צריך להטמיע את השיטה initialize() במודול Python. לדוגמה:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
כאן מוסבר איך משתמשים ב-xla_bridge.register_plugin
. בשלב זה, זוהי שיטה פרטית. ממשק API ציבורי יושק בעתיד.
אפשר להריץ את השורה הבאה כדי לוודא שהפלאגין רשום, ולהציג שגיאה אם לא ניתן לטעון אותו.
jax.config.update("jax_platforms", "my_plugin")
ל-JAX יכולים להיות כמה קצוות עורפיים או יישומי פלאגין. יש כמה אפשרויות להבטיח שהפלאגין ישמש כקצה העורפי שמוגדר כברירת מחדל:
- אפשרות 1: מריצים את
jax.config.update("jax_platforms", "my_plugin")
בתחילת התוכנית. - אפשרות 2: מגדירים את ENV כ-
JAX_PLATFORMS=my_plugin
. - אפשרות 3: מגדירים עדיפות גבוהה מספיק בקריאה ל-xb.register_plugin (ערך ברירת המחדל הוא 400, שהוא גבוה יותר מערכים אחרים של קצה עורפי קיים). הערה: המערכת תשתמש בקצה העורפי עם העדיפות הגבוהה ביותר רק כאשר
JAX_PLATFORMS=''
. ערך ברירת המחדל שלJAX_PLATFORMS
הוא''
, אבל לפעמים הוא יימחק.
איך בודקים באמצעות JAX
כמה תרחישי בדיקה בסיסיים שאפשר לנסות:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(בקרוב נוסיף הוראות להרצת בדיקות היחידה של Jax ביחס לפלאגין שלכם).
דוגמאות נוספות לפלאגינים של PJRT זמינות במאמר דוגמאות ל-PJRT.