카테고리: 컴파일 시간: 하드웨어에서 지원되지 않는 RHS DataType
이 오류는 행렬 곱셈(예: jax.lax.dot_general, jax.lax.conv, jax.numpy.matmul 또는 @ 연산자)이 사용 중인 특정 TPU 세대에서 기본적으로 지원되지 않습니다.
샘플 오류 메시지:
INTERNAL: Mosaic failed to compile TPU kernel: Unsupported matmul RHS type on target: 'vector<256x256xi8>'
...
The MLIR operation involved:
%13440 = "tpu.matmul"(%13435, %13437, %13439) <dimension_numbers = #tpu.dot_dimension_numbers<...>
XLA 백엔드: TPU
개요
TPU의 행렬 곱셈 단위 (MXU)는 모든 하드웨어 세대에서 Float32 작업을 기본적으로 지원합니다.
하지만 BFloat16 및 기타 양자화된 데이터 유형(예: Int4, Int8 또는 Float8)은 하드웨어 세대에 따라 다릅니다. 이 오류는 커널이 특정 TPU 세대에 실행할 물리적 회로가 없는 데이터 유형을 사용하여 행렬 곱셈을 MXU에 매핑하려고 할 때 트리거됩니다.
이 오류는 일반적으로 지원되지 않는 유형을 지원되는 유형으로 자동 변환하려고 시도하는 컴파일러의 정규화 패스(예: 소프트웨어 에뮬레이션을 통해)가 유효한 변환 규칙을 찾을 수 없거나 호환성 모드가 사용 중지되어 변환을 수행할 수 없음을 나타냅니다.
디버깅
이 오류를 해결하려면 데이터 유형을 하드웨어 기능과 일치시켜야 합니다. 다음 옵션 중에서 선택하세요.
1. 네이티브 유형으로 변환
가장 신뢰할 수 있는 해결 방법은 matmul 작업 전에 커널 내에서 피연산자를 하드웨어 지원 데이터 유형 (예: TPU v4 이상의 Float32 또는 BFloat16)으로 수동으로 변환하는 것입니다.
- 이유:
Float32은 모든 TPU 세대에서 MXU가 기본적으로 지원하는 범용 데이터 유형입니다. - 트레이드 오프: 여기에는 VPU(벡터 처리 장치) 비용(캐스팅을 실행하는 데 필요한 사이클)이 따르지만 커널이 현재 하드웨어에서 실행된다는 보장이 있습니다.
2. 호환성 모드 확인
일반적으로 컴파일러는 기본적으로 사용 설정된 호환성 모드에서 이러한 유형 불일치 문제를 자동으로 처리할 수 있습니다. --xla_mosaic_compat_mode가 false로 설정되지 않았는지 XLA 구성을 다시 확인합니다.
이는 하드웨어에서 기본적으로 지원하지 않는 작업을 위한 소프트웨어 에뮬레이션 시퀀스를 삽입하는 '폴리필' 역할을 합니다.
호환성 모드에서 사용 설정되는 기능:
- 혼합 정밀도 MatMul: MatMul 전에 정수를
Float32로 확장하는 등 캐스트 작업을 자동으로 삽입하여 정수 피연산자를 부동 소수점 누적기와 혼합할 수 있습니다. - 낮은 정밀도 에뮬레이션: 특정 하드웨어 세대에서 실행 전에
4-bit부동 소수점 (4E2M1FN) 또는8-bit부동 소수점 (8E4M3FN)과 같이 지원되지 않는 유형을BFloat16또는Float32과 같은 지원되는 유형으로 확장하여 에뮬레이션합니다.
이 모드는 에뮬레이션이 MXU가 데이터를 처리하기 전에 데이터 형식을 변환하는 추가 명령어가 필요하므로 최대 성능보다 호환성을 우선시합니다.
3. 하드웨어 업그레이드 또는 지원 요청
알고리즘에 캐스팅이나 에뮬레이션 오버헤드 없이 Int4 또는 Float8과 같은 유형의 네이티브 성능이 엄격하게 필요한 경우 네이티브 지원이 제공되는 최신 TPU 세대에서 실행해야 합니다.
기능 요청: 하드웨어에서 이 작업을 지원한다고 생각되거나 호환성 모드에서도 컴파일러에 유효한 에뮬레이션 경로가 누락된 경우 기능 요청을 제출하세요. 일반적으로 작업의 호환성을 보장합니다. 따라서 커널이 특정 TPU 세대에서 실행되면 향후 모든 세대에서 실행되어야 합니다. 하지만 이전 세대에 대한 에뮬레이션이 보장되지는 않습니다 (일부의 경우 캐스팅이 매우 비쌈).