โครงสร้างของ XLA Op
ลองดูตัวอย่าง HLO
add.936 = bf16[8,1,1280,16384]{3,2,0,1:T(8,128)(2,1)}
add(exponential.183, broadcast.3115)
ซึ่งประกอบด้วยคอมโพเนนต์ต่อไปนี้
- ชื่อการดำเนินการ:
add.936- นี่คือชื่อที่ไม่ซ้ำกันของการดำเนินการ
- รูปร่าง:
bf16[8,1,1280,16384]- นี่คือรูปร่างเอาต์พุตของ Op โดย dtype คือ bf16 และรูปร่างคือ
[8,1,1280,16384]
- นี่คือรูปร่างเอาต์พุตของ Op โดย dtype คือ bf16 และรูปร่างคือ
- เลย์เอาต์ (พร้อมการเรียง):
3,2,0,1:T(8,128)(2,1)- ซึ่งอธิบายวิธีจัดเก็บอาร์เรย์ในหน่วยความจำ
3,2,0,1แสดงถึง ลำดับของแกนในหน่วยความจำ (เช่น คอลัมน์หลัก แถวหลัก ฯลฯ) และT(8,128)(2,1)แสดงถึงการแบ่งไทล์และการเว้นวรรคที่ใช้ - เลย์เอาต์เป็นข้อมูลที่ไม่บังคับ หากไม่ได้ระบุ จะไม่มีการเรียงต่อกัน และระบบจะถือว่า มิติข้อมูลเรียงจากมิติข้อมูลหลักสุดไปมิติข้อมูลย่อยสุด
- ซึ่งอธิบายวิธีจัดเก็บอาร์เรย์ในหน่วยความจำ
- การดำเนินการ:
add- การดำเนินการที่กำลังดำเนินการ ในที่นี้คือ เพิ่ม ซึ่งมีการกล่าวถึงในชื่อ Op ด้วย
- อาร์กิวเมนต์:
exponential.183,broadcast.3115- การดำเนินการนี้ใช้อาร์กิวเมนต์ 2 รายการซึ่งระบุด้วยชื่อที่ไม่ซ้ำกัน
มาดูอีกตัวอย่างหนึ่งกัน นั่นคือการผสาน Op
%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)}
fusion(bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32),
kind=kCustom, calls=%all-reduce-scatter.3
นอกเหนือจากคอมโพเนนต์ที่อธิบายไว้ก่อนหน้านี้แล้ว ยังประกอบด้วย
- แอตทริบิวต์:
kindและcalls- ซึ่งจะให้ข้อมูลเพิ่มเติมเกี่ยวกับการดำเนินการที่กำลังทำอยู่ ในกรณีนี้คือการผสาน
- ตำแหน่งในหน่วยความจำ (ตัวระบุพื้นที่หน่วยความจำ):
S(1)- ซึ่งระบุพื้นที่หน่วยความจำ/ตำแหน่งที่จัดเก็บอาร์เรย์
S(1)ที่นี่หมายถึงอาร์เรย์นี้อยู่ใน VMEM (ใน TPU)
- ซึ่งระบุพื้นที่หน่วยความจำ/ตำแหน่งที่จัดเก็บอาร์เรย์
- รายละเอียดรูปร่างและเลย์เอาต์สำหรับอาร์กิวเมนต์อินพุต
%fusion.32
ส่วนต่อไปนี้จะอธิบายรูปร่าง เลย์เอาต์ และ ตัวระบุพื้นที่หน่วยความจำ ดูข้อมูลเพิ่มเติมเกี่ยวกับ การเรียงหน้าต่างได้ในเลย์เอาต์แบบเรียง
รูปร่าง
XLA ShapeProto proto
(xla_data.proto)
อธิบายจำนวนมิติข้อมูล ขนาด และประเภทข้อมูลของอาร์เรย์ N มิติ (เรียกสั้นๆ ว่าอาร์เรย์)
คำศัพท์ สัญกรณ์ และแบบแผน
จำนวนมิติข้อมูลจริงของอาร์เรย์คือจำนวนมิติข้อมูลที่มีขนาดมากกว่า 1
มิติข้อมูลจะมีหมายเลขตั้งแต่
0ถึงN-1สำหรับอาร์เรย์Nมิติ ขนาดของมิติข้อมูลต้องเป็นจำนวนเต็มที่ไม่ติดลบ โดยเฉพาะขนาด 0 จะ ใช้ได้ หมายเลขมิติข้อมูลเป็นป้ายกำกับที่กำหนดขึ้นเพื่อความสะดวก ลำดับของหมายเลขมิติข้อมูลเหล่านี้ไม่ได้หมายถึงการเรียงลำดับย่อย/หลัก ที่เฉพาะเจาะจงในเลย์เอาต์ของรูปร่าง เลย์เอาต์จะกำหนดโดยLayoutProtoโปรโตตามธรรมเนียมแล้ว มิติข้อมูลจะแสดงตามลำดับที่เพิ่มขึ้นของหมายเลขมิติข้อมูล เช่น สำหรับอาร์เรย์ 3 มิติที่มีขนาด
[A x B x C]มิติที่ 0 มีขนาดAมิติที่ 1 มีขนาดBและมิติที่ 2 มีขนาดCยูทิลิตีบางอย่างใน XLA ยังรองรับการจัดทำดัชนีเชิงลบแบบ Python ด้วย โดยมิติข้อมูล -1 คือมิติข้อมูลสุดท้าย (เทียบเท่ากับ
N-1สำหรับอาร์เรย์Nมิติ) เช่น สำหรับอาร์เรย์ 3 มิติที่อธิบายไว้ข้างต้น มิติข้อมูล -1 มีขนาดCมิติข้อมูล -2 มีขนาดBและอื่นๆอาร์เรย์ 2, 3 และ 4 มิติมักมีตัวอักษรเฉพาะ ที่เชื่อมโยงกับมิติข้อมูล เช่น สำหรับอาร์เรย์ 2 มิติ
- มิติข้อมูล 0:
y - มิติข้อมูล 1:
x
สำหรับอาร์เรย์ 3 มิติ ให้ทำดังนี้
- มิติข้อมูล 0:
z - มิติข้อมูล 1:
y - มิติข้อมูล 2:
x
สำหรับอาร์เรย์ 4 มิติ
- มิติข้อมูล 0:
p - มิติข้อมูล 1:
z - มิติข้อมูล 2:
y - มิติข้อมูล 3:
x
- มิติข้อมูล 0:
ฟังก์ชันใน XLA API ที่ใช้มิติข้อมูลจะเรียงตามลำดับที่เพิ่มขึ้นของ หมายเลขมิติข้อมูล ซึ่งตรงกับการจัดลำดับที่ใช้เมื่อส่งมิติข้อมูลเป็น
initializer_listเช่นShapeUtil::MakeShape(F32, {A, B, C, D})จะสร้างรูปร่างที่มีอาร์เรย์ขนาดมิติข้อมูลซึ่งประกอบด้วยลำดับ
[A, B, C, D]
เลย์เอาต์
LayoutProto proto อธิบายวิธีแสดงอาร์เรย์ในหน่วยความจำ โดยจะ
มีฟิลด์ต่อไปนี้
message LayoutProto {
repeated int64 minor_to_major;
int64 tail_padding_alignment_in_elements;
...
}
การจัดเรียงมิติข้อมูลจากเล็กไปใหญ่
ช่องที่ต้องกรอกมีเพียงminor_to_major ฟิลด์นี้อธิบาย
ลำดับจากน้อยไปมากของมิติข้อมูลภายในรูปร่าง ค่าใน
minor_to_major คือการจัดลำดับมิติข้อมูลของอาร์เรย์ (0 ถึง N-1
สำหรับอาร์เรย์ N มิติ) โดยค่าแรกคือมิติข้อมูลที่เล็กที่สุด
จนถึงค่าสุดท้ายซึ่งเป็นมิติข้อมูลที่ใหญ่ที่สุด มิติข้อมูลที่เล็กที่สุด
คือมิติข้อมูลที่เปลี่ยนแปลงเร็วที่สุดเมื่อเลื่อนดู
องค์ประกอบของอาร์เรย์ที่จัดวางในหน่วยความจำเชิงเส้น
ตัวอย่างเช่น ลองพิจารณาอาร์เรย์ 2 มิติต่อไปนี้ที่มีขนาด [2 x 3]
a b c
d e f
ในที่นี้ มิติข้อมูล 0 คือขนาด 2 และมิติข้อมูล 1 คือขนาด 3 หากminor_to_majorฟิลด์ในเลย์เอาต์เป็น [0, 1] มิติข้อมูล 0 จะเป็นมิติข้อมูลที่เล็กที่สุด และมิติข้อมูล 1 จะเป็นมิติข้อมูลที่ใหญ่ที่สุด ซึ่ง
สอดคล้องกับเลย์เอาต์ต่อไปนี้ในหน่วยความจำเชิงเส้น
a d b e c f
ลำดับมิติข้อมูลจากเล็กไปใหญ่ของ 0 ถึง N-1 จะคล้ายกับคอลัมน์หลัก
(สำหรับ 2 มิติ) สมมติว่ามิติข้อมูลมีการจัดเรียงแบบ Monotonic อีกวิธีหนึ่ง
ที่เราอาจอ้างอิงเลย์เอาต์นี้ในโค้ดก็คือ "dim 0 เป็นแบบย่อย"
ในทางกลับกัน หากminor_to_majorฟิลด์ในเลย์เอาต์เป็น [1, 0] เลย์เอาต์ในหน่วยความจำเชิงเส้นจะเป็นดังนี้
a b c d e f
ลำดับมิติข้อมูลจากเล็กไปใหญ่ของ N-1 ถึง 0 สำหรับอาร์เรย์ N มิติ
จะคล้ายกับ แถวหลัก (สำหรับ 2 มิติ) หากสมมติว่ามิติข้อมูลมีการจัดเรียงแบบโมโนโทน
อีกวิธีที่เราอาจอ้างอิงถึงเลย์เอาต์นี้ในโค้ดก็คือ
เพียงแค่ "dim 0 เป็นหลัก"
การเรียงลำดับจากรุ่นย่อยไปรุ่นหลักเริ่มต้น
เลย์เอาต์เริ่มต้นสำหรับรูปร่างที่สร้างใหม่คือ "ลำดับมิติข้อมูลคือ
จากมิติข้อมูลหลักไปยังมิติข้อมูลรอง" (เช่น [N-1, ..., 0])
Padding
ฟิลด์ tail_padding_alignment_in_elements จะกำหนดการจัดแนวของอาร์เรย์ tiled ในแง่ของจำนวนองค์ประกอบ หลังจาก
ใช้การปูกระเบื้องแล้ว ระบบจะเพิ่มองค์ประกอบที่มีการเว้นวรรคที่ส่วนท้ายของเลย์เอาต์จนกว่า
จำนวนองค์ประกอบทั้งหมดจะเป็นค่าที่คูณด้วยค่านี้
การจัดทำดัชนีในอาร์เรย์
คลาส IndexUtil ใน
index_util.h
มีเครื่องมือสำหรับแปลงระหว่างดัชนีหลายมิติและดัชนีเชิงเส้น
เมื่อกำหนดรูปร่างและเลย์เอาต์ ดัชนีหลายมิติประกอบด้วยint64
ดัชนีสำหรับแต่ละมิติข้อมูล ดัชนีเชิงเส้นคือค่า int64 ค่าเดียวซึ่ง
จัดทำดัชนีลงในบัฟเฟอร์ที่เก็บอาร์เรย์ ดู shape_util.h และ
layout_util.h ในไดเรกทอรีเดียวกันเพื่อดูยูทิลิตีที่จะช่วยให้การสร้างและการ
ปรับแต่งรูปร่างและเลย์เอาต์เป็นเรื่องง่าย
ตัวระบุพื้นที่ความทรงจำ
ใน HLO คุณอาจใส่คำอธิบายประกอบอาร์เรย์แต่ละรายการด้วยตัวระบุพื้นที่หน่วยความจำ ซึ่งเขียนเป็น S(n)
S(0)(มักจะไม่มี) หมายถึงหน่วยความจำแบนด์วิดท์สูง (HBM) ของอุปกรณ์S(1)แสดงถึงหน่วยความจำเสมือน (VMEM) ในอุปกรณ์S(2),S(3)ฯลฯ สอดคล้องกับพื้นที่หน่วยความจำเพิ่มเติมที่เฉพาะเจาะจงของอุปกรณ์S(5)แสดงหน่วยความจำของโฮสต์