عادةً ما تتمحور عملية تطوير XLA حول
HLO، الذي يمثّل عملية حسابية وظيفية معزولة
يتم تقديمها إلى المترجم. تتضمّن XLA أدوات متعددة لسطر الأوامر (موضّحة أدناه) تستخدم HLO وتنفّذها أو توفّر مرحلة تجميع وسيطة. إنّ استخدام هذه الأدوات لا يقدّر بثمن لدورة تكرار سريعة، لأنّ HLO يمكن تصوّره وتعديله، وغالبًا ما يكون تغييره وتشغيله بشكل متكرّر أسرع طريقة لفهم أداء XLA أو سلوكه وإصلاحه.compile->modify->run
إنّ أسهل طريقة للحصول على HLO لبرنامج يتم تجميعه باستخدام XLA هي عادةً استخدام متغيّر البيئة XLA_FLAGS:
$ XLA_FLAGS=--xla_dump_to=/tmp/myfolder ./myprogram-entry-point
الذي يخزّن جميع ملفات HLO قبل التحسين في المجلد المحدّد، بالإضافة إلى العديد من العناصر المفيدة الأخرى.
[run_hlo_module] تشغيل وحدات HLO
bazel run //xla/tools:run_hlo_module -- [flags] <filename>
تعمل الأداة run_hlo_module على HLO قبل التحسين، وتتضمّن تلقائيًا تجميع الحِزم وتشغيلها ومقارنتها بتنفيذ المترجم المرجعي. على سبيل المثال، يكون الاستدعاء المعتاد لتشغيل ملف إدخال
computation.hlo على وحدة معالجة الرسومات من NVIDIA والتحقّق من صحته كما يلي:
run_hlo_module --platform=CUDA --reference_platform=Interpreter computation.hlo
تشغيل وحدات HLO متعددة
يمكن استدعاء عدة وحدات HLO لـ run_hlo_module. لتشغيل جميع وحدات hlo من دليل:
bazel run //xla/tools:run_hlo_module -- [flags] /dump/*before_optimizations*
[multihost_hlo_runner] تشغيل وحدات HLO مع دعم SPMD
# Note: Binary name is `hlo_runner_main`.
bazel run //xla/tools/multihost_hlo_runner:hlo_runner_main -- [flags] <filename>
أداة تنفيذ HLO المتعددة المضيفين هي أداة مشابهة جدًا، مع التنبيه إلى أنّها تتوافق مع SPMD، بما في ذلك الاتصال بين المضيفين. راجِع Multi-Host HLO Runner للاطّلاع على التفاصيل.
تشغيل وحدات HLO متعددة مع دعم SPMD
على غرار run_hlo_module، يتيح multihost_hlo_runner أيضًا إمكانية التفعيل باستخدام وحدات متعدّدة.
bazel run //xla/tools/multihost_hlo_runner:hlo_runner_main -- [flags] /dump/*before_optimizations*
[hlo-opt] تجميع وحدة HLO
bazel run //xla/tools:hlo-opt -- --platform=[gpu|cpu|...] [more flags] <filename>
عند تصحيح الأخطاء أو فهم طريقة عمل برنامج التجميع، من المفيد غالبًا الحصول على التوسيع لجهاز معيّن في نقطة معيّنة من مسار التعلّم (سواء كان HLO أو HLO محسّنًا أو TritonIR أو LLVM)، وذلك لإدخال HLO أو StableHLO معيّن.
تتيح hlo-opt مراحل إخراج متعددة، سواء كانت PTX أو HLO بعد عمليات التحسين أو LLVM IR قبل عمليات التحسين أو TritonIR. تعتمد المجموعة الدقيقة من المراحل المتوافقة على النظام الأساسي (على سبيل المثال، PTX خاص بـ NVIDIA)، ويمكن الاطّلاع عليها باستخدام الأمر --list-stages:
hlo-opt --platform=CUDA --list-stages
buffer-assignment
hlo
hlo-backend
html
llvm
llvm-after-optimizations
llvm-before-optimizations
ptx
بعد اختيار مرحلة، يمكن للمستخدم كتابة نتيجة التحويل لمنصة معيّنة إلى بث معيّن:
hlo-opt --platform=cpu --stage=hlo input.hlo
سيؤدي ذلك إلى طباعة التفريغ إلى stdout (أو إلى ملف معيّن إذا تم تحديد -o).
الترجمة البرمجية لوحدة معالجة الرسومات بدون جهاز
لا تحتاج عملية التجميع بدون جهاز إلى الوصول إلى وحدة معالجة الرسومات. توفّر ميزة "التجميع بدون جهاز" طريقة لتحديد مواصفات وحدة معالجة الرسومات على سطر الأوامر (--xla_gpu_target_config_filename) للمراحل التي تتطلّب الوصول إلى وحدة معالجة الرسومات، ما يلغي الحاجة إلى جهاز مزوّد بوحدة معالجة الرسومات.
مثال: ناتج PTX بدون إذن الوصول إلى جهاز وحدة معالجة الرسومات:
hlo-opt --platform=CUDA --stage=llvm --xla_gpu_target_config_filename=/xla/tools/hlo_opt/gpu_specs/a100_pcie_80.txtpb input.hlo
يتم شحن مواصفات وحدات معالجة الرسومات الشائعة مع المجمّع، والملف المقدَّم هو تسلسل نصي لـ device_description.proto:
gpu_device_info {
cuda_compute_capability {
major: 8
minor: 0
}
threads_per_block_limit: 1024
threads_per_warp: 32
shared_memory_per_block: 127152
shared_memory_per_core: 65536
threads_per_core_limit: 2048
core_count: 6192
fpus_per_core: 64
block_dim_limit_x: 2147483647
block_dim_limit_y: 65535
block_dim_limit_z: 65535
memory_bandwidth: 2039000000000
l2_cache_size: 4194304
clock_rate_ghz: 1.1105
device_memory_size: 79050250240
}
platform_name: "CUDA"
يمكنك الاطّلاع على المزيد من مواصفات وحدة معالجة الرسومات على /xla/tools/hlo_opt/gpu_specs
الضبط التلقائي
في بعض الأحيان، قد تتضمّن عملية التجميع ضبطًا تلقائيًا استنادًا إلى عملية التجميع --stage.
لكي تعمل عملية التجميع بدون جهاز، على المستخدم إما
إيقاف الضبط التلقائي باستخدام --xla_gpu_autotune_level=0
أو
تحميل نتائج ضبط تلقائي حالية باستخدام
--xla_gpu_load_autotune_results_from=<filename> (تم الحصول عليها باستخدام
--xla_gpu_dump_autotune_results_to=<filename>).
hlo-opt --platform=CUDA --stage=llvm --xla_gpu_target_config_filename=gpu_specs/a100_pcie_80.txtpb --xla_gpu_load_autotune_results_from=results.textpb input.hlo
ملف الضبط التلقائي هو تسلسل نصي لـ autotune_results.proto، ويكون المثال على النحو التالي:
version: 3
results {
device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB"
hlo: "{\n tmp_0 = f16[1,16,17,3]{3,2,1,0} parameter(0)\n tmp_1 = f16[16,51]{1,0} bitcast(f16[1,16,17,3]{3,2,1,0} tmp_0)\n tmp_2 = s8[16,17,3]{2,1,0} parameter(1)\n tmp_3 = s8[51,16]{0,1} bitcast(s8[16,17,3]{2,1,0} tmp_2)\n tmp_4 = f16[51,16]{0,1} convert(s8[51,16]{0,1} tmp_3)\n tmp_5 = f16[16,16]{1,0} dot(f16[16,51]{1,0} tmp_1, f16[51,16]{0,1} tmp_4), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n ROOT tmp_6 = f16[1,16,16]{2,1,0} bitcast(f16[16,16]{1,0} tmp_5)\n}"
result {
run_time {
nanos: 31744
}
triton {
block_m: 32
block_n: 32
block_k: 32
split_k: 1
num_stages: 1
num_warps: 4
}
}
}
يمكن تسلسل قاعدة بيانات الضبط التلقائي باستخدام
XLA_FLAGS=--xla_gpu_dump_autotune_results_to=<myfile.pbtxt>
[hlo-opt] تطوير وتصحيح أخطاء بطاقة HLO
# If you are working with hardware independent passes from the
# `xla/hlo/transforms/` directory, prefer light-weight version
# of the `hlo-opt` tool with fewer dependencies:
bazel run //xla/hlo/tools:hlo-opt -- [flags] <filename>
# Otherwise, for hardware independent and CPU, GPU passes use
# the same binary from "Compile HLO Modules" section above:
bazel run //xla/tools:hlo-opt -- [flags] <filename>
تتيح الأداة hlo-opt تنفيذ عمليات فردية
بشكل مستقل عن مراحل تجميع المنصة المحدّدة. يساعد هذا العزل في إجراء عمليات فحص سريعة على وحدة HLO للإدخال وتحديد السبب الرئيسي لحدوث الأعطال.
hlo-opt --passes=schedule-aware-collective-cse input.hlo
تتيح الأداة hlo-opt أيضًا استخدام DebugOptions XLA_FLAGS.
hlo-opt --passes=schedule-aware-collective-cse
--xla_gpu_experimental_collective_cse_distance_threshold=20 input.hlo
استخدِم الخيار--list-passes للحصول على سلسلة اسم البطاقة.
hlo-opt --list-passes
يمكن للمستخدمين إنشاء مسار مخصّص خاص بهم من خلال تحديد أكثر من تمريرة واحدة إلى الخيار --passes.
hlo-opt --passes=pass1,pass2,pass3 input.hlo
المساعدة في تطوير بطاقة HLO الجديدة
- أولاً، اكتب بطاقتك.
سجِّل البطاقة الجديدة في سجلّ بطاقات أدوات
hlo-opt.RegisterPass<FooPass>(FooPassInputOptions)استنادًا إلى نوع البطاقة، اختَر أحد المواقع الجغرافية التالية للتسجيل:
opt_lib.ccالبطاقات التي لا تعتمد على أجهزة معيّنة
cpu_opt.ccبطاقات خاصة بوحدة المعالجة المركزية
gpu_opt.ccبطاقات خاصة بوحدة معالجة الرسومات
compiled_opt.ccعمليات مشتركة بين وحدة المعالجة المركزية ووحدة معالجة الرسومات ووحدة المعالجة الممتدة
لا تنسَ إضافة تبعية الإصدار.ضمِّن تسجيل البطاقة كجزء من طلب السحب(مثال) حتى تصبح البطاقة متاحة للاستخدام لجميع مستخدمي
hlo-opt.أعِد إنشاء الأداة
hlo-opt، وتحقّق من نجاح تسجيل البطاقة باستخدام الخيار--list-passes، ثم استخدِم الخيار--passesلتشغيل البطاقة.$ hlo-opt --passes=foo-pass input.hloلكتابة اختبارات الوحدة لعملية النقل، يُرجى الرجوع إلى https://openxla.org/xla/test_hlo_passes للحصول على مزيد من التفاصيل.
Pass Runtime Measurement
بالنسبة إلى النماذج الكبيرة، يمكن أن تستغرق عمليات التجميع الكاملة بضع دقائق، ما يجعل من الصعب رصد أي تراجع طفيف في الأداء. في المقابل، تتيح عمليات التشغيل الفردية باستخدام hlo-opt قياس الأداء بدقة ورصد الزيادات الصغيرة في وقت التنفيذ الناتجة عن تغييرات الرمز الجديدة بسهولة.
time hlo-opt --passes=reduce-window-rewriter,scatter_simplifier
--xla_reduce_window_rewrite_base_length=128 input.hlo
[hlo-opt] تحويل أشكال وحدات HLO
# Use the light weight version of the `hlo-opt` tool.
bazel run //xla/hlo/tools:hlo-opt -- [flags] <filename>
تحويل HLO Text -> HLO Proto
hlo-opt --emit-proto input.hlo
تحويل HLO Proto أو HLO Proto Binary -> HLO Text
hlo-opt input.pbtxt or input.pb
[ptx-opt] Compiler LLVM Module down to PTX
ستشغّل الأداة مسار تحسين LLVMIR ثم تستدعي CompileToPtx.
bazel run //xla/hlo/tools/ptx-opt -- --arch=9.0 <filename>
يمكن للأداة أيضًا تفريغ LLVMIR بعد كل مسار.
bazel run //xla/hlo/tools/ptx-opt -- --arch=9.0 --xla_dump_to=<path> --xla_gpu_dump_llvmir <filename>
[isolate_hlo] عزل تعليمات HLO التي تتسبّب في مشاكل
إذا كان لديك تفريغ كبير لرمز HLO وكنت تشك في أنّ تعليمات أو قسمًا معيّنًا
ضمن وحدة HLO يتسبّب في حدوث العطل، يمكنك استخدام أداة isolate_hlo.
تستخرج هذه الأداة تعليمات HLO واحدة (والسياق اللازم لها) إلى وحدة HLO جديدة أصغر حجمًا. ويكون هذا مفيدًا للغاية لإنشاء برنامج صغير لإعادة إنتاج الخطأ على مستوى المترجم.
- المستندات والمصدر: تتوفّر أداة
isolate_hloفي مستودع OpenXLA. اطّلِع على الدليلxla/toolsفي رمز المصدر XLA. الاستخدام: أنشئ الأداة من شجرة المصدر XLA. يستغرق ذلك عادةً ملف وحدة إدخال HLO (نص أو بروتوكول)، واسم التعليمات المطلوب استخراجها، ومسار ملف الإخراج.
# Example usage after building XLA: # ./build/tools/isolate_hlo --input=module.hlo --instruction_name=fusion.123 \ # --output=isolated_fusion.123.hlo --input_format=txt --output_format=long_txtراجِع رسالة المساعدة الخاصة بالأداة (
--help) لمعرفة العلامات وخيارات التنسيق المحدّدة.