הנחיות לגבי דגלים של XLA

במדריך הזה מוצגת מבחר של דגלי XLA חשובים כדי לעזור למשתמשים להבין את היכולות של XLA ולהשתמש בהן בצורה יעילה. בקטעים הבאים מפורטים דגלים שיכולים להשפיע באופן משמעותי על ביצועי זמן הריצה ועל השימוש בזיכרון. אם מתעוררות בעיות כלשהן, כמו קריסות, אחרי שמפעילים דגל, מומלץ לחזור להגדרת ברירת המחדל וליצור בעיה ב-GitHub.

התרעות לגבי ביצועים

הסימונים הבאים עוזרים לשפר את הביצועים בזמן הריצה. ניסוי ההגדרות האלה עשוי להוביל לשיפור משמעותי בביצועים.

דגל תיאור ערכי ברירת מחדל הצעות לערכים ערכים של מועמדים
Pipelining
1. xla_should_allow_loop_variant_parameter_in_chain
2. xla_should_add_loop_invariant_op_in_chain
3. xla_tpu_enable_ici_ag_pipelining
צריך להשתמש ב-3 הדגלים האלה יחד כדי להפעיל צינורות משותפים של פעולות all-gather של ICI(Interchip-Interconnect), וכך ליצור יותר הזדמנויות לביצוע חופף. 1. xla_should_allow_loop_variant_parameter_in_chain=kDisabled
2. xla_should_add_loop_invariant_op_in_chain=kDisabled
3. xla_tpu_enable_ici_ag_pipelining=false
1. xla_should_allow_loop_variant_parameter_in_chain=kEnabled
2. xla_should_add_loop_invariant_op_in_chain=kEnabled
3. xla_tpu_enable_ici_ag_pipelining=true
1. xla_should_allow_loop_variant_parameter_in_chain=kDisabled/kEnabled/kAuto
2. xla_should_add_loop_invariant_op_in_chain=kDisabled/kEnabled/kAuto
3. xla_tpu_enable_ici_ag_pipelining=true/false
v5e/Async
xla_enable_async_all_gather
xla_tpu_enable_async_collective_fusion
xla_tpu_enable_async_collective_fusion_fuse_all_gather
צריך להשתמש ב-3 הדגלים האלה יחד כדי להפעיל פעולות all-gather אסינכרוניות בגרסה v5e. xla_enable_async_all_gather=kAuto
xla_tpu_enable_async_collective_fusion=true
xla_tpu_enable_async_collective_fusion_fuse_all_gather=true
xla_enable_async_all_gather=kAuto
xla_tpu_enable_async_collective_fusion=true
xla_tpu_enable_async_collective_fusion_fuse_all_gather=true
xla_enable_async_all_gather=kDisabled/kEnabled/kAuto
xla_tpu_enable_async_collective_fusion=true/false
xla_tpu_enable_async_collective_fusion_fuse_all_gather=true/false
v5e/Async
xla_tpu_enable_async_collective_fusion
xla_tpu_enable_async_collective_fusion_fuse_all_reduce
צריך להשתמש בשני הדגלים האלה יחד כדי להפעיל פעולות אסינכרוניות של all-reduce ב-v5e. xla_tpu_enable_async_collective_fusion=true
xla_tpu_enable_async_collective_fusion_fuse_all_reduce=false
xla_tpu_enable_async_collective_fusion=true
xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true
xla_tpu_enable_async_collective_fusion=true/false
xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true/false
Async
xla_tpu_enable_async_all_to_all
הדגל הזה מאפשר תקשורת אסינכרונית בין כל המשתתפים. xla_tpu_enable_async_all_to_all=false xla_tpu_enable_async_all_to_all=true xla_tpu_enable_async_all_to_all=true/false
מוגבל על ידי זמן האחזור
xla_all_gather_latency_bound_threshold_in_bytes
הדגל הזה מיועד לפעולות all-gather שמוגבלות לזמן אחזור (כלומר, פעולות קטנות). הפעלת ההגדרה הזו מפעילה אופטימיזציות ספציפיות שיכולות לקצר את זמן הביצוע של פעולות all-gather שמוגבלות על ידי זמן האחזור. בדרך כלל משתמשים בה בעומסי עבודה של הסקה. xla_all_gather_latency_bound_threshold_in_bytes=-1
(שלא מופעל)
4~16Mb(i.e. 4~16 * 1024 * 1024) [0, 9223372036854775807]
מוגבל על ידי זמן האחזור
xla_all_reduce_latency_bound_threshold_in_bytes
הדגל הזה מיועד לפעולות all-gather שמוגבלות לזמן אחזור (כלומר, פעולות קטנות). הפעלת ההגדרה הזו מפעילה אופטימיזציות ספציפיות שיכולות לקצר את זמן הביצוע של פעולות all-reduce שמוגבלות על ידי זמן האחזור. בדרך כלל משתמשים בה בעומסי עבודה של הסקה. xla_all_reduce_latency_bound_threshold_in_bytes=-1
(שלא מופעל)
4~16Mb(i.e. 4~16 * 1024 * 1024) [0, 9223372036854775807]
מוגבל על ידי זמן האחזור
xla_collective_permute_latency_bound_threshold_in_bytes
הדגל הזה מיועד לפעולות all-gather שמוגבלות לזמן אחזור (כלומר, פעולות קטנות). הפעלת ההגדרה הזו מפעילה אופטימיזציות ספציפיות שיכולות לקצר את זמן הביצוע של פעולות שינוי מיקום קולקטיביות שמוגבלות על ידי זמן האחזור. בדרך כלל משתמשים בה בעומסי עבודה של הסקה. xla_collective_permute_latency_bound_threshold_in_bytes=-1
(שלא מופעל)
4~16Mb(i.e. 4~16 * 1024 * 1024) [0, 9223372036854775807]
מוגבל על ידי זמן האחזור
xla_all_to_all_latency_bound_threshold_in_bytes
הדגל הזה מיועד לפעולות all-gather שמוגבלות לזמן אחזור (כלומר, פעולות קטנות). הפעלת האפשרות הזו מפעילה אופטימיזציות ספציפיות שיכולות לקצר את זמן הביצוע של תקשורת all-to-all שמוגבלת על ידי זמן האחזור. בדרך כלל משתמשים בה בעומסי עבודה של הסקה. xla_all_to_all_latency_bound_threshold_in_bytes=-1
(שלא מופעל)
4~16Mb(i.e. 4~16 * 1024 * 1024) [0, 9223372036854775807]
xla_enable_async_collective_permute מבצעת שכתוב של כל הפעולות של החלפה קולקטיבית לגרסאות האסינכרוניות שלהן. אם ההגדרה היא auto, ‏ XLA יכול להפעיל באופן אוטומטי את התכונה 'העברה אסינכרונית של נתונים' על סמך הגדרות או תנאים אחרים. xla_enable_async_collective_permute=kAuto xla_enable_async_collective_permute=kAuto xla_enable_async_collective_permute=kAuto/kEnabled/kDisabled

סימונים בזיכרון

הדגלים שמפורטים בהמשך נועדו לפתור בעיות שקשורות ל-HBM. צריך לשנות את ההגדרות האלה רק אם נתקלים בשגיאות מסוג HBM out of memory (HBM לא זמין) במהלך קומפילציה של מודל. בכל שאר התרחישים, מומלץ להשתמש בערכי ברירת המחדל, כי שינוי שלהם עלול לפגוע בביצועים.

דגל תיאור ערכי ברירת מחדל הצעות לערכים ערכים של מועמדים
Scheduler
xla_latency_hiding_scheduler_rerun
ההגדרה הזו משנה את ההתנהגות של מתזמן הסתרת זמן האחזור. התהליך פועל על ידי הקטנה הדרגתית של מגבלת הזיכרון שהוקצתה לתזמון עם כל הפעלה מחדש של התהליך. xla_latency_hiding_scheduler_rerun=1 xla_latency_hiding_scheduler_rerun=5 0~10(it doesn’t make much sense beyond 10 reruns)
Fusion
xla_tpu_rwb_fusion
הדגל הזה מפעיל סוגים של מיזוגים מסוג reduce+broadcast, ועשוי להפחית את השימוש בזיכרון. xla_tpu_rwb_fusion=true xla_tpu_rwb_fusion=false xla_tpu_rwb_fusion=true/false
Scheduler
xla_memory_scheduler
הדגל הזה מציין את האלגוריתם שמתזמן הזיכרון ישתמש בו כדי לצמצם את צריכת הזיכרון. שימוש באלגוריתם מתקדם יותר עשוי להניב לוח זמנים שצורך פחות זיכרון, אבל על חשבון זמן קומפילציה ארוך יותר. xla_memory_scheduler=kDefault xla_memory_scheduler=kBrkga xla_memory_scheduler=kDefault/kList/kDfs/kPostOrder/kBrkga
Scheduler
xla_tpu_enable_latency_hiding_scheduler
הדגל הזה מפעיל את מתזמן הסתרת זמן האחזור, שמאפשר לנו לבצע פעולות אסינכרוניות במקום פעולות סינכרוניות. השבתת האפשרות הזו מפחיתה את השימוש בזיכרון, אבל גורמת לאובדן של שיפורי הביצועים שמתקבלים מהפעולות האסינכרוניות האלה. xla_tpu_enable_latency_hiding_scheduler=true xla_tpu_enable_latency_hiding_scheduler=false xla_tpu_enable_latency_hiding_scheduler=true/false
SPMD
xla_jf_spmd_threshold_for_windowed_einsum_mib
הדגל הזה מגדיר את הסף התחתון של הגודל המינימלי של הנקודה להפעלת מכפלת מטריצות קולקטיבית. אם מגדירים ערך גבוה יותר, חוסכים בזיכרון אבל מאבדים הזדמנויות לבצע פעולות כפל מטריצות קולקטיביות. xla_jf_spmd_threshold_for_windowed_einsum_mib=-1 10Mb~1Gb (i.e. 10*1024*1024 ~ 1024*1024*1024) [0, 9223372036854775807]

דיווחים נפוצים אחרים

דגל סוג הערות
xla_dump_to מחרוזת (נתיב קובץ) התיקייה שבה ימוקמו קובצי HLO לפני אופטימיזציה וארטיפקטים אחרים (ראו XLA Tools).

התראות TPU XLA

דגל סוג הערות
xla_tpu_enable_data_parallel_all_reduce_opt ערך בוליאני (true/false) אופטימיזציה להגדלת ההזדמנויות לחפיפה של כל הפעולות לצמצום (all-reduce) של DCN (רשת מרכזי נתונים) שמשמשות לפיצול מקביל של נתונים.
xla_tpu_data_parallel_opt_different_sized_ops ערך בוליאני (true/false) מאפשרת צינורות (pipelining) של פעולות מקבילות לנתונים בכמה איטרציות, גם אם גדלי הפלט שלהן לא תואמים למה שאפשר לשמור במקום במשתנים המוערמים. יכול להגביר את העומס על הזיכרון.
xla_tpu_spmd_rng_bit_generator_unsafe ערך בוליאני (true/false) האם להפעיל את RngBitGenerator HLO באופן מחולק, מה שלא בטוח אם מצפים לתוצאות דטרמיניסטיות עם חלוקות שונות בחלקים שונים של החישוב.
xla_tpu_megacore_fusion_allow_ags ערך בוליאני (true/false) מאפשר מיזוג של כל הפעולות של all-gather עם פעולות של convolution או all-reduce.
xla_tpu_enable_ag_backward_pipelining ערך בוליאני (true/false) הצינורות אוספים את כל הנתונים (בשלב הזה, כל הנתונים בקנה מידה גדול) אחורה דרך לולאות הסריקה.

התראות GPU XLA

דגל סוג הערות
xla_gpu_enable_latency_hiding_scheduler ערך בוליאני (true/false) הדגל הזה מאפשר לתזמנים להסתיר את זמן האחזור כדי לחפוף תקשורת אסינכרונית עם חישוב בצורה יעילה. ערך ברירת המחדל הוא False.
xla_gpu_enable_triton_gemm ערך בוליאני (true/false) שימוש בכפל מטריצות שמבוסס על Triton.
xla_gpu_graph_level סימון (0-3) הדגל מדור קודם להגדרת רמת תרשים ה-GPU. שימוש ב-xla_gpu_enable_command_buffer בתרחישי שימוש חדשים. ‫0 = מושבת; 1 = תיעוד של מיזוגים ו-memcpys;‏ 2 = תיעוד של gemms;‏ 3 = תיעוד של convolutions.
xla_gpu_all_reduce_combine_threshold_bytes מספר שלם (בייטים) הדגלים האלה קובעים מתי לשלב כמה פעולות קטנות של AllGather / ReduceScatter / AllReduce לפעולה גדולה אחת של AllGather / ReduceScatter / AllReduce, כדי לצמצם את הזמן שמוקדש לתקשורת בין מכשירים. לדוגמה, עבור ספי AllGather / ReduceScatter בעומס עבודה מבוסס-Transformer, כדאי להגדיר אותם גבוה מספיק כדי לשלב לפחות את המשקל של שכבת Transformer. כברירת מחדל, הערך של combine_threshold_bytes מוגדר ל-256.
xla_gpu_all_gather_combine_threshold_bytes מספר שלם (בייטים) מידע נוסף זמין בקטע xla_gpu_all_reduce_combine_threshold_bytes למעלה.
xla_gpu_reduce_scatter_combine_threshold_bytes מספר שלם (בייטים) מידע נוסף זמין בקטע xla_gpu_all_reduce_combine_threshold_bytes למעלה.
xla_gpu_enable_pipelined_all_gather ערך בוליאני (true/false) הפעלת צינורות (pipelining) של הוראות all-gather.
xla_gpu_enable_pipelined_reduce_scatter ערך בוליאני (true/false) הפעלת צינורות של הוראות reduce-scatter.
xla_gpu_enable_pipelined_all_reduce ערך בוליאני (true/false) הפעלת צינורות (pipelining) של כל ההוראות להפחתה.
xla_gpu_enable_while_loop_double_buffering ערך בוליאני (true/false) הפעלת מאגר כפול ללולאת while.
xla_gpu_enable_all_gather_combine_by_dim ערך בוליאני (true/false) שילוב של פעולות all-gather עם אותו ממד gather או ללא קשר לממד שלהן.
xla_gpu_enable_reduce_scatter_combine_by_dim ערך בוליאני (true/false) לשלב פעולות reduce-scatter עם אותו מאפיין או בלי קשר למאפיין שלהן.