ทิ้งการคำนวณ HLO

การดัมพ์ HLO คือการแสดงโมดูล HLO ในรูปแบบข้อความในขั้นตอนต่างๆ ของการคำนวณ ซึ่งมีประโยชน์สำหรับการแก้ไขข้อบกพร่อง และคุณมักจะต้องรวมไว้ในรายงานข้อบกพร่อง โดยปกติจะเป็นไฟล์ข้อความที่มนุษย์อ่านได้ ซึ่งแสดงรายการ คำสั่ง HLO และพร็อพเพอร์ตี้ของคำสั่งเหล่านั้น บางครั้ง ระบบจะทิ้งโมดูล HLO เป็น

  • HloProto: ไฟล์ Protocol Buffer ซึ่งเป็นรูปแบบที่มีโครงสร้างมากขึ้น และเครื่องอ่านได้
  • HloSnapshot: โมดูล HLO พร้อมอินพุต สำหรับการเล่น HLO ซ้ำ บางครั้งคุณต้องใช้ข้อมูลจริงที่ป้อนให้กับการคำนวณที่กำหนดแทนที่จะใช้ข้อมูลแบบสุ่ม

คุณใช้แฟล็ก XLA เพื่อระบุและรับการดัมพ์ได้ ในกรณีส่วนใหญ่ คุณสามารถตั้งค่าได้ ด้วยตัวแปรสภาพแวดล้อม นอกจากนี้ JAX ยังมีวิธีแบบเป็นโปรแกรมในการพิมพ์ การดัมพ์ HLO ด้วย

การดำเนินการในเครื่อง

การใช้ตัวแปรสภาพแวดล้อม

คุณตั้งค่าตัวแปรสภาพแวดล้อม XLA_FLAGS ด้วยแฟล็กที่จำเป็นเพื่อรับ การทิ้งได้ ตัวเลือกนี้ใช้ได้กับ JAX, TensorFlow และ PyTorch/XLA

หากต้องการส่งออกโมดูล HLO และข้อมูลการแก้ไขข้อบกพร่องอื่นๆ ไปยังไดเรกทอรีที่เฉพาะเจาะจง ให้เรียกใช้ โปรแกรมด้วยแฟล็ก --xla_dump_to ดังนี้

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH"

เช่น คุณใช้ /tmp หรือ /tmp/xladump เป็นเส้นทางได้

โดยค่าเริ่มต้น การดำเนินการนี้จะทิ้งโมดูล HLO เป็นข้อความที่จุดเริ่มต้นและจุดสิ้นสุดของ ไปป์ไลน์การเพิ่มประสิทธิภาพ

นอกจากนี้ คุณยังระบุรูปแบบอย่างชัดเจนได้ด้วย

  1. การดัมพ์ข้อความ
XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=DIRECTORY_PATH"
  1. Proto ของ HLO
XLA_FLAGS="--xla_dump_hlo_as_proto --xla_dump_to=DIRECTORY_PATH"
  1. สแนปชอต HLO
XLA_FLAGS="--xla_dump_hlo_snapshots --xla_dump_to=DIRECTORY_PATH"
  1. การแสดงกราฟด้วยเซิร์ฟเวอร์ Graphviz (ใช้ได้ดีกับกราฟขนาดเล็กเท่านั้น)
XLA_FLAGS="--xla_dump_hlo_as_url --xla_dump_to=DIRECTORY_PATH"
  1. แสดงผลกราฟเป็นไฟล์ HTML (ใช้ได้ดีกับกราฟขนาดเล็กเท่านั้น)
XLA_FLAGS="--xla_dump_hlo_as_html --xla_dump_to=DIRECTORY_PATH"

สำหรับกราฟขนาดใหญ่ คุณสามารถใช้ interactive_graphviz เพื่อแสดงภาพบางส่วนของกราฟได้

Dump Specific Intermediate Passes

นอกจาก HLO ที่เพิ่มประสิทธิภาพล่วงหน้า / เพิ่มประสิทธิภาพขั้นสุดท้ายมาตรฐานแล้ว คุณยัง ส่งออกสถานะของ HLO หลังจากคอมไพเลอร์ผ่านการทำงานที่เฉพาะเจาะจงได้ด้วย

XLA_FLAGS="--xla_dump_hlo_pass_re=regex --xla_dump_to=DIRECTORY_PATH"

ระบบจะส่งออกโมดูล HLO สำหรับการส่งผ่านที่มีชื่อตรงกับนิพจน์ทั่วไป (regex) เช่น คุณสามารถสังเกต HLO ที่เกิดจากพาส ที่เกี่ยวข้องกับการแบ่งพาร์ติชัน SPMD ได้โดยใช้สิ่งต่อไปนี้

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=spmd|propagation"

หากต้องการทิ้งผลลัพธ์หลังจากผ่าน XLA ทุกครั้ง (ซึ่งจะทำให้มีไฟล์จำนวนมาก) ให้ตั้งค่าดังนี้

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=.*"

ตัวเลือกเฉพาะ JAX

แบบเป็นโปรแกรมใน JAX

คุณยังสามารถดัมพ์ HLO โดยใช้ API lower และ compile ของ JAX ได้ด้วย แทนที่จะส่งแฟล็กหรือตัวแปรสภาพแวดล้อม

ดึงข้อมูล HLO ที่ลดระดับลงซึ่งยังไม่ได้เพิ่มประสิทธิภาพจากเครื่องด้วยคำสั่งต่อไปนี้

jax.jit(f).lower(*args).as_text('hlo')

หากต้องการส่งออกไปยังไฟล์ระหว่างการส่งผ่านการคอมไพล์ HLO ให้ระบุสิ่งต่อไปนี้

compilation_args = {
    'xla_dump_to': DIRECTORY_PATH,
    'xla_dump_hlo_pass_re': 'spmd|propagation', # or some other pass filter
    ...
    }

jax.jit(f).lower(*args).compile(compilation_args)

Dump jaxprs

jaxprs คือการแสดงผลระดับกลางของ JAX สำหรับร่องรอยของโปรแกรม หากต้องการส่งออกข้อมูลนี้ ให้ตั้งค่าตัวแปรสภาพแวดล้อมต่อไปนี้

JAX_DUMP_IR_TO="DIRECTORY_PATH" JAX_DUMP_IR_MODES=jaxpr

ดูข้อมูลเพิ่มเติมในเอกสารประกอบ JAX เกี่ยวกับการส่งออกและการทำให้การคำนวณที่จัดเตรียมไว้เป็นอนุกรม: การแก้ไขข้อบกพร่อง

Google Colab

ตัวแปรสภาพแวดล้อม

ในเซลล์แรกที่ดำเนินการใน Notebook (เนื่องจากโดยปกติแล้วตัวแปรสภาพแวดล้อมและ แฟล็กบรรทัดคำสั่งจะได้รับการประมวลผลเพียงครั้งเดียว เช่น ในเวลาที่นำเข้าโมดูล หรือเวลาเริ่มต้นแบ็กเอนด์ XLA) ให้เพิ่ม XLA_FLAGS ที่อธิบายไว้ข้างต้นพร้อมกับ os.environ เช่น

import os
os.environ['XLA_FLAGS'] = "--xla_dump_to=DIRECTORY_PATH"

ซึ่งจะส่งการคำนวณไปยัง DIRECTORY_PATH เช่น /tmp ใน Colab ให้ไปที่เบราว์เซอร์ "ไฟล์" ในแถบด้านข้างซ้ายเพื่อดูและเข้าถึงไดเรกทอรีนี้

คุณใช้ Flag ทั้งหมดที่กล่าวถึงในส่วนการดำเนินการในเครื่องได้

ตัวเลือกเฉพาะ JAX

เช่นเดียวกับการดำเนินการในเครื่อง สำหรับการตรวจสอบแบบสดและแบบอินเทอร์แอกทีฟ คุณสามารถพิมพ์ HLO ที่มีการเพิ่มประสิทธิภาพล่วงหน้าของการคำนวณได้โดยตรงโดยใช้คำสั่งต่อไปนี้

def f(x):
    return jax.numpy.sin(jax.numpy.cos(x))

c = jax.jit(f).lower(3.).compiler_ir('hlo')

print(c.as_hlo_text())

นอกจากนี้ คุณยังพิมพ์ HLO ที่เพิ่มประสิทธิภาพของการคำนวณได้โดยตรงด้วย

def optimized_HLO(f, *args, platform=None):
    print(jax.jit(f).lower(*args).compile().as_text())

def f(x):
    return jax.numpy.sin(jax.numpy.cos(x))

optimized_HLO(f, 1.0)

การทิ้งการคำนวณทั้งหมด/บางส่วน

หากต้องการดูทุกอย่างในไฟล์ที่ส่งออก รวมถึงการคอมไพล์ขนาดเล็กทั้งหมด ให้ตั้งค่าตัวแปรสภาพแวดล้อม JAX ดังนี้

JAX_COMPILER_DETAILED_LOGGING_MIN_OPS=0

โมเสค

Mosaic เป็นคอมไพเลอร์สำหรับแบ็กเอนด์ Pallas TPU และแบ็กเอนด์ Pallas GPU แบบทดลอง หากต้องการทิ้งการคำนวณโมเสก ให้ตั้งค่าสถานะต่อไปนี้

--xla_mosaic_dump_to=/tmp/mosaic_dumps

หรือตั้งค่าอาร์กิวเมนต์ init ของ TPU เป็นตัวแปรสภาพแวดล้อม

export LIBTPU_INIT_ARGS="--xla_mosaic_dump_to=/tmp/mosaic_dumps"

ดูข้อมูลเพิ่มเติมได้ในเอกสารประกอบ JAX เกี่ยวกับ Pallas และ Mosaic

เพิ่มเติมเกี่ยวกับ HLO Dumps

การค้นหาการคำนวณที่เหมาะสม

โดยปกติแล้ว ระบบจะทิ้งการคำนวณจำนวนมาก ไฟล์ที่ดัมพ์จะมีชื่อที่ระบุอย่างชัดเจน พร้อมด้วย "ชื่อการคำนวณ" ของ JAX, TensorFlow หรือ PyTorch/XLA ที่ระบุไว้ ในบันทึก ซึ่งช่วยให้ระบุไฟล์ HLO ที่เกี่ยวข้องได้ง่าย เช่น

1624325116260738.module_0065.pmap__unnamed_wrapped_function_.186875.before_optimizations.txt

หรือคุณจะใช้ ripgrep เพื่อระบุโมดูลที่มีสัญลักษณ์หรือการคำนวณที่เฉพาะเจาะจงได้อย่างรวดเร็วก็ได้

เคล็ดลับ: ใส่ไฟล์ที่ดัมพ์ก่อน/หลัง/การกำหนดบัฟเฟอร์ 3 ไฟล์ที่สนใจ ในรายงานข้อบกพร่อง

Conversion ของ HLO

เครื่องมือที่ชื่อ hlo-opt ซึ่งแปลระหว่างรูปแบบ HLOProto กับรูปแบบข้อความได้ ซึ่งมีประโยชน์ในกรณีที่คุณมีรูปแบบหนึ่ง แต่ต้องการอีกรูปแบบหนึ่งเพื่อ การแก้ไขข้อบกพร่อง

ดูวิธีใช้ได้ที่ เอกสารประกอบเกี่ยวกับเครื่องมือ XLA: hlo-opt

เล่นซ้ำ

คุณสามารถเรียกใช้ (เล่นซ้ำ) การคำนวณที่ดัมพ์ในแบ็กเอนด์ XLA ที่ระบุด้วย ข้อมูลจำลองหรือสแนปชอตอินพุต ซึ่งเป็นวิธีที่สะดวกในการทำซ้ำ ทำซ้ำ และแก้ไขข้อบกพร่องใน XLA

คำสั่งต่อไปนี้ใช้ข้อมูลจำลอง หากบันทึกภาพรวม HLO ไว้ คุณสามารถส่งภาพรวมเหล่านั้นแทนได้ และระบบจะใช้ข้อมูลจากภาพรวม หากต้องการใช้ข้อมูลจำลองขณะเรียกใช้สแนปชอต ให้ส่งแฟล็ก --force_fake_data

แบ็กเอนด์ของ CPU:

bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=cpu
 /tmp/xladump/module_4561.before_optimizations.txt

แบ็กเอนด์ GPU:

bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=CUDA
 /tmp/xladump/module_4561.before_optimizations.txt