Представление сегментирования

Фон

Цель представления сегментирования — указать, как сегментируется тензор относительно набора доступных устройств.

Представление шардинга может быть:

  • Вручную указывается пользователем как ограничения сегментирования на входах, выходах или промежуточных звеньях.
  • Трансформируется за операцию в процессе распространения шардинга.

Обзор

Базовая структура

Логическая сетка — это многомерное представление устройств, определяемое списком имен и размеров осей.

Предлагаемое представление сегментирования привязано к определенной логической сетке по ее имени и может ссылаться только на имена осей из этой сетки. Шардинг тензора определяет, по каким осям (конкретной логической сетки) сегментируется каждое измерение тензора, в порядке от большего к меньшему. Тензор копируется вдоль всех остальных осей сетки.

Давайте рассмотрим представление сегментирования с помощью простого тензора ранга 2 и устройств 4.

Сначала мы преобразуем 4 устройства [0, 1, 2, 3] в двумерный массив [[0, 1], [2, 3]] чтобы создать сетку с двумя осями:

@mesh_xy = <["x"=2, "y"=2]>

Затем мы можем сегментировать следующий тензор ранга 2 [[a, b], [c, d]] следующим образом:

Шардинговое представление тензора 2-го ранга

Другие ключевые компоненты

  • Открытые/закрытые измерения — размеры могут быть открытыми — их можно дополнительно сегментировать по доступным осям; или закрытые – фиксированы и не могут быть изменены.
  • Явно реплицированные оси — все оси, которые не используются для сегментирования измерения, реплицируются неявно, но при сегментировании могут быть указаны оси, которые реплицируются явно и поэтому не могут быть использованы для сегментирования измерения в дальнейшем.
  • Разделение осей и подоси — (полную) ось сетки можно разделить на несколько подосей, которые можно индивидуально использовать для сегментирования измерения или явно реплицировать.
  • Несколько логических сеток — разные сегменты могут быть привязаны к разным логическим сеткам, которые могут иметь разные оси или даже разный порядок идентификаторов логических устройств.
  • Приоритеты — для поэтапного разделения программы приоритеты можно прикрепить к сегментам измерений, которые определяют, в каком порядке ограничения сегментирования каждого измерения будут распространяться по всему модулю.
  • Делимость измерения : измерение может быть сегментировано по осям, произведение размеров которых не делит размер измерения.

Детальный проект

В этом разделе мы раскрываем базовую структуру и каждый ключевой компонент.

Базовая структура

Шардинги измерений сообщают нам для каждого измерения тензора, по каким осям (или подосям ) он сегментируется от главного к второстепенному. Все остальные оси, которые не сегментируют измерение, реплицируются неявно (или реплицируются явно ).

Мы начнем с простого примера и будем расширять его по мере описания дополнительных функций.

@mesh_xy = <["x"=2, "y"=4, "z"=2]>

// The 1st tensor dimension is sharded along axis "x" and the 2nd tensor dimension is
// sharded along axis "z" then further along axis "y". The local shape of this tensor (i.e. the shape on a single device), would be tensor<2x1xf32>.
sharding<@mesh_xy, [{"x"}, {"z", "y"}]> : tensor<4x8xf32>

Инварианты

  • Количество сегментов измерения должно соответствовать рангу тензора.
  • Все имена осей должны существовать в сетке, на которую ссылаются.
  • Оси или подоси могут появляться в представлении сегментирования только один раз (каждая из них либо сегментирует измерение, либо явно реплицируется ).

Открытые/закрытые размеры

Каждое измерение тензора может быть открытым или закрытым.

Открыть

Открытое измерение открыто для распространения с целью дальнейшего сегментирования его по дополнительным осям, т. е. указанное сегментирование измерения не обязательно должно быть окончательным сегментированием этого измерения. Это похоже (но не совсем то же самое) на

Если измерение открыто, мы добавляем ? по осям, по которым уже сегментировано измерение (см. пример ниже).

Закрыто

Закрытое измерение — это измерение, которое недоступно для распространения с целью добавления дальнейшего сегментирования, т. е. указанное сегментирование измерения является окончательным сегментированием этого измерения, и его нельзя изменить. Распространенным примером использования этого является то, что GSPMD (обычно) не изменяет аргументы ввода/вывода модуля или как с помощью jax.jit пользователь, указанный in_shardings , является статичным - они не могут измениться.

Мы можем расширить приведенный выше пример, добавив в него открытое и закрытое измерение.

@mesh_xy = <["x"=2, "y"=4, "z"=2]>

// The 1st dimension is closed, therefore it can't be further sharded and {"x"}
// will remain its sharding. The 2nd dimension is open, and can therefore be
// further sharded during propagation, e.g. by "y".
sharding<@mesh_xy, [{"x"}, {"z", ?}]> : tensor<4x8xf32>

Явно реплицированные оси

Явный набор осей, на которых реплицируется тензор. Хотя можно определить, что тензор, не сегментированный по оси, неявно реплицируется на нем (как сегодня jax.sharding.PartitionSpec ), его явное использование гарантирует, что распространение не сможет использовать эти оси для дальнейшего сегментирования открытого измерения с помощью этих осей. При неявной репликации тензор можно дополнительно разбить. Но при явной репликации ничто не может разделить тензор по этой оси.

Порядок реплицируемых осей не влияет на то, как хранятся данные тензора. Но исключительно для обеспечения единообразия оси будут храниться в том порядке, в котором они указаны в сетке верхнего уровня. Например, если сетка:

@mesh_xy = <["c"=2, "a"=2, "b"=2]>

И мы хотим, чтобы оси "a" и "c" были явно реплицированы, порядок должен быть таким:

replicated={"c", "a"}

Мы можем расширить приведенный выше пример, чтобы иметь явно реплицированную ось.

@mesh_xyz = <["x"=2, "y"=4, "z"=2]>

// Since "y" is explicitly replicated, it can't be used to shard the 2nd
// dimension that is open. However, "z" is implicitly replicated so it can be
// used to shard that dimension. The local shape of this tensor (i.e. the
// shape on a single device), would // be tensor<2x8xf32>.
sharding<@mesh_xyz, [{"x"}, {?}], replicated={"y"}> : tensor<4x8xf32>

Разделение осей и подоси

Логическая сетка из n осей создается путем преобразования одномерного массива устройств в n-мерный массив, где каждое измерение образует ось с определяемым пользователем именем.

Тот же процесс можно выполнить в компиляторе, чтобы разбить ось размера k на m подосей, изменив форму сетки из [...,k,...] в [...,k1,...,km,...] .

Мотивация

Чтобы понять мотивацию разделения осей, мы рассмотрим следующий пример:

@mesh_x = <["x"=4]>

%arg0 : tensor<8xf32> {sdy.sharding=<@mesh_x, [{"x"}]>}
%0 = reshape %arg0 : (tensor<8xf32>) -> tensor<2x4xf32>

Мы хотим сегментировать результат изменения формы таким образом, чтобы избежать обмена данными (т. е. сохранить данные там, где они есть). Поскольку размер "x" больше первого измерения результата, нам нужно разделить ось на две подоси "x.0" и "x.1" размером 2 каждая и разделить первое измерение на "x.0" и 2-е измерение на "x.1" .

Функция Шардинг ввода/вывода

Вполне возможно, что во время распространения вход или выход основной функции будет сегментирован вдоль подоси. Это может быть проблемой для некоторых фреймворков, где мы не можем выразить такие сегменты для возврата пользователю (например, в JAX мы не можем выразить подоси с помощью jax.sharding.NamedSharding ).

У нас есть несколько вариантов решения таких случаев:

  • Разрешите и верните сегментирование в другом формате (например, jax.sharding.PositionalSharding вместо jax.sharding.NamedSharding в JAX).
  • Запретить и собрать все подоси, которые сегментируют ввод/вывод.

В настоящее время мы разрешаем подоси на входах/выходах в конвейере распространения. Дайте нам знать, если вам нужен способ отключить это.

Представительство

Точно так же, как мы можем ссылаться на определенные полные оси сетки по их имени, мы можем ссылаться на определенные подоси по их размеру и произведению размеров всех подосей (с тем же именем оси) слева от них (т. главное для них).

Чтобы извлечь конкретную подось размера k из полной оси "x" размера n , мы эффективно изменяем размер n (в сетке) на [m, k, n/(m*k)] и используем второй размер в качестве подоси. Таким образом, подось может быть задана двумя числами, m и k , и для обозначения подосей мы используем следующее краткое обозначение: "x":(m)k .

  • m>=1 — это предварительный размер этой подоси ( m должен быть делителем n ). Предварительный размер — это произведение всех размеров подоси слева от этой подоси (которые являются основными для нее) (если он равен 1, это означает, что их нет. Если больше 1, это соответствует одному или нескольким подосям). -оси).

  • k>1фактический размер этой подоси ( k должен быть делителем n ).

  • n/(m*k)размер сообщения . Это произведение всех размеров подоси справа от этой подоси (которые являются второстепенными по отношению к ней) (если равно 1, это означает, что их нет, если больше 1, это соответствует одной или нескольким подосям) .

Однако количество других подосей не имеет значения при использовании конкретной подоси "x":(m)k , и на любую другую подось не нужно ссылаться при сегментировании тензора, если она не сегментирует измерение и не реплицируется явно.

Возвращаясь к примеру из раздела «Мотивация» , мы можем сегментировать результат следующим образом:

@mesh_x = <["x"=4]>

%arg0 : tensor<8xf32> {sdy.sharding=<@mesh_x, [{"x"}]>}
%0 = reshape %arg0 {sdy.sharding_per_value=<[<@mesh_x, [{"x":(1)2}, {"x":(2)2}]>]>}
    : (tensor<8xf32>) -> tensor<2x4xf32>

Вот еще один пример разделенной оси, в которой используются только некоторые из ее подосей.

@mesh_xyz = <["x"=2, "y"=8, "z"=2]>

// Axis "y" is effectively split into 3 sub-axes denoted as
//   "y":(1)2, "y":(2)2, "y":(4)2
// in order, but only "y":(2)2 is used, to shard the 2nd dimension. The local
// shape of this tensor (i.e. the shape on a single device), would be
// tensor<2x4xf32>.
sharding<@mesh_xyz, [{"x"}, {"y":(2)2}]> : tensor<4x8xf32>

Аналогично, следующие два шардинга семантически эквивалентны. Мы можем думать о mesh_xy как о разделении mesh_full .

@mesh_full = <"devices"=8>
@mesh_xy = <"x"=4, "y"=2>

sharding<@mesh_xy, [{"x"},{ "y"}]> : tensor<4x4xf32>
sharding<@mesh_full, [{"devices":(1)4}, {"devices":(4)2}]> : tensor<4x4xf32>

Явно реплицированные подоси

Помимо подосей, используемых для сегментирования измерений, они также могут быть помечены как явно реплицированные. Мы допускаем это в представлении, потому что подоси ведут себя так же, как полные оси, т. е. когда вы сегментируете измерение вдоль подоси оси "x" , другие подоси "x" неявно копируются и, следовательно, могут быть явно реплицируется, чтобы указать, что подось должна оставаться реплицированной и не может использоваться для сегментирования измерения.

Например:

@mesh_xyz = <["x"=2, "y"=8, "z"=2]>

// Sub-axis "y":(1)2 is explicitly replicated and "y":(4)2 is implicitly replicated.
sharding<@mesh_xyz, [{"x"}, {"y":(2)2}], replicated={"y":(1)2}> : tensor<4x8xf32>

Реплицированные подоси одной и той же полной оси следует располагать в порядке возрастания их предварительного размера, например:

replicated={"y":(4)2, "x", "y":(1)2} ~> replicated={"x", "y":(1)2, "y":(4)2}

Инварианты

  • Подоси, на которые ссылаются в тензорном сегментировании, не должны перекрываться, например "x":(1)4 и "x":(2)4 перекрываются.

  • Подоси, на которые ссылаются в тензорном сегментировании, должны быть как можно больше, т. е. если сегментирование измерения имеет две соседние по порядку подоси A и B или подоси A и B явно реплицируются, они не должны быть последовательными, например "x":(1)2 и "x":(2)4 поскольку их можно заменить одним "x":(1)8 .

Несколько логических сеток

Одна логическая сетка представляет собой многомерное представление устройств. Нам может потребоваться несколько представлений устройств для представления наших сегментов, особенно для произвольных назначений устройств.

Например, jax.sharding.PositionalSharding не имеет одной общей логической сетки . GSPMD в настоящее время поддерживает это с помощью HloSharding, где представление может представлять собой упорядоченный список устройств и размеров измерений, но это невозможно представить с помощью разделения осей, описанного выше.

Мы преодолеваем это ограничение и обрабатываем существующие крайние случаи, определяя несколько логических сеток на верхнем уровне программы. Каждая сетка может иметь разное количество осей с разными именами, а также свое произвольное назначение для одного и того же набора устройств, т.е. каждая сетка относится к одному и тому же набору устройств (по их уникальному логическому идентификатору), но с произвольным порядком. аналогично представлению GSPMD.

Каждое представление сегментирования связано с определенной логической сеткой, поэтому оно будет ссылаться только на оси этой сетки.

Тензор, назначенный одной логической сетке, может использоваться операцией, назначенной другой сетке, путем наивного изменения тензора в соответствии с целевой сеткой. В GSPMD это обычно делается для разрешения конфликтующих сеток.

Ниже мы приводим два примера:

Пользователи могут указать несколько сеток с разными именованными осями (например, через jax.sharding.NamedSharding ), которые имеют одинаковый порядок устройств. В этом примере <@mesh_0, "b"> идентично <@mesh_1, "z">.

@mesh_0 = {<["a"=4, "b"=2]>, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]}
@mesh_1 = {<["x"=2, "y"=2, "z"=2]>, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]}

Приоритеты

Приоритет — это способ определения приоритета определенных решений по секционированию и распространению над другими, а также позволяет выполнять постепенное секционирование программы.

Приоритеты — это значения, прикрепленные к некоторым или всем измерениям представления сегментирования (реплицируемые оси не имеют приоритетов).

Например:

@mesh_xy = <["w"=6, "x"=2, "y"=4, "z"=2]>

//                                    |-> y is implicitly p0
%arg4 : sharding<@mesh_xy, [{"x"}p1, {"y"}, {"z",?}p2], replicated={} }>

Приоритеты дают пользователям более детальный контроль над распространением, например, сначала пакетный параллелизм, затем мегатрон и, наконец, сегментирование ZeRO. Это дает надежные гарантии относительно того, что секционировано, и обеспечивает лучшую отладку за счет более детальных стратегий сегментирования (можно увидеть, как программа работает только с мегатроном в изоляции).

Мы разрешаем присвоить приоритет каждому сегментированию измерения (по умолчанию 0), что указывает на то, что все сегменты с приоритетом <i будут распространяться на всю программу раньше сегментов с приоритетом i .

Даже если сегментирование имеет открытое измерение с более низким приоритетом, например, {"z",?}p2 , оно не будет переопределено другим тензорным сегментированием с более высоким приоритетом во время распространения. Однако такое открытое измерение может быть дополнительно сегментировано после распространения всех сегментов с более высоким приоритетом.

Другими словами, приоритеты НЕ связаны с тем, какое сегментирование измерений важнее другого - это порядок, в котором отдельные группы сегментирования измерений должны распространяться на всю программу, и то, как следует разрешать конфликты на промежуточных, неаннотированных тензорах.

Инварианты

  • Приоритеты начинаются с 0 (самый высокий приоритет) и увеличиваются (чтобы пользователи могли легко добавлять и удалять приоритеты, мы допускаем промежутки между приоритетами, например, используются p0 и p2, а p1 — нет).

  • Пустой сегмент закрытого измерения (т. е. {} ) не должен иметь приоритета, поскольку это не будет иметь никакого эффекта.

Делимость сегментирования измерений

Измерение размера d может быть сегментировано по осям, произведение размеров которых равно n , так что d не делится на n (что на практике потребовало бы заполнения измерения).

Например:

@mesh_xy = <["x"=8, "y"=2, "z"=3]>

sharding<@mesh_xy, [{"x"}, {"y"}, {"z"}]> : tensor<7x3x8xf32>

Грамматика

Каждая логическая сетка определяется следующим образом:

@mesh_name = <mesh_axis_1,...,mesh_axis_n>

mesh_axis ::= axis_name=axis_size

axis_name ::= str
axis_size ::= int

Представление шардинга будет иметь следующую структуру для тензора ранга r:

sharding<@mesh_name, dim_shardings, replicated=replicated_axes}

mesh_name ::= str
dim_shardings ::= [dim_sharding_1,...,dim_sharding_r]
replicated_axes ::= {axis_1,...,axis_m}

dim_sharding ::=
  {axis_1,...,axis_k} |  // closed dimension
  {axis_1,...,axis_k,?}  // open dimension

axis ::=
  axis_name  |   // a full axis
  sub_axis             // a sub axis

axis_name ::= str

sub_axis ::= axis_name:(pre_size)size
pre_size ::= int
size ::= int