傾印 HLO 計算

HLO 傾印是 HLO 模組在運算不同階段的文字表示法。這項資訊有助於偵錯,而且通常需要納入錯誤報告。這通常是人類可讀的文字檔,列出 HLO 指令和屬性。有時,HLO 模組會傾印為:

  • HloProto:通訊協定緩衝區檔案,這是結構更嚴謹的機器可讀取格式。
  • HloSnapshot:HLO 模組及其輸入內容。如要重播 HLO,有時需要提供饋送至特定運算的實際輸入內容,而非隨機資料。

您可以使用 XLA 標記指定及取得傾印。在大多數情況下,您可以使用環境變數設定這個值。JAX 也提供程式輔助方式,可列印 HLO 傾印。

本機執行

使用環境變數

您可以使用必要旗標設定 XLA_FLAGS 環境變數,取得傾印內容。這適用於 JAX、TensorFlow 和 PyTorch/XLA。

如要將 HLO 模組和其他偵錯資訊傾印至特定目錄,請使用 --xla_dump_to 旗標執行程式:

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH"

舉例來說,您可以使用 /tmp/tmp/xladump 做為路徑。

根據預設,這會在最佳化管道的開頭和結尾,以文字形式傾印 HLO 模組。

您也可以明確指定格式:

  1. 文字傾印
XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=DIRECTORY_PATH"
  1. HLO proto
XLA_FLAGS="--xla_dump_hlo_as_proto --xla_dump_to=DIRECTORY_PATH"
  1. HLO 快照
XLA_FLAGS="--xla_dump_hlo_snapshots --xla_dump_to=DIRECTORY_PATH"
  1. 使用 graphviz 伺服器算繪圖表 (僅適用於小型圖表)
XLA_FLAGS="--xla_dump_hlo_as_url --xla_dump_to=DIRECTORY_PATH"
  1. 將圖表轉譯為 HTML 檔案 (僅適用於小型圖表)
XLA_FLAGS="--xla_dump_hlo_as_html --xla_dump_to=DIRECTORY_PATH"

如要查看較大的圖表,可以使用 interactive_graphviz 顯示圖表的部分內容。

傾印特定中間階段

除了標準的預先最佳化 / 最終最佳化 HLO 之外,您也可以在特定編譯器傳遞後傾印 HLO 的狀態。

XLA_FLAGS="--xla_dump_hlo_pass_re=regex --xla_dump_to=DIRECTORY_PATH"

系統會傾印名稱符合規則運算式 (regex) 的傳遞項目 HLO 模組。舉例來說,您可以觀察與 SPMD 分區相關的傳遞所產生的 HLO:

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=spmd|propagation"

如要在每次 XLA 傳遞後傾印結果 (這會產生大量檔案),可以設定:

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=.*"

JAX 專屬選項

以程式輔助方式在 JAX 中

您也可以使用 JAX 的 lowercompile API,以程式輔助方式傾印 HLO,不必傳遞旗標或環境變數。

在本地擷取未經最佳化的原始低階 HLO,方法如下:

jax.jit(f).lower(*args).as_text('hlo')

如要在 HLO 編譯傳遞期間傾印至檔案,請指定:

compilation_args = {
    'xla_dump_to': DIRECTORY_PATH,
    'xla_dump_hlo_pass_re': 'spmd|propagation', # or some other pass filter
    ...
    }

jax.jit(f).lower(*args).compile(compilation_args)

傾印 jaxpr

jaxprs 是 JAX 的程式追蹤記錄中介表示法。如要傾印此內容,請設定下列環境變數:

JAX_DUMP_IR_TO="DIRECTORY_PATH" JAX_DUMP_IR_MODES=jaxpr

詳情請參閱 JAX 說明文件中的「Exporting and serializing staged-out computations: Debugging」。

Google Colab

環境變數

在筆記本的第一個執行儲存格中 (因為環境變數和命令列旗標通常只會處理一次,例如在模組匯入時間或 XLA 後端初始化時間),加入上述 XLA_FLAGS,例如:os.environ

import os
os.environ['XLA_FLAGS'] = "--xla_dump_to=DIRECTORY_PATH"

這會將運算作業傾印至 DIRECTORY_PATH,例如 /tmp。在 Colab 中,前往左側邊欄的「檔案」瀏覽器,即可查看及存取這個目錄。

您可以使用「在本機執行」一節中提及的所有標記。

JAX 專屬選項

與本機執行類似;如要進行即時互動式內省,您可以直接列印計算的預先最佳化 HLO:

def f(x):
    return jax.numpy.sin(jax.numpy.cos(x))

c = jax.jit(f).lower(3.).compiler_ir('hlo')

print(c.as_hlo_text())

您也可以直接列印運算的最佳化 HLO:

def optimized_HLO(f, *args, platform=None):
    print(jax.jit(f).lower(*args).compile().as_text())

def f(x):
    return jax.numpy.sin(jax.numpy.cos(x))

optimized_HLO(f, 1.0)

傾印所有/小型運算作業

如要查看傾印中的所有內容,包括所有小型編譯,請設定 JAX 環境變數:

JAX_COMPILER_DETAILED_LOGGING_MIN_OPS=0

馬賽克

Mosaic 是 Pallas TPU 後端的編譯器,也是實驗性 Pallas GPU 後端的編譯器。如要傾印馬賽克運算,請設定下列旗標:

--xla_mosaic_dump_to=/tmp/mosaic_dumps

或者,將 TPU 初始化引數設為環境變數:

export LIBTPU_INIT_ARGS="--xla_mosaic_dump_to=/tmp/mosaic_dumps"

如要瞭解詳情,請參閱 JAX 說明文件中的 Pallas 和 Mosaic

HLO Dumps 的其他作品

找出合適的計算方式

通常會捨棄許多運算。傾印的檔案會明確命名為 JAX、Tensorflow 或 PyTorch/XLA「運算名稱」,並在記錄中標示,方便您找出相關的 HLO 檔案。例如:

1624325116260738.module_0065.pmap__unnamed_wrapped_function_.186875.before_optimizations.txt

否則,您可以使用 ripgrep 快速找出含有特定符號或計算的模組。

提示:在錯誤報告中加入感興趣的 3 個傾印檔案 (緩衝區指派前後/緩衝區指派)。

HLO 轉換

名為 hlo-opt 的工具,可在 HLOProto 和文字格式之間轉換。如果您只有一種格式,但需要另一種格式進行偵錯,這個做法就非常實用。

瞭解如何使用: XLA 工具說明文件:hlo-opt

重播

您可以使用虛擬資料或輸入快照,在指定的 XLA 後端執行 (重播) 傾印的運算。這是在 XLA 中重現、疊代及偵錯問題的便利方式。

下列指令會使用虛擬資料。如果您已儲存 HLO 快照,可以改為傳遞這些快照,系統會使用快照中的資料。如要在執行快照時仍使用虛擬資料,請傳遞 --force_fake_data 旗標。

CPU 後端:

bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=cpu
 /tmp/xladump/module_4561.before_optimizations.txt

GPU 後端:

bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=CUDA
 /tmp/xladump/module_4561.before_optimizations.txt