Transmissão

Este documento descreve a semântica de transmissão do XLA.

O que é transmissão?

A transmissão é o processo de fazer com que matrizes com formas diferentes tenham formas compatíveis para operações aritméticas. A terminologia é emprestada da transmissão do NumPy.

A transmissão pode ser necessária para operações entre matrizes multidimensionais de classificações diferentes ou entre matrizes multidimensionais com formas diferentes, mas compatíveis. Considere a adição X+v, em que X é uma matriz (uma matriz com duas dimensões) e v é um vetor (uma matriz com uma dimensão). Para realizar a adição elemento a elemento, o XLA precisa "transmitir" o vetor v para o mesmo número de dimensões da matriz X, replicando v um determinado número de vezes. O comprimento do vetor precisa corresponder a pelo menos uma das dimensões da matriz.

Exemplo:

|1 2 3| + |7 8 9|
|4 5 6|

As dimensões da matriz são (2,3), e a dimensão do vetor é (3). O vetor é transmitido por replicação nas linhas para gerar:

|1 2 3| + |7 8 9| = |8  10 12|
|4 5 6|   |7 8 9|   |11 13 15|

No NumPy, isso é chamado de transmissão.

Princípios

A linguagem XLA é o mais estrita e explícita possível, evitando recursos implícitos "mágicos". Esses recursos podem facilitar um pouco algumas computações, mas ao custo de mais pressupostos incorporados ao código do usuário, que serão difíceis de mudar a longo prazo. Se necessário, recursos mágicos implícitos podem ser adicionados em wrappers no nível do cliente.

Em relação à transmissão, o XLA exige especificações explícitas de transmissão em operações entre matrizes de diferentes classificações. Isso é diferente do NumPy, que infere a especificação quando possível.

Transmissão de uma matriz de dimensão menor para uma de dimensão maior

Escalares sempre podem ser transmitidos em arrays sem uma especificação explícita de dimensões de transmissão. Uma operação binária elemento a elemento entre um escalar e uma matriz significa aplicar a operação com o escalar a cada elemento na matriz. Por exemplo, adicionar um escalar a uma matriz significa produzir uma matriz em que cada elemento é uma soma do escalar e do elemento correspondente da matriz de entrada.

|1 2 3| + 7 = |8  9  10|
|4 5 6|       |11 12 13|

A maioria das necessidades de transmissão pode ser capturada usando uma tupla de dimensões em uma operação binária. Quando as entradas da operação têm classificações diferentes, essa tupla de transmissão especifica quais dimensões na matriz de dimensão maior correspondem à matriz de dimensão menor.

Considere o exemplo anterior. Em vez de adicionar um escalar a uma matriz (2,3), adicione um vetor de dimensão (3) a uma matriz de dimensões (2,3). Sem especificar a transmissão, essa operação é inválida. Para solicitar corretamente a adição de matriz-vetor, especifique a dimensão de transmissão como (1). Isso significa que a dimensão do vetor é correspondente à dimensão 1 da matriz. Em 2D, se a dimensão 0 representa linhas e a dimensão 1 representa colunas, isso significa que cada elemento do vetor se torna uma coluna de um tamanho que corresponde ao número de linhas na matriz:

|7 8 9| ==> |7 8 9|
            |7 8 9|

Como um exemplo mais complexo, considere adicionar um vetor de três elementos (dimensão (3)) a uma matriz 3x3 (dimensões (3,3)). Há duas maneiras de fazer a transmissão neste exemplo:

(1) Uma dimensão de transmissão de 1 pode ser usada. Cada elemento do vetor se torna uma coluna, e o vetor é duplicado para cada linha na matriz.

|7 8 9| ==> |7 8 9|
            |7 8 9|
            |7 8 9|

(2) Uma dimensão de transmissão de 0 pode ser usada. Cada elemento do vetor se torna uma linha, e o vetor é duplicado para cada coluna na matriz.

 |7| ==> |7 7 7|
 |8|     |8 8 8|
 |9|     |9 9 9|

As dimensões de transmissão podem ser uma tupla que descreve como uma forma de dimensão menor é transmitida para uma forma de dimensão maior. Por exemplo, dado um cuboide 2x3x4 e uma matriz 3x4, uma tupla de transmissão (1,2) significa corresponder a matriz às dimensões 1 e 2 do cuboide.

Esse tipo de transmissão é usado nas operações binárias em XlaBuilder, se o argumento broadcast_dimensions for fornecido. Por exemplo, consulte XlaBuilder::Add. No código-fonte da XLA, esse tipo de transmissão às vezes é chamado de transmissão "InDim".

Definição formal

O atributo de transmissão permite corresponder uma matriz de dimensão menor a uma de dimensão maior especificando quais dimensões da matriz de dimensão maior corresponder. Por exemplo, para uma matriz com dimensões MxNxPxQ, um vetor com dimensão T pode ser correspondido da seguinte maneira:

          MxNxPxQ

dim 3:          T
dim 2:        T
dim 1:      T
dim 0:    T

Em cada caso, T precisa ser igual à dimensão correspondente da matriz de dimensão maior. Os valores do vetor são transmitidos da dimensão correspondente para todas as outras dimensões.

Para corresponder uma matriz TxV à matriz MxNxPxQ, um par de dimensões de transmissão é usado:

          MxNxPxQ
dim 2,3:      T V
dim 1,2:    T V
dim 0,3:  T     V
etc...

A ordem das dimensões na tupla de transmissão precisa ser a ordem em que as dimensões da matriz de menor dimensão devem corresponder às dimensões da matriz de maior dimensão. O primeiro elemento na tupla especifica qual dimensão na matriz de dimensão superior precisa corresponder à dimensão 0 na matriz de dimensão inferior. O segundo elemento na tupla especifica qual dimensão na matriz de dimensão superior precisa corresponder à dimensão 1 na matriz de dimensão inferior, e assim por diante. A ordem das dimensões de transmissão precisa ser estritamente crescente. Por exemplo, no exemplo anterior, é ilegal corresponder V a N e T a P, assim como corresponder V a P e N.

Transmissão de matrizes de dimensões semelhantes com dimensões degeneradas

Um problema relacionado é a transmissão de duas matrizes que têm o mesmo número de dimensões, mas tamanhos diferentes. Assim como no NumPy, isso só é possível quando as matrizes são compatíveis. Duas matrizes são compatíveis quando todas as dimensões delas são compatíveis. Duas dimensões são compatíveis se:

  • Elas são iguais ou
  • Uma delas é 1 (uma dimensão "degenerada")

Quando duas matrizes compatíveis são encontradas, a forma resultante tem o máximo das duas entradas em todos os índices de dimensão.

Exemplos:

  1. (2,1) e (2,3) são transmitidos para (2,3).
  2. (1,2,5) e (7,2,5) são transmitidos para (7,2,5).
  3. (7,2,5) e (7,1,5) são transmitidos para (7,2,5).
  4. (7,2,5) e (7,2,6) são incompatíveis e não podem ser transmitidos.

Um caso especial surge e também é compatível quando cada uma das matrizes de entrada tem uma dimensão degenerada em um índice diferente. Nesse caso, o resultado é uma "operação externa": (2,1) e (1,3) são transmitidos para (2,3). Para mais exemplos, consulte a documentação do NumPy sobre transmissão.

Composição de transmissão

A transmissão de uma matriz de dimensão inferior para uma de dimensão superior e a transmissão usando dimensões degeneradas podem ser realizadas na mesma operação binária. Por exemplo, um vetor de tamanho 4 e uma matriz de tamanho 1x2 podem ser adicionados usando dimensões de transmissão de valor (0):

|1 2 3 4| + [5 6]    // [5 6] is a 1x2 matrix, not a vector.

Primeiro, o vetor é transmitido até duas dimensões (matriz) usando as dimensões de transmissão. O valor único (0) nas dimensões de transmissão indica que a dimensão zero do vetor corresponde à dimensão zero da matriz. Isso produz uma matriz de tamanho 4xM, em que o valor M é escolhido para corresponder ao tamanho da dimensão correspondente na matriz 1x2. Portanto, uma matriz 4x2 é produzida:

|1 1| + [5 6]
|2 2|
|3 3|
|4 4|

Em seguida, a "transmissão de dimensão degenerada" transmite a dimensão zero da matriz 1x2 para corresponder ao tamanho da dimensão correspondente do lado direito:

|1 1| + |5 6|     |6  7|
|2 2| + |5 6|  =  |7  8|
|3 3| + |5 6|     |8  9|
|4 4| + |5 6|     |9 10|

Um exemplo mais complicado é uma matriz de tamanho 1x2 adicionada a uma matriz de tamanho 4x3x1 usando dimensões de transmissão de (1, 2). Primeiro, a matriz 1x2 é transmitida até três dimensões usando as dimensões de transmissão para produzir uma matriz intermediária Mx1x2, em que o tamanho da dimensão M é determinado pelo tamanho do operando maior (a matriz 4x3x1), produzindo uma matriz intermediária 4x1x2. O M está na dimensão 0 (a mais à esquerda) porque as dimensões 1 e 2 são mapeadas para as dimensões da matriz 1x2 original, já que as dimensões de transmissão são (1, 2). Essa matriz intermediária pode ser adicionada à matriz 4x3x1 usando a transmissão de dimensões degeneradas para produzir um resultado de matriz 4x3x2.