שילוב פלאגין של PJRT

רקע

PJRT הוא ה-API האחיד של המכשיר, שאנחנו רוצים להוסיף לסביבה העסקית של למידת המכונה. החזון לטווח ארוך הוא: (1) frameworks (JAX, TF וכו') יקראו ל-PJRT, שיש בו יישומים אטומים למסגרות של מכשיר ספציפיים; (2) כל מכשיר מתמקד בהטמעת ממשקי API של PJRT, ויכול להיות אטום למסגרות.

המסמך הזה מתמקד בהמלצות בנושא שילוב עם PJRT ובדרכים לבדוק את השילוב של PJRT עם JAX.

איך מבצעים שילוב עם PJRT

שלב 1: מטמיעים ממשק API של PJRT C

אפשרות א': אפשר להטמיע את PJRT C API ישירות.

אפשרות ב': אם אתם יכולים לפתח מול קוד C++ במאגר ה-xla (באמצעות פיצול או bazel), אתם יכולים גם להטמיע את PJRT C++ API ולהשתמש ב-wrapper של C←C++:

  1. מטמיעים לקוח C++ PJRT שעובר בירושה מלקוח PJRT הבסיסי (וממחלקות PJRT קשורות). לפניכם כמה דוגמאות של לקוח C++ PJRT: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. מטמיעים כמה שיטות C API שאינן חלק מלקוח C++ PJRT:
    • PJRT_Client_Create. בהמשך מופיע פסאודו קוד לדוגמה (בהנחה ש-GetPluginPjRtClient מחזיר לקוח C++ PJRT שהוטמע למעלה):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

חשוב לשים לב ש-PJRT_Client_Create יכול לקבל את האפשרויות שהועברו מה-framework. כאן יש דוגמה לאופן שבו לקוח GPU משתמש בתכונה הזו.

באמצעות ה-wrapper, לא צריך להטמיע את שאר ממשקי ה-API של C.

שלב 2: מטמיעים את GetPjRtApi

צריך להטמיע שיטה GetPjRtApi שמחזירה PJRT_Api* שמכיל מצביעי פונקציות אל הטמעות של PJRT C API. דוגמה להטמעה דרך wrapper (בדומה ל-pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

שלב 3: בדיקת הטמעות של API C

אפשר לקרוא ל-RegisterPjRtCApiTestFactory כדי להריץ קבוצה קטנה של בדיקות עבור התנהגויות בסיסיות של PJRT C API.

איך להשתמש בפלאגין PJRT מ-JAX

שלב 1: הגדרת JAX

אפשר להשתמש ב-JAX ללילה

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

או ליצור JAX מהמקור.

בשלב זה, צריך להתאים את גרסת jaxlib לגרסת PJRT C API. בדרך כלל מספיק להשתמש בגרסת jaxlib ללילה מאותו יום שבו חלה ההתחייבות של TF שעבורה אתה בונה את הפלאגין, למשל.

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

אפשר גם לבנות ג'קסליב מהמקור בדיוק במסגרת ה-XLSA שאתם בונים לפיה (instructions).

בקרוב נתחיל לתמוך בתאימות לממשקי ABI.

שלב 2: שימוש במרחב השמות של jax_Plugins או הגדרה של record_point

יש שתי דרכים לגלות את הפלאגין על ידי JAX.

  1. שימוש בחבילות מרחב שמות (ref). יש להגדיר מודול ייחודי גלובלי במסגרת חבילת מרחב השמות של jax_plugins (כלומר, פשוט ליצור ספריית jax_plugins ולהגדיר את המודול מתחתיה). הנה מבנה ספרייה לדוגמה:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. שימוש במטא-נתונים של חבילה (ref). אם יוצרים חבילה באמצעות pyproject.toml או setup.py, מפרסמים את שם מודול הפלאגין על ידי הכללת נקודת כניסה תחת הקבוצה jax_plugins שמפנה לשם המודול המלא. לפניכם דוגמה דרך pyproject.toml או setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

לפניכם דוגמאות לאופן ההטמעה של openxla-pjrt-Plugin באמצעות אפשרות 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

שלב 3: הטמעה של שיטת אתחול()

עליך ליישם שיטת אתחול() במודול python כדי לרשום את הפלאגין. לדוגמה:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

כאן מוסבר איך להשתמש ב-xla_bridge.register_plugin. זו שיטה פרטית כרגע. בעתיד יושק ממשק API ציבורי.

אפשר להריץ את השורה הבאה כדי לוודא שהפלאגין רשום ולהעלות הודעת שגיאה אם לא ניתן לטעון אותו.

jax.config.update("jax_platforms", "my_plugin")

ל-JAX יכולים להיות כמה קצוות עורפיים/יישומי פלאגין. יש כמה דרכים לוודא שהפלאגין ישמש כקצה העורפי כברירת המחדל:

  • אפשרות 1: מריצים את jax.config.update("jax_platforms", "my_plugin") בתחילת התוכנה.
  • אפשרות 2: מגדירים את ENV JAX_PLATFORMS=my_plugin.
  • אפשרות 3: הגדירו עדיפות גבוהה מספיק בזמן הקריאה ל-xb.register_Plugin (ערך ברירת המחדל הוא 400, שהוא גבוה יותר מקצוות עורפיים אחרים). לתשומת ליבך, הקצה העורפי עם העדיפות הגבוהה ביותר ישמש רק כאשר JAX_PLATFORMS=''. ערך ברירת המחדל של JAX_PLATFORMS הוא '' אבל לפעמים הוא יוחלף.

איך לבצע בדיקה באמצעות JAX

כמה מקרי בדיקה בסיסיים שכדאי לנסות:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(בקרוב נוסיף הוראות להפעלת הבדיקות של יחידת jax בפלאגין שלך!)

דוגמה: פלאגין JAX CUDA

  1. הטמעת PJRT C API דרך wrapper (pjrt_c_api_gpu.h).
  2. מגדירים את נקודת הכניסה לחבילה (setup.py).
  3. מטמיעים שיטת startize() (__init__.py).
  4. אפשר לבדוק את זה בכל בדיקות jax ל-CUDA. ```