تمثيل التقسيم إلى أجزاء

الخلفية

الغرض من تمثيل التجزئة هو تحديد كيفية تقسيم مصفوفة كثيفة بالاستناد إلى مجموعة من الأجهزة المتاحة.

يمكن أن يكون تمثيل التجزئة إما:

  • يحدّدها المستخدم يدويًا على أنّها قيود تقسيم على المدخلات أو المخرجات أو الوسائط.
  • يتم تحويلها لكل عملية في عملية نشر التجزئة.

نظرة عامة

البنية الأساسية

الشبكة المنطقية هي عرض متعدد الأبعاد للأجهزة، ويتم تحديدها من خلال قائمة بأسماء محور وأحجامه.

يكون تمثيل التجزئة المقترَح مرتبطًا بشبكة منطقية معيّنة من خلال اسمها، ولا يمكنه الإشارة إلا إلى أسماء المحاور من تلك الشبكة. تُحدِّد عملية تقسيم مصفوفة كثيفة المحاور (لشبكة منطقية معيّنة) التي يتم تقسيم كل سمة من سمات المصفوفة الكثيفة على طولها، مع ترتيبها من الأكبر إلى الأصغر. يتم تكرار المتجه على طول جميع المحاور الأخرى للشبكة.

لنستكشف تمثيل التجزئة باستخدام مصفوفة بسيطة من الترتيب 2 و4 أجهزة.

نعيد أولاً تشكيل الأجهزة الأربعة [0, 1, 2, 3] إلى صفيف ثنائي الأبعاد [[0, 1], [2, 3]] لإنشاء شبكة تتضمّن محورَين:

@mesh_xy = <["x"=2, "y"=2]>

يمكننا بعد ذلك تقسيم مصفوفة [[a, b], [c, d]] من الترتيب 2 على النحو التالي:

تمثيل التجزئة لمتجه من الترتيب 2

المكوّنات الرئيسية الأخرى

  • السمات المفتوحة/المغلقة: يمكن أن تكون السمات إما مفتوحة - يمكن تقسيمها إلى شرائح إضافية على المحاور المتاحة، أو مغلقة - تكون ثابتة ولا يمكن تغييرها.
  • الأعمدة التي تتم إعادة نسخها بشكل صريح: تتم إعادة نسخ جميع الأبعاد التي لا تُستخدَم لتقسيم سمة بشكل ضمني، ولكن يمكن أن يحدِّد التقسيم الأبعاد التي تتم إعادة نسخها بشكل صريح، وبالتالي لا يمكن استخدامها لتقسيم سمة لاحقًا.
  • تقسيم المحاور والمحاور الفرعية: يمكن تقسيم محور الشبكة (الكامل) إلى محاور فرعية متعددة يمكن استخدامها بشكلٍ فردي لتقسيم سمة أو تكرارها بشكلٍ صريح.
  • شبكات منطقية متعدّدة: يمكن ربط مجموعات sharding مختلفة بشبكات منطقية مختلفة، والتي يمكن أن تتضمّن محاور مختلفة أو حتى ترتيبًا مختلفًا لأرقام تعريف الأجهزة المنطقية.
  • الأولويات: لتقسيم برنامج بشكل تدريجي، يمكن إرفاق الأولويات بتقسيم السمات، ما يحدّد الترتيب الذي سيتم فيه نشر قيود تقسيم السمات لكل سمة في الوحدة.
  • قابلية تقسيم الأبعاد: يمكن تقسيم قياس باستخدام محاور لا يقسم ناتج أحجامها قياس.

التصميم التفصيلي

نوسّع البنية الأساسية وكل مكوّن رئيسي في هذا القسم.

البنية الأساسية

تُعلمنا تقسيمات السمات لكل سمة من سمات المصفوفة، والتي يتم تقسيمها على طول محورين (أو محورَين فرعيَّين) من الرئيسي إلى الفرعي. تتمّ إعادة إنشاء جميع المحاور الأخرى التي لا تقسم سمة بشكل ضمني (أو بشكل صريح).

سنبدأ بمثال بسيط ونوسّعه أثناء وصفنا لمزيد من الميزات.

@mesh_xy = <["x"=2, "y"=4, "z"=2]>

// The 1st tensor dimension is sharded along axis "x" and the 2nd tensor dimension is
// sharded along axis "z" then further along axis "y". The local shape of this tensor (i.e. the shape on a single device), would be tensor<2x1xf32>.
sharding<@mesh_xy, [{"x"}, {"z", "y"}]> : tensor<4x8xf32>

القيم الثابتة

  • يجب أن يتطابق عدد تقسيمات السمات مع ترتيب المصفوفة.
  • يجب أن تكون جميع أسماء المحاور متوفّرة في الشبكة المُشار إليها.
  • لا يمكن أن تظهر المحاور أو المحاور الفرعية إلا مرة واحدة في تمثيل التجزئة (يؤدي كل محور إلى تجزئة سمة أو يتم تكرار محوره صراحةً).

السمات المفتوحة/المغلقة

يمكن أن تكون كل سمة من سمات مصفوفة تانسور مفتوحة أو مغلقة.

فتح

تكون السمة المفتوحة متاحة للنشر لتقسيمها بشكل أكبر على طول محورين إضافيين، أي أنّه ليس من الضروري أن يكون تقسيم السمة المحدّد هو التقسيم النهائي لسمة معيّنة. يشبه ذلك (ولكن ليس مطابقًا تمامًا)

إذا كانت السمة مفتوحة، نضيف ? بعد المحاور التي تم تقسيم السمة حسبها (راجِع المثال أدناه).

النشاط التجاري مغلق

السمة المغلقة هي سمة غير متاحة للنشر لإضافة المزيد من التحليل إلى تحليلها، أي أنّ تقسيم السمة المحدّد هو التقسيم النهائي لهذه السمة ولا يمكن تغييره. ومن حالات الاستخدام الشائعة لهذا الإجراء، الطريقة التي لا تعدّل بها GSPMD (عادةً) وسيطات الإدخال/الإخراج الخاصة بالوحدة، أو الطريقة التي تجعل فيها in_shardings التي حدّدها المستخدم ثابتة باستخدام jax.jit، إذ لا يمكن تغييرها.

يمكننا توسيع المثال أعلاه للحصول على سمة مفتوحة وسمة مغلقة.

@mesh_xy = <["x"=2, "y"=4, "z"=2]>

// The 1st dimension is closed, therefore it can't be further sharded and {"x"}
// will remain its sharding. The 2nd dimension is open, and can therefore be
// further sharded during propagation, e.g. by "y".
sharding<@mesh_xy, [{"x"}, {"z", ?}]> : tensor<4x8xf32>

المحاور التي تم تكرارها صراحةً

مجموعة صريحة من المحاور التي يتم تكرار مصفوفة تينسور عليها على الرغم من أنّه يمكن تحديد أنّ مصفوفة لاتّصال لا يتم تقسيمها على محور يتم نسخها ضمنيًا عليه (مثل jax.sharding.PartitionSpec اليوم)، فإنّ تحديدها بشكل صريح يضمن أنّ الانتشار لا يمكنه استخدام هذه المحاور لمزيد من تقسيم سمة مفتوحة باستخدام هذه المحاور. من خلال النسخ التلقائي، يمكن تقسيم ملف ملف tensor بشكل أكبر. ولكن مع التكرار الصريح، لا يمكن لأيّ عنصر تقسيم المصفوفة على طول هذا المحور.

لا يؤثّر ترتيب المحاور المكرّرة في طريقة تخزين بيانات مصفوفة تينسور. ولكن من أجل الاتساق فقط، سيتم تخزين المحاور بالترتيب الذي يتم فيه تحديدها في شبكة المستوى الأعلى. على سبيل المثال، إذا كانت الشبكة:

@mesh_xy = <["c"=2, "a"=2, "b"=2]>

ونريد تكرار محورَي "a" و"c" بشكل صريح، ويجب أن يكون الترتيب على النحو التالي:

replicated={"c", "a"}

يمكننا توسيع مثالنا أعلاه للحصول على محور مكرّر بشكل صريح.

@mesh_xyz = <["x"=2, "y"=4, "z"=2]>

// Since "y" is explicitly replicated, it can't be used to shard the 2nd
// dimension that is open. However, "z" is implicitly replicated so it can be
// used to shard that dimension. The local shape of this tensor (i.e. the
// shape on a single device), would // be tensor<2x8xf32>.
sharding<@mesh_xyz, [{"x"}, {?}], replicated={"y"}> : tensor<4x8xf32>

تقسيم المحاور والمحاور الفرعية

يتم إنشاء شبكة منطقية من n محور من خلال إعادة تشكيل صفيف أحادي الأبعاد من الأجهزة إلى صفيف ذي n أبعاد، حيث يشكّل كلّ سمة محورًا باسم يحدّده المستخدم.

يمكن تنفيذ العملية نفسها في المُجمِّع لتقسيم محور بحجم k إلى m محورًا فرعيًا، وذلك من خلال إعادة تشكيل الشبكة من [...,k,...] إلى [...,k1,...,km,...].

الحافز

لفهم السبب وراء تقسيم المحاور، سنلقي نظرة على المثال التالي:

@mesh_x = <["x"=4]>

%arg0 : tensor<8xf32> {sdy.sharding=<@mesh_x, [{"x"}]>}
%0 = reshape %arg0 : (tensor<8xf32>) -> tensor<2x4xf32>

نريد تقسيم نتيجة إعادة التنسيق بطريقة تتجنّب التواصل (أي إبقاء البيانات في مكانها). بما أنّ حجم "x" هو أكبر من السمة الأولى للنتيجة، علينا تقسيم المحور إلى محورَين فرعيَّين "x.0" و"x.1" بحجم 2 لكل منهما، وتقسيم السمة الأولى على "x.0" والسمة الثانية على "x.1".

تقسيم إدخال/إخراج الدالة

من الممكن أن تصبح مدخلات أو مخرجات الدالة الرئيسية مجزّأة على طول محور فرعي أثناء النشر. يمكن أن يشكّل ذلك مشكلة لبعض الأطر، حيث لا يمكننا التعبير عن هذه التقسيمات لعرضها على المستخدم (على سبيل المثال، في JAX، لا يمكننا التعبير عن المحاور الفرعية باستخدام jax.sharding.NamedSharding).

لدينا بعض الخيارات للتعامل مع هذه الحالات:

  • السماح بتقسيم البيانات وعرضها بتنسيق مختلف (مثل jax.sharding.PositionalSharding بدلاً من jax.sharding.NamedSharding في JAX)
  • لا تسمح بالأعمدة الفرعية التي تجمع كل الإدخالات أو المخرجات.

نسمح حاليًا باستخدام المحاور الفرعية في الإدخالات/المخرجات في مسار النشر. يُرجى إعلامنا إذا أردت طريقة لإيقاف هذه الميزة.

التمثيل

بالطريقة نفسها التي يمكننا بها الإشارة إلى محاور كاملة معيّنة من الشبكة باستخدام اسمها، يمكننا الإشارة إلى محاور فرعية معيّنة حسب حجمها ومنتج جميع أحجام المحور الفرعي (الذي يحمل اسم المحور نفسه) على يمينها (أي المحاور الرئيسية بالنسبة إليها) .

لاستخراج محور فرعي محدّد بحجم k من محور كامل "x" بحجم n، نعيد تشكيل الحجم n (في الشبكة) بفعالية إلى [m, k, n/(m*k)] ونستخدم السمة الثانية كمحور فرعي. وبالتالي، يمكن تحديد محور فرعي باستخدام رقمين، m وk، ونستخدم الرمز الموجز التالي للإشارة إلى المحاور الفرعية: "x":(m)k.

  • m>=1 هو الحجم المُسبَق لهذا المحور الفرعي (يجب أن يكونm مقسومًا عليه n). الحجم المُسبَق هو حاصل ضرب جميع أحجام المحاور الفرعية على يمين هذا المحور الفرعي (إذا كان يساوي 1، يعني ذلك أنّه لا يتوفّر أي محور فرعي، وإذا كان أكبر من 1، يعني ذلك أنّه يتطابق مع محور فرعي واحد أو أكثر).

  • k>1 هو الحجم الفعلي لهذا المحور الفرعي (يجب أن يكون k مقسومًا عليه n).

  • n/(m*k) هو مقاس المشاركة. وهو حاصل ضرب جميع أحجام المحاور الفرعية على يسار هذا المحور الفرعي (إذا كان يساوي 1، يعني ذلك أنّه لا يوجد أي محور فرعي، وإذا كان أكبر من 1، يعني ذلك أنّه يتوافق مع محور فرعي واحد أو عدّة محاور فرعية).

ومع ذلك، لا يُحدث عدد المحاور الفرعية الأخرى فرقًا عند استخدام محور فرعي محدّد "x":(m)k، ولا يلزم الإشارة إلى أي محور فرعي آخر في عملية تقسيم مصفوفة تينسور إذا لم يكن يقسّم سمة أو يتم نسخه صراحةً.

بالعودة إلى المثال في قسم "الدافع"، يمكننا تقسيم النتيجة على النحو التالي:

@mesh_x = <["x"=4]>

%arg0 : tensor<8xf32> {sdy.sharding=<@mesh_x, [{"x"}]>}
%0 = reshape %arg0 {sdy.sharding_per_value=<[<@mesh_x, [{"x":(1)2}, {"x":(2)2}]>]>}
    : (tensor<8xf32>) -> tensor<2x4xf32>

في ما يلي مثال آخر على محور مُقسَّم يتم فيه استخدام بعض محاوره الفرعية فقط.

@mesh_xyz = <["x"=2, "y"=8, "z"=2]>

// Axis "y" is effectively split into 3 sub-axes denoted as
//   "y":(1)2, "y":(2)2, "y":(4)2
// in order, but only "y":(2)2 is used, to shard the 2nd dimension. The local
// shape of this tensor (i.e. the shape on a single device), would be
// tensor<2x4xf32>.
sharding<@mesh_xyz, [{"x"}, {"y":(2)2}]> : tensor<4x8xf32>

وبالمثل، فإنّ القسمتَين التاليتَين متساويتان من الناحية الدلالية. يمكننا اعتبار mesh_xy عملية تقسيم mesh_full.

@mesh_full = <"devices"=8>
@mesh_xy = <"x"=4, "y"=2>

sharding<@mesh_xy, [{"x"},{ "y"}]> : tensor<4x4xf32>
sharding<@mesh_full, [{"devices":(1)4}, {"devices":(4)2}]> : tensor<4x4xf32>

المحاور الفرعية التي تم تكرارها صراحةً

بالإضافة إلى المحاور الفرعية المستخدَمة لتقسيم السمة، يمكن أيضًا وضع علامة عليها بأنّها مكرّرة صراحةً. نسمح بذلك في التمثيل لأنّ المحاور الفرعية تتصرّف تمامًا مثل المحاور الكاملة، أي عند تقسيم سمة على طول محور فرعي من محور "x"، تتمّ تكرار المحاور الفرعية الأخرى للمحور "x" بشكل ضمني، وبالتالي يمكن تكرارها بشكل صريح للإشارة إلى أنّه يجب أن يظلّ محور فرعي مكرّرًا ولا يمكن استخدامه لتقسيم سمة.

على سبيل المثال:

@mesh_xyz = <["x"=2, "y"=8, "z"=2]>

// Sub-axis "y":(1)2 is explicitly replicated and "y":(4)2 is implicitly replicated.
sharding<@mesh_xyz, [{"x"}, {"y":(2)2}], replicated={"y":(1)2}> : tensor<4x8xf32>

يجب ترتيب المحور الفرعي المكرّر للمحور الكامل نفسه بترتيب تصاعدي حسب الحجم المحدّد مسبقًا، على سبيل المثال:

replicated={"y":(4)2, "x", "y":(1)2} ~> replicated={"x", "y":(1)2, "y":(4)2}

القيم الثابتة

  • يجب ألا تتداخل المحاور الفرعية المُشار إليها في تقسيم مصفوفة، على سبيل المثال، يتداخل "x":(1)4 و"x":(2)4.

  • يجب أن تكون المحاور الفرعية المُشار إليها في تقسيم مصفوفة تانسور أكبر قدر ممكن، أي إذا كان تقسيم السمة يتضمّن محورَين فرعيَّين متجاورَين أ و ب بالترتيب، أو تم تكرار المحاور الفرعية أ و ب بشكل صريح، يجب ألّا تكون متتالية، على سبيل المثال، "x":(1)2 و"x":(2)4 لأنّه يمكن استبدالهما ب"x":(1)8 واحد.

شبكات منطقية متعددة

الشبكة المنطقية هي عرض متعدد الأبعاد للأجهزة. قد نحتاج إلى عدة طرق لعرض الأجهزة لتمثيل تقسيماتنا، خاصةً عند تحديد الأجهزة بشكل عشوائي.

على سبيل المثال، jax.sharding.PositionalSharding لا يحتوي على شبكة منطقية مشتركة واحدة. تتيح ميزة GSPMD حاليًا ذلك باستخدام HloSharding، حيث يمكن أن يكون التمثيل قائمة مرتبة للأجهزة وأحجام السمات، ولكن لا يمكن تمثيل ذلك باستخدام تقسيم المحور أعلاه.

نتغلب على هذا القيد ونتعامل مع الحالات الشاذة الحالية من خلال تحديد شبكات منطقية متعددة في أعلى مستوى من البرنامج. يمكن أن تحتوي كل شبكة على عدد مختلف من المحاور بأسماء مختلفة، بالإضافة إلى تحديد عشوائي لها للمجموعة نفسها من الأجهزة، أي أنّ كل شبكة تشير إلى المجموعة نفسها من الأجهزة (حسب معرّفها المنطقي الفريد) ولكن بترتيب عشوائي، على غرار تمثيل GSPMD.

يرتبط كل تمثيل لتقسيم البيانات بشبكة منطقية معيّنة، وبالتالي لن يشير سوى إلى المحاور من تلك الشبكة.

يمكن أن يستخدم إجراء تم تعيينه لشبكة منطقية واحدة مصفوفة تم تعيينها لشبكة مختلفة، وذلك من خلال إعادة تقسيم المصفوفة بشكل عفوي لمطابقة الشبكة الوجهة. في GSPMD، يتم عادةً إجراء ذلك لحلّ تعارضات الشبكات.

في ما يلي مثالان:

يمكن للمستخدمين تحديد شبكات متعددة باستخدام محاور مُسمّاة مختلفة (مثلاً من خلال jax.sharding.NamedSharding)، والتي تتضمّن ترتيب الأجهزة نفسه. في هذا المثال، <@mesh_0, "b"> متطابقة مع <@mesh_1, "z">..

@mesh_0 = {<["a"=4, "b"=2]>, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]}
@mesh_1 = {<["x"=2, "y"=2, "z"=2]>, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]}

الأولويات

الأولوية هي طريقة لمنح الأولوية لقرارات تقسيم ونشر معيّنة على غيرها، كما تسمح بتقسيم برنامج بشكل تدريجي.

الأولويات هي قيم مرتبطة ببعض أو كل سمات تمثيل التجزئة (لا تتضمّن المحاور المكرّرة أولويات).

على سبيل المثال:

@mesh_xy = <["w"=6, "x"=2, "y"=4, "z"=2]>

//                                    |-> y is implicitly p0
%arg4 : sharding<@mesh_xy, [{"x"}p1, {"y"}, {"z",?}p2], replicated={} }>

تمنح الأولويات للمستخدمين إمكانية التحكّم بشكل أدق في النشر، مثل توازُم الدفعات أولاً، ثمّ توازُم megatron، وأخيراً تجزئة ZeRO. يتيح ذلك ضمانات قوية بشأن ما يتم تقسيمه ويسمح بإمكانية تصحيح الأخطاء بشكل أفضل من خلال توفُّر استراتيجيات تقسيم أكثر دقة (يمكنك الاطّلاع على شكل البرنامج بعد استخدام Megatron فقط بشكل منفصل).

نسمح بإرفاق الأولوية بكل عملية تقسيم سمة (0 تلقائيًا)، ما يشير إلى أنّه سيتم نشر جميع عمليات التقسيم ذات الأولوية <i على البرنامج بالكامل قبل عمليات التقسيم ذات الأولوية i.

حتى إذا كانت شريحة البيانات تتضمّن سمة مفتوحة ذات أولوية أقل، مثلاً {"z",?}p2، لن يتم إلغاء تقسيم مصفوفة آخر بأولوية أعلى أثناء الانتشار. ومع ذلك، يمكن تقسيم هذه السمة المفتوحة إلى أقسام أخرى بعد نشر كل القسمات التي لها أولوية أعلى.

بعبارة أخرى، لا تتعلق الأولويات بتحديد عملية تقسيم السمات التي هي أكثر أهمية من غيرها، بل هي الترتيب الذي يجب أن تنتشر به مجموعات مختلفة من عمليات تقسيم السمات إلى البرنامج بأكمله، وكيفية حلّ التعارضات في المتسلسلات غير المُشارَك عليها.

القيم الثابتة

  • تبدأ الأولويات من 0 (أعلى أولوية) وتزداد (للسماح للمستخدمين بإضافة وإزالة الأولويات بسهولة، نسمح بفواصل بين الأولويات، على سبيل المثال، يتم استخدام p0 و p2 ولكن لا يتم استخدام p1).

  • تقسيم سمة مغلقة فارغة (أي {})، يجب ألا يكون لها أولوية، لأنّ ذلك لن يؤثر في أيّ شيء.

قابلية تقسيم شرائح السمات

من الممكن تقسيم سمة بحجم d على محاور يكون حاصل ضرب أحجامها هو n، بحيث لا يكون d قابلاً للقسمة على n (ما قد يتطلّب في الممارسة ملء السمة).

على سبيل المثال:

@mesh_xy = <["x"=8, "y"=2, "z"=3]>

sharding<@mesh_xy, [{"x"}, {"y"}, {"z"}]> : tensor<7x3x8xf32>

القواعد اللغوية

يتم تعريف كل شبكة منطقية على النحو التالي:

@mesh_name = <mesh_axis_1,...,mesh_axis_n>

mesh_axis ::= axis_name=axis_size

axis_name ::= str
axis_size ::= int

سيكون تمثيل التجزئة بالبنية التالية لمكثف رتبته r:

sharding<@mesh_name, dim_shardings, replicated=replicated_axes}

mesh_name ::= str
dim_shardings ::= [dim_sharding_1,...,dim_sharding_r]
replicated_axes ::= {axis_1,...,axis_m}

dim_sharding ::=
  {axis_1,...,axis_k} |  // closed dimension
  {axis_1,...,axis_k,?}  // open dimension

axis ::=
  axis_name  |   // a full axis
  sub_axis             // a sub axis

axis_name ::= str

sub_axis ::= axis_name:(pre_size)size
pre_size ::= int
size ::= int