รูปร่างและเลย์เอาต์

โครงสร้างของ XLA Op

ลองดูตัวอย่าง HLO

add.936 = bf16[8,1,1280,16384]{3,2,0,1:T(8,128)(2,1)}
          add(exponential.183, broadcast.3115)

ซึ่งประกอบด้วยคอมโพเนนต์ต่อไปนี้

  • ชื่อการดำเนินการ: add.936
    • นี่คือชื่อที่ไม่ซ้ำกันของการดำเนินการ
  • รูปร่าง: bf16[8,1,1280,16384]
    • นี่คือรูปร่างเอาต์พุตของ Op โดย dtype คือ bf16 และรูปร่างคือ [8,1,1280,16384]
  • เลย์เอาต์ (พร้อมการเรียง): 3,2,0,1:T(8,128)(2,1)
    • ซึ่งอธิบายวิธีจัดเก็บอาร์เรย์ในหน่วยความจำ 3,2,0,1 แสดงถึง ลำดับของแกนในหน่วยความจำ (เช่น คอลัมน์หลัก แถวหลัก ฯลฯ) และ T(8,128)(2,1) แสดงถึงการแบ่งไทล์และการเว้นวรรคที่ใช้
    • เลย์เอาต์เป็นข้อมูลที่ไม่บังคับ หากไม่ได้ระบุ จะไม่มีการเรียงต่อกัน และระบบจะถือว่า มิติข้อมูลเรียงจากมิติข้อมูลหลักสุดไปมิติข้อมูลย่อยสุด
  • การดำเนินการ: add
    • การดำเนินการที่กำลังดำเนินการ ในที่นี้คือ เพิ่ม ซึ่งมีการกล่าวถึงในชื่อ Op ด้วย
  • อาร์กิวเมนต์: exponential.183, broadcast.3115
    • การดำเนินการนี้ใช้อาร์กิวเมนต์ 2 รายการซึ่งระบุด้วยชื่อที่ไม่ซ้ำกัน

มาดูอีกตัวอย่างหนึ่งกัน นั่นคือการผสาน Op

%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)}
            fusion(bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32),
            kind=kCustom, calls=%all-reduce-scatter.3

นอกเหนือจากคอมโพเนนต์ที่อธิบายไว้ก่อนหน้านี้แล้ว ยังประกอบด้วย

  • แอตทริบิวต์: kind และ calls
    • ซึ่งจะให้ข้อมูลเพิ่มเติมเกี่ยวกับการดำเนินการที่กำลังทำอยู่ ในกรณีนี้คือการผสาน
  • ตำแหน่งในหน่วยความจำ (ตัวระบุพื้นที่หน่วยความจำ): S(1)
    • ซึ่งระบุพื้นที่หน่วยความจำ/ตำแหน่งที่จัดเก็บอาร์เรย์ S(1) ที่นี่หมายถึงอาร์เรย์นี้อยู่ใน VMEM (ใน TPU)
  • รายละเอียดรูปร่างและเลย์เอาต์สำหรับอาร์กิวเมนต์อินพุต %fusion.32

ส่วนต่อไปนี้จะอธิบายรูปร่าง เลย์เอาต์ และ ตัวระบุพื้นที่หน่วยความจำ ดูข้อมูลเพิ่มเติมเกี่ยวกับ การเรียงหน้าต่างได้ในเลย์เอาต์แบบเรียง

รูปร่าง

XLA ShapeProto proto (xla_data.proto) อธิบายจำนวนมิติข้อมูล ขนาด และประเภทข้อมูลของอาร์เรย์ N มิติ (เรียกสั้นๆ ว่าอาร์เรย์)

คำศัพท์ สัญกรณ์ และแบบแผน

  • จำนวนมิติข้อมูลจริงของอาร์เรย์คือจำนวนมิติข้อมูลที่มีขนาดมากกว่า 1

  • มิติข้อมูลจะมีหมายเลขตั้งแต่ 0 ถึง N-1 สำหรับอาร์เรย์ N มิติ ขนาดของมิติข้อมูลต้องเป็นจำนวนเต็มที่ไม่ติดลบ โดยเฉพาะขนาด 0 จะ ใช้ได้ หมายเลขมิติข้อมูลเป็นป้ายกำกับที่กำหนดขึ้นเพื่อความสะดวก ลำดับของหมายเลขมิติข้อมูลเหล่านี้ไม่ได้หมายถึงการเรียงลำดับย่อย/หลัก ที่เฉพาะเจาะจงในเลย์เอาต์ของรูปร่าง เลย์เอาต์จะกำหนดโดย LayoutProto โปรโต

  • ตามธรรมเนียมแล้ว มิติข้อมูลจะแสดงตามลำดับที่เพิ่มขึ้นของหมายเลขมิติข้อมูล เช่น สำหรับอาร์เรย์ 3 มิติที่มีขนาด [A x B x C] มิติที่ 0 มีขนาด A มิติที่ 1 มีขนาด B และมิติที่ 2 มีขนาด C

    ยูทิลิตีบางอย่างใน XLA ยังรองรับการจัดทำดัชนีเชิงลบแบบ Python ด้วย โดยมิติข้อมูล -1 คือมิติข้อมูลสุดท้าย (เทียบเท่ากับ N-1 สำหรับอาร์เรย์ N มิติ) เช่น สำหรับอาร์เรย์ 3 มิติที่อธิบายไว้ข้างต้น มิติข้อมูล -1 มีขนาด C มิติข้อมูล -2 มีขนาด B และอื่นๆ

  • อาร์เรย์ 2, 3 และ 4 มิติมักมีตัวอักษรเฉพาะ ที่เชื่อมโยงกับมิติข้อมูล เช่น สำหรับอาร์เรย์ 2 มิติ

    • มิติข้อมูล 0: y
    • มิติข้อมูล 1: x

    สำหรับอาร์เรย์ 3 มิติ ให้ทำดังนี้

    • มิติข้อมูล 0: z
    • มิติข้อมูล 1: y
    • มิติข้อมูล 2: x

    สำหรับอาร์เรย์ 4 มิติ

    • มิติข้อมูล 0: p
    • มิติข้อมูล 1: z
    • มิติข้อมูล 2: y
    • มิติข้อมูล 3: x
  • ฟังก์ชันใน XLA API ที่ใช้มิติข้อมูลจะเรียงตามลำดับที่เพิ่มขึ้นของ หมายเลขมิติข้อมูล ซึ่งตรงกับการจัดลำดับที่ใช้เมื่อส่งมิติข้อมูลเป็น initializer_list เช่น

    ShapeUtil::MakeShape(F32, {A, B, C, D})

    จะสร้างรูปร่างที่มีอาร์เรย์ขนาดมิติข้อมูลซึ่งประกอบด้วยลำดับ [A, B, C, D]

เลย์เอาต์

LayoutProto proto อธิบายวิธีแสดงอาร์เรย์ในหน่วยความจำ โดยจะ มีฟิลด์ต่อไปนี้

message LayoutProto {
  repeated int64 minor_to_major;
  int64 tail_padding_alignment_in_elements;
  ...
}

การจัดเรียงมิติข้อมูลจากเล็กไปใหญ่

ช่องที่ต้องกรอกมีเพียงminor_to_major ฟิลด์นี้อธิบาย ลำดับจากน้อยไปมากของมิติข้อมูลภายในรูปร่าง ค่าใน minor_to_major คือการจัดลำดับมิติข้อมูลของอาร์เรย์ (0 ถึง N-1 สำหรับอาร์เรย์ N มิติ) โดยค่าแรกคือมิติข้อมูลที่เล็กที่สุด จนถึงค่าสุดท้ายซึ่งเป็นมิติข้อมูลที่ใหญ่ที่สุด มิติข้อมูลที่เล็กที่สุด คือมิติข้อมูลที่เปลี่ยนแปลงเร็วที่สุดเมื่อเลื่อนดู องค์ประกอบของอาร์เรย์ที่จัดวางในหน่วยความจำเชิงเส้น

ตัวอย่างเช่น ลองพิจารณาอาร์เรย์ 2 มิติต่อไปนี้ที่มีขนาด [2 x 3]

a b c
d e f

ในที่นี้ มิติข้อมูล 0 คือขนาด 2 และมิติข้อมูล 1 คือขนาด 3 หากminor_to_majorฟิลด์ในเลย์เอาต์เป็น [0, 1] มิติข้อมูล 0 จะเป็นมิติข้อมูลที่เล็กที่สุด และมิติข้อมูล 1 จะเป็นมิติข้อมูลที่ใหญ่ที่สุด ซึ่ง สอดคล้องกับเลย์เอาต์ต่อไปนี้ในหน่วยความจำเชิงเส้น

a d b e c f

ลำดับมิติข้อมูลจากเล็กไปใหญ่ของ 0 ถึง N-1 จะคล้ายกับคอลัมน์หลัก (สำหรับ 2 มิติ) สมมติว่ามิติข้อมูลมีการจัดเรียงแบบ Monotonic อีกวิธีหนึ่ง ที่เราอาจอ้างอิงเลย์เอาต์นี้ในโค้ดก็คือ "dim 0 เป็นแบบย่อย"

ในทางกลับกัน หากminor_to_majorฟิลด์ในเลย์เอาต์เป็น [1, 0] เลย์เอาต์ในหน่วยความจำเชิงเส้นจะเป็นดังนี้

a b c d e f

ลำดับมิติข้อมูลจากเล็กไปใหญ่ของ N-1 ถึง 0 สำหรับอาร์เรย์ N มิติ จะคล้ายกับ แถวหลัก (สำหรับ 2 มิติ) หากสมมติว่ามิติข้อมูลมีการจัดเรียงแบบโมโนโทน อีกวิธีที่เราอาจอ้างอิงถึงเลย์เอาต์นี้ในโค้ดก็คือ เพียงแค่ "dim 0 เป็นหลัก"

การเรียงลำดับจากรุ่นย่อยไปรุ่นหลักเริ่มต้น

เลย์เอาต์เริ่มต้นสำหรับรูปร่างที่สร้างใหม่คือ "ลำดับมิติข้อมูลคือ จากมิติข้อมูลหลักไปยังมิติข้อมูลรอง" (เช่น [N-1, ..., 0])

Padding

ฟิลด์ tail_padding_alignment_in_elements จะกำหนดการจัดแนวของอาร์เรย์ tiled ในแง่ของจำนวนองค์ประกอบ หลังจาก ใช้การปูกระเบื้องแล้ว ระบบจะเพิ่มองค์ประกอบที่มีการเว้นวรรคที่ส่วนท้ายของเลย์เอาต์จนกว่า จำนวนองค์ประกอบทั้งหมดจะเป็นค่าที่คูณด้วยค่านี้

การจัดทำดัชนีในอาร์เรย์

คลาส IndexUtil ใน index_util.h มีเครื่องมือสำหรับแปลงระหว่างดัชนีหลายมิติและดัชนีเชิงเส้น เมื่อกำหนดรูปร่างและเลย์เอาต์ ดัชนีหลายมิติประกอบด้วยint64 ดัชนีสำหรับแต่ละมิติข้อมูล ดัชนีเชิงเส้นคือค่า int64 ค่าเดียวซึ่ง จัดทำดัชนีลงในบัฟเฟอร์ที่เก็บอาร์เรย์ ดู shape_util.h และ layout_util.h ในไดเรกทอรีเดียวกันเพื่อดูยูทิลิตีที่จะช่วยให้การสร้างและการ ปรับแต่งรูปร่างและเลย์เอาต์เป็นเรื่องง่าย

ตัวระบุพื้นที่ความทรงจำ

ใน HLO คุณอาจใส่คำอธิบายประกอบอาร์เรย์แต่ละรายการด้วยตัวระบุพื้นที่หน่วยความจำ ซึ่งเขียนเป็น S(n)

  • S(0) (มักจะไม่มี) หมายถึงหน่วยความจำแบนด์วิดท์สูง (HBM) ของอุปกรณ์
  • S(1) แสดงถึงหน่วยความจำเสมือน (VMEM) ในอุปกรณ์
  • S(2), S(3) ฯลฯ สอดคล้องกับพื้นที่หน่วยความจำเพิ่มเติมที่เฉพาะเจาะจงของอุปกรณ์
  • S(5) แสดงหน่วยความจำของโฮสต์