En este documento, se describe la semántica de transmisión de XLA.
¿Qué es la transmisión?
La transmisión es el proceso de hacer que los arrays con diferentes formas tengan formas compatibles para las operaciones aritméticas. La terminología se tomó de la transmisión de NumPy.
La transmisión puede ser necesaria para las operaciones entre arrays multidimensionales de diferentes rangos o entre arrays multidimensionales con formas diferentes, pero compatibles. Considera la suma X+v, en la que X es una matriz (un array con 2 dimensiones) y v es un vector (un array con 1 dimensión). Para realizar la suma de elementos, XLA necesita "transmitir" el vector v a la misma cantidad de dimensiones que la matriz X, replicando v una cierta cantidad de veces. La longitud del vector debe coincidir con al menos una de las dimensiones de la matriz.
Por ejemplo:
|1 2 3| + |7 8 9|
|4 5 6|
Las dimensiones de la matriz son (2,3), y la dimensión del vector es (3). El vector se transmite replicándolo en las filas para obtener lo siguiente:
|1 2 3| + |7 8 9| = |8 10 12|
|4 5 6| |7 8 9| |11 13 15|
En NumPy, esto se denomina transmisión.
Principios
El lenguaje XLA es lo más estricto y explícito posible, y evita las funciones "mágicas" implícitas. Estas funciones pueden facilitar un poco la definición de algunos cálculos, pero a costa de más suposiciones incorporadas en el código del usuario que serán difíciles de cambiar a largo plazo. Si es necesario, se pueden agregar funciones mágicas implícitas en los wrappers a nivel del cliente.
En cuanto a la transmisión, XLA requiere especificaciones explícitas de transmisión en las operaciones entre arrays de diferentes rangos. Esto es diferente de NumPy, que infiere la especificación cuando es posible.
Transmisión de un array de menor dimensión a uno de mayor dimensión
Los escalares siempre se pueden transmitir a través de arrays sin una especificación explícita de las dimensiones de transmisión. Una operación binaria por elementos entre un escalar y un array significa aplicar la operación con el escalar a cada elemento del array. Por ejemplo, agregar un escalar a una matriz significa producir una matriz en la que cada elemento es una suma del escalar y el elemento correspondiente de la matriz de entrada.
|1 2 3| + 7 = |8 9 10|
|4 5 6| |11 12 13|
La mayoría de las necesidades de transmisión se pueden capturar con una tupla de dimensiones en una operación binaria. Cuando las entradas de la operación tienen diferentes rangos, esta tupla de transmisión especifica qué dimensiones del array de mayor dimensión deben coincidir con el array de menor dimensión.
Considera el ejemplo anterior. En lugar de agregar un escalar a una matriz (2,3), agrega un vector de dimensión (3) a una matriz de dimensiones (2,3). Si no se especifica la transmisión, esta operación no es válida. Para solicitar correctamente la suma de una matriz y un vector, especifica que la dimensión de transmisión sea (1), lo que significa que la dimensión del vector coincide con la dimensión 1 de la matriz. En 2D, si la dimensión 0 representa las filas y la dimensión 1 representa las columnas, esto significa que cada elemento del vector se convierte en una columna de un tamaño que coincide con la cantidad de filas de la matriz:
|7 8 9| ==> |7 8 9|
|7 8 9|
Como ejemplo más complejo, considera agregar un vector de 3 elementos (dimensión (3)) a una matriz de 3 x 3 (dimensiones (3,3)). En este ejemplo, la transmisión puede ocurrir de dos maneras:
(1) Se puede usar una dimensión de transmisión de 1. Cada elemento del vector se convierte en una columna, y el vector se duplica para cada fila de la matriz.
|7 8 9| ==> |7 8 9|
|7 8 9|
|7 8 9|
(2) Se puede usar una dimensión de transmisión de 0. Cada elemento del vector se convierte en una fila, y el vector se duplica para cada columna de la matriz.
|7| ==> |7 7 7|
|8| |8 8 8|
|9| |9 9 9|
Las dimensiones de transmisión pueden ser una tupla que describa cómo se transmite una forma de menor dimensión a una de mayor dimensión. Por ejemplo, dado un cuboide de 2x3x4 y una matriz de 3x4, una tupla de transmisión (1,2) significa que la matriz coincide con las dimensiones 1 y 2 del cuboide.
Este tipo de transmisión se usa en las operaciones binarias en XlaBuilder si se proporciona el argumento broadcast_dimensions. Por ejemplo, consulta XlaBuilder::Add.
En el código fuente de XLA, este tipo de transmisión a veces se denomina transmisión "InDim".
Definición formal
El atributo de transmisión permite hacer coincidir un array de menor dimensión con uno de mayor dimensión especificando qué dimensiones del array de mayor dimensión deben coincidir. Por ejemplo, para un array con dimensiones MxNxPxQ, un vector con dimensión T se puede correlacionar de la siguiente manera:
MxNxPxQ
dim 3: T
dim 2: T
dim 1: T
dim 0: T
En cada caso, T debe ser igual a la dimensión coincidente del array de mayor dimensión. Luego, los valores del vector se transmiten desde la dimensión coincidente a todas las demás dimensiones.
Para hacer coincidir una matriz de TxV con el array de MxNxPxQ, se usa un par de dimensiones de transmisión:
MxNxPxQ
dim 2,3: T V
dim 1,2: T V
dim 0,3: T V
etc...
El orden de las dimensiones en la tupla de transmisión debe ser el orden en el que se espera que las dimensiones del array de menor dimensión coincidan con las dimensiones del array de mayor dimensión. El primer elemento de la tupla especifica qué dimensión del array de mayor dimensión debe coincidir con la dimensión 0 del array de menor dimensión. El segundo elemento de la tupla especifica qué dimensión del array de mayor dimensión debe coincidir con la dimensión 1 del array de menor dimensión, y así sucesivamente. El orden de las dimensiones de transmisión debe ser estrictamente creciente. Por ejemplo, en el ejemplo anterior, es ilegal hacer coincidir V con N y T con P. También es ilegal hacer coincidir V con P y N.
Transmisión de arrays de dimensiones similares con dimensiones degeneradas
Un problema relacionado es la transmisión de dos arrays que tienen la misma cantidad de dimensiones, pero diferentes tamaños de dimensión. Al igual que con NumPy, esto solo es posible cuando los arrays son compatibles. Dos arrays son compatibles cuando todas sus dimensiones son compatibles. Dos dimensiones son compatibles si se cumplen las siguientes condiciones:
- Son iguales.
- Una de ellas es 1 (una dimensión "").
Cuando se encuentran dos arrays compatibles, la forma del resultado tiene el máximo de las dos entradas en cada índice de dimensión.
Ejemplos:
- (2,1) y (2,3) se transmiten a (2,3).
- (1,2,5) y (7,2,5) se transmiten a (7,2,5).
- (7,2,5) y (7,1,5) se transmiten a (7,2,5).
- (7,2,5) y (7,2,6) son incompatibles y no se pueden transmitir.
Se presenta un caso especial, que también se admite, en el que cada uno de los arrays de entrada tiene una dimensión en un índice diferente. En este caso, el resultado es una "operación externa": (2,1) y (1,3) se transmiten a (2,3). Para obtener más ejemplos, consulta la documentación de NumPy sobre la transmisión.
Composición de la transmisión
La transmisión de un array de menor dimensión a uno de mayor dimensión y la transmisión con dimensiones degeneradas se pueden realizar en la misma operación binaria. Por ejemplo, se pueden sumar un vector de tamaño 4 y una matriz de tamaño 1 x 2 con dimensiones de transmisión de valor (0):
|1 2 3 4| + [5 6] // [5 6] is a 1x2 matrix, not a vector.
Primero, el vector se transmite hasta 2 dimensiones (matriz) con las dimensiones de transmisión. El valor único (0) en las dimensiones de transmisión indica que la dimensión cero del vector coincide con la dimensión cero de la matriz. Esto produce una matriz de tamaño 4xM, en la que el valor M se elige para que coincida con el tamaño de la dimensión correspondiente en el array de 1x2. Por lo tanto, se produce una matriz de 4x2:
|1 1| + [5 6]
|2 2|
|3 3|
|4 4|
Luego, la "transmisión de dimensiones degeneradas" transmite la dimensión cero de la matriz de 1 x 2 para que coincida con el tamaño de la dimensión correspondiente del lado derecho:
|1 1| + |5 6| |6 7|
|2 2| + |5 6| = |7 8|
|3 3| + |5 6| |8 9|
|4 4| + |5 6| |9 10|
Un ejemplo más complicado es una matriz de tamaño 1x2 agregada a un array de tamaño 4x3x1 con dimensiones de transmisión de (1, 2). Primero, la matriz de 1x2 se transmite hasta 3 dimensiones usando las dimensiones de transmisión para producir un array intermedio de Mx1x2, en el que el tamaño de la dimensión M se determina según el tamaño del operando más grande (el array de 4x3x1), lo que produce un array intermedio de 4x1x2. La letra M se encuentra en la dimensión 0 (la dimensión más a la izquierda) porque las dimensiones 1 y 2 se asignan a las dimensiones de la matriz original de 1 x 2, ya que las dimensiones de transmisión son (1, 2). Este array intermedio se puede agregar a la matriz de 4 x 3 x 1 con la transmisión de dimensiones degeneradas para producir un resultado de array de 4 x 3 x 2.