Shardy es un sistema de partición de tensores basado en MLIR para todos los dialectos. Se creó a partir de la colaboración de los equipos de GSPMD y PartIR, y en él se incorpora lo mejor de ambos sistemas, así como la experiencia compartida de los equipos y los usuarios.
Beneficios
- Combina la propagación de GSPMD con la partición incremental de PartIR para brindarles a los usuarios más control y previsibilidad.
- Funciones nuevas impulsadas por la experiencia compartida, p.ej., compatibilidad novedosa con cambios de forma que generan comunicación adicional, a menos que los usuarios sepan cómo evitarlos.
- Mejor usabilidad y depurabilidad para aumentar la velocidad del usuario final, p. ej., con una representación de fragmentación basada en ejes
- Una base de código simple y de código abierto que usa MLIR, con un conjunto más amplio de colaboradores activos (internos, externos y en varias zonas horarias) para brindar asistencia a los usuarios.
Componentes
- Representación de fragmentación: Es una representación de fragmentación basada en ejes que está vinculada a una malla lógica específica (de potencialmente varias mallas) y admite ejes y fragmentaciones de dimensiones de restricción, ejes de división para operaciones como cambiar de forma, prioridades para la partición incremental y mucho más.
- APIs del compilador: Es un conjunto de componentes del compilador que se pueden usar junto con la representación de fragmentación para influir en la propagación de fragmentación.
- División de entrada/salida: Adjunta una división a una entrada o salida de la función principal para indicar que así es como se debe dividir el tensor de entrada/salida cuando se le pasa a la función o se muestra desde ella.
- Restricción de fragmentación: Adjunta una fragmentación a un tensor intermedio (p.ej., el resultado de un matmul) para indicar que así es como se debe fragmentar ese tensor o un subconjunto de sus usos.
- Shard As/Like: Agrupa varios tensores por un ID para indicar que se deben dividir de la misma manera.
- Cálculo manual: Encierra un subcálculo que se particiona de forma manual con un subconjunto de ejes de malla, en los que se especifican las particiones a lo largo de esos ejes manuales para todas las entradas y salidas, y dentro del subcálculo, los tipos de tensores son locales en relación con esas particiones.
- Propagación de fragmentación: Es un algoritmo de propagación que combina las prioridades del usuario y las restricciones de fragmentación con modelos de costos y heurísticas del compilador:
- Prioridades definidas por el usuario, p.ej., hacer paralelismo por lotes y, luego, ZeRO
- Prioridades basadas en operaciones, p. ej., operaciones por elemento primero y, luego, operaciones de multiplicación de matrices, etcétera
- Heurísticas más detalladas, p.ej., prefiere dimensiones de lotes.
- Particionador SPMD: Es un componente que reduce las decisiones de propagación de fragmentación particionando el programa en un programa SPMD y agregando el movimiento o el formato de datos necesarios y las operaciones colectivas en el proceso.
- A corto plazo, la implementación inicial usará el particionador de SPMD de GSPMD actual.
- A largo plazo, planeamos crear un nuevo particionador SPMD basado en MLIR.
Repositorio de código
El proyecto Shardy está en desarrollo activo, y buscamos comentarios de la comunidad de código abierto. El código de Shardy está disponible en https://github.com/openxla/shardy.