Dump HLO Computations

קובץ HLO dump הוא ייצוג טקסטואלי של מודולי HLO בשלבים שונים של החישוב. הוא שימושי לניפוי באגים, ולעתים קרובות צריך לכלול אותו בדוחות על באגים. בדרך כלל זה קובץ טקסט שקל לקרוא, שמפרט את ההוראות של HLO ואת המאפיינים שלהן. לפעמים, מודולים של HLO נשמרים כ:

  • HloProto: קובצי מאגר פרוטוקולים, שהם פורמט מובנה יותר וקריא למחשבים.
  • HloSnapshot: מודול HLO בתוספת הקלטים שלו. כדי להפעיל מחדש HLO, לפעמים צריך את הקלט בפועל שמוזן לחישוב מסוים ולא נתונים אקראיים.

אפשר להשתמש בדגלי XLA כדי לציין ולשלוף קובצי dump. ברוב המקרים, אפשר להגדיר אותו באמצעות משתנה סביבה. ‫JAX מציעה גם דרך תכנותית להדפסת ה-HLO dump.

הרצה מקומית

שימוש במשתני סביבה

אפשר להגדיר את משתנה הסביבה XLA_FLAGS עם הדגלים הנדרשים כדי לקבל קובצי dump. הפתרון הזה אפשרי ב-JAX, ב-TensorFlow וב-PyTorch/XLA.

כדי להעביר מודולים של HLO ומידע אחר לניפוי באגים לספרייה ספציפית, מריצים את התוכנית עם הדגל --xla_dump_to:

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH"

לדוגמה, אפשר להשתמש ב-/tmp או ב-/tmp/xladump כנתיבים.

כברירת מחדל, המערכת מציגה את מודולי ה-HLO כטקסט, בתחילת צינור האופטימיזציה ובסופו.

אפשר גם לציין במפורש את הפורמט:

  1. העתקת טקסט
XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=DIRECTORY_PATH"
  1. קובצי proto של HLO
XLA_FLAGS="--xla_dump_hlo_as_proto --xla_dump_to=DIRECTORY_PATH"
  1. תמונות מצב של HLO
XLA_FLAGS="--xla_dump_hlo_snapshots --xla_dump_to=DIRECTORY_PATH"
  1. עיבוד גרף באמצעות שרת Graphviz (פועל היטב רק עבור גרפים קטנים)
XLA_FLAGS="--xla_dump_hlo_as_url --xla_dump_to=DIRECTORY_PATH"
  1. עיבוד הגרף לקובץ HTML (מתאים רק לגרפים קטנים)
XLA_FLAGS="--xla_dump_hlo_as_html --xla_dump_to=DIRECTORY_PATH"

בתרשימים גדולים יותר, אפשר להשתמש באפשרות interactive_graphviz כדי להמחיש חלקים מהתרשים.

הצגת כרטיסי ביניים ספציפיים

בנוסף ל-HLOs רגילים שעברו אופטימיזציה מראש או אופטימיזציה סופית, אפשר גם להציג את הסטטוס של HLOs אחרי שלב מסוים בהידור.

XLA_FLAGS="--xla_dump_hlo_pass_re=regex --xla_dump_to=DIRECTORY_PATH"

מודולים של HLO יימחקו עבור המעברים שהשמות שלהם תואמים לביטוי הרגולרי (regex). לדוגמה, אפשר לראות את האופטימיזציות ברמה גבוהה (HLO) שנובעות ממעברים שקשורים לחלוקת SPMD באמצעות:

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=spmd|propagation"

כדי ליצור קובץ פלט אחרי כל מעבר XLA (הפעולה הזו תיצור הרבה קבצים), אפשר להגדיר:

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=.*"

אפשרויות ספציפיות ל-JAX

באופן פרוגרמטי ב-JAX

במקום להעביר דגלים או משתני סביבה, אפשר גם להשתמש בממשקי ה-API‏ lower ו-compile של JAX כדי להציג את ה-HLO באופן פרוגרמטי.

כדי לאחזר באופן מקומי את ה-HLO המקורי הלא אופטימלי שהומר לאותיות קטנות, מריצים את הפקודה:

jax.jit(f).lower(*args).as_text('hlo')

כדי לבצע dump לקבצים במהלך שלבי הקומפילציה של HLO, מציינים:

compilation_args = {
    'xla_dump_to': DIRECTORY_PATH,
    'xla_dump_hlo_pass_re': 'spmd|propagation', # or some other pass filter
    ...
    }

jax.jit(f).lower(*args).compile(compilation_args)

Dump jaxprs

jaxprs הם ייצוג הביניים של JAX למעקב אחר תוכניות. כדי להציג את הנתונים האלה, מגדירים את משתני הסביבה:

JAX_DUMP_IR_TO="DIRECTORY_PATH" JAX_DUMP_IR_MODES=jaxpr

מידע נוסף זמין במסמכי JAX בנושא ייצוא וסריאליזציה של חישובים שהועברו: ניפוי באגים.

Google Colab

משתני סביבה

בתא הראשון שמופעל במחברת (כי משתני סביבה ודגלים של שורת פקודה בדרך כלל מעובדים רק פעם אחת, למשל בזמן ייבוא מודול או בזמן אתחול של XLA backend), מוסיפים את XLA_FLAGS שצוין למעלה עם os.environ, לדוגמה:

import os
os.environ['XLA_FLAGS'] = "--xla_dump_to=DIRECTORY_PATH"

החישוב יישמר בקובץ DIRECTORY_PATH, לדוגמה /tmp. ב-Colab, עוברים לדפדפן Files (קבצים) בסרגל הצד הימני כדי להציג את הספרייה הזו ולגשת אליה.

אפשר להשתמש בכל הדגלים שמוזכרים בקטע 'הרצה מקומית'.

אפשרויות ספציפיות ל-JAX

בדומה להרצה מקומית, כדי לבצע בדיקה בזמן אמת של חישוב אינטראקטיבי, אפשר להדפיס ישירות את ה-HLO של חישוב שעבר אופטימיזציה מראש:

def f(x):
    return jax.numpy.sin(jax.numpy.cos(x))

c = jax.jit(f).lower(3.).compiler_ir('hlo')

print(c.as_hlo_text())

אפשר גם להדפיס ישירות את ה-HLO שעבר אופטימיזציה של חישוב:

def optimized_HLO(f, *args, platform=None):
    print(jax.jit(f).lower(*args).compile().as_text())

def f(x):
    return jax.numpy.sin(jax.numpy.cos(x))

optimized_HLO(f, 1.0)

הצגת כל החישובים או חלק מהם

אם רוצים לראות הכול ב-dump, כולל כל הקומפילציות הקטנות, צריך להגדיר את משתנה הסביבה של JAX:

JAX_COMPILER_DETAILED_LOGGING_MIN_OPS=0

פסיפס

‫Mosaic הוא קומפיילר ל-backend של Pallas TPU ול-backend הניסיוני של Pallas GPU. כדי להוציא את נתוני החישוב של הפסיפס, מגדירים את הדגל הבא:

--xla_mosaic_dump_to=/tmp/mosaic_dumps

אפשר גם להגדיר את ארגומנטי האתחול של ה-TPU כמשתנה סביבה:

export LIBTPU_INIT_ARGS="--xla_mosaic_dump_to=/tmp/mosaic_dumps"

מידע נוסף זמין במאמרי העזרה של JAX בנושא Pallas ו-Mosaic.

מידע נוסף על קובצי HLO Dumps

איך למצוא את החישוב הנכון

בדרך כלל, הרבה חישובים מושלכים. השמות של הקבצים שנוצרו הם השמות של החישובים ב-JAX, ב-Tensorflow או ב-PyTorch/XLA שמופיעים ביומנים, כך שקל לזהות את קובצי ה-HLO הרלוונטיים. לדוגמה:

1624325116260738.module_0065.pmap__unnamed_wrapped_function_.186875.before_optimizations.txt

אפשר גם להשתמש ב-ripgrep כדי לזהות במהירות באיזה מודול נמצאים סמלים או חישובים מסוימים.

טיפ: כדאי לכלול בדוחות על באגים את 3 הקבצים שנוצרו לפני/אחרי/הקצאת המאגר שמעניינים אתכם.

המרה מ-HLO

כלי שנקרא hlo-opt שיכול לתרגם בין פורמטים של HLOProto וטקסט. האפשרות הזו שימושית במקרים שבהם יש לכם פורמט אחד, אבל אתם צריכים את הפורמט השני לצורך ניפוי באגים.

מידע נוסף על השימוש בכלי: מסמכי התיעוד של XLA Tooling: hlo-opt.

הפעלה מחדש

אפשר להריץ (להפעיל מחדש) את החישובים שנוצרו ב-backend ספציפי של XLA עם נתונים פיקטיביים או תמונות מצב של קלט. זו דרך נוחה לשחזר בעיות ב-XLA, לחזור על הפעולות ולנפות באגים.

הפקודות הבאות משתמשות בנתונים פיקטיביים. אם שמרתם תמונות מצב של HLO, תוכלו להעביר אותן במקום זאת, והמערכת תשתמש בנתונים מתמונת המצב. כדי להשתמש בנתונים פיקטיביים בזמן הפעלת הצילום, מעבירים את הדגל --force_fake_data.

קצה עורפי של המעבד:

bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=cpu
 /tmp/xladump/module_4561.before_optimizations.txt

קצה עורפי של GPU:

bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=CUDA
 /tmp/xladump/module_4561.before_optimizations.txt