Registro AVX2 compacto para que os números inteiros selecionados sejam contíguos de acordo com a máscara [duplicado]
Pergunta
Na pergunta Otimizando a compactação de array, a resposta principal afirma:
Os registros SSE/AVX com os conjuntos de instruções mais recentes permitem uma abordagem melhor.Podemos usar o resultado do PMOVMSKB diretamente, transformando-o no registrador de controle de algo como PSHUFB.
Isso é possível com Haswell (AVX2)?Ou requer um dos sabores do AVX512?
Eu tenho um vetor AVX2 contendo int32s e um vetor correspondente do resultado de uma comparação.Quero embaralhá-lo de alguma forma para que os elementos com o msb correspondente definido na máscara (compare true) sejam contíguos na extremidade inferior do vetor.
O melhor que posso ver é obter uma máscara de bits com _mm256_movemask_ps/vmovmskps (sem variante *d?) e então usá-la em uma tabela de pesquisa de vetor 256 AVX2 para obter uma máscara aleatória para a faixa cruzada _mm256_permutevar8x32_epi32/vpermd
Solução
A primeira coisa a fazer é encontrar uma função escalar rápida.Aqui está uma versão que não usa branch.
inline int compact(int *x, int *y, const int n) {
int cnt = 0;
for(int i=0; i<n; i++) {
int cut = x[i]!=0;
y[cnt] = cut*x[i];
cnt += cut;
}
return cnt;
}
O melhor resultado com SIMD provavelmente depende da distribuição de zeros.Se for esparso ou denso.O código a seguir deve funcionar bem para distribuições esparsas ou densas.Por exemplo, longas séries de zeros e diferentes de zeros.Se a distribuição for mais uniforme não sei se esse código terá algum benefício.Mas dará o resultado correto de qualquer maneira.
Aqui está uma versão AVX2 que testei.
int compact_AVX2(int *x, int *y, int n) {
int i =0, cnt = 0;
for(i=0; i<n-8; i+=8) {
__m256i x4 = _mm256_loadu_si256((__m256i*)&x[i]);
__m256i cmp = _mm256_cmpeq_epi32(x4, _mm256_setzero_si256());
int mask = _mm256_movemask_epi8(cmp);
if(mask == -1) continue; //all zeros
if(mask) {
cnt += compact(&x[i],&y[cnt], 8);
}
else {
_mm256_storeu_si256((__m256i*)&y[cnt], x4);
cnt +=8;
}
}
cnt += compact(&x[i], &y[cnt], n-i); // cleanup for n not a multiple of 8
return cnt;
}
Aqui está a versão SSE2 que testei.
int compact_SSE2(int *x, int *y, int n) {
int i =0, cnt = 0;
for(i=0; i<n-4; i+=4) {
__m128i x4 = _mm_loadu_si128((__m128i*)&x[i]);
__m128i cmp = _mm_cmpeq_epi32(x4, _mm_setzero_si128());
int mask = _mm_movemask_epi8(cmp);
if(mask == 0xffff) continue; //all zeroes
if(mask) {
cnt += compact(&x[i],&y[cnt], 4);
}
else {
_mm_storeu_si128((__m128i*)&y[cnt], x4);
cnt +=4;
}
}
cnt += compact(&x[i], &y[cnt], n-i); // cleanup for n not a multiple of 4
return cnt;
}
Aqui está um teste completo
#include <stdio.h>
#include <stdlib.h>
#if defined (__GNUC__) && ! defined (__INTEL_COMPILER)
#include <x86intrin.h>
#else
#include <immintrin.h>
#endif
#define N 50
inline int compact(int *x, int *y, const int n) {
int cnt = 0;
for(int i=0; i<n; i++) {
int cut = x[i]!=0;
y[cnt] = cut*x[i];
cnt += cut;
}
return cnt;
}
int compact_SSE2(int *x, int *y, int n) {
int i =0, cnt = 0;
for(i=0; i<n-4; i+=4) {
__m128i x4 = _mm_loadu_si128((__m128i*)&x[i]);
__m128i cmp = _mm_cmpeq_epi32(x4, _mm_setzero_si128());
int mask = _mm_movemask_epi8(cmp);
if(mask == 0xffff) continue; //all zeroes
if(mask) {
cnt += compact(&x[i],&y[cnt], 4);
}
else {
_mm_storeu_si128((__m128i*)&y[cnt], x4);
cnt +=4;
}
}
cnt += compact(&x[i], &y[cnt], n-i); // cleanup for n not a multiple of 4
return cnt;
}
int compact_AVX2(int *x, int *y, int n) {
int i =0, cnt = 0;
for(i=0; i<n-8; i+=8) {
__m256i x4 = _mm256_loadu_si256((__m256i*)&x[i]);
__m256i cmp = _mm256_cmpeq_epi32(x4, _mm256_setzero_si256());
int mask = _mm256_movemask_epi8(cmp);
if(mask == -1) continue; //all zeros
if(mask) {
cnt += compact(&x[i],&y[cnt], 8);
}
else {
_mm256_storeu_si256((__m256i*)&y[cnt], x4);
cnt +=8;
}
}
cnt += compact(&x[i], &y[cnt], n-i); // cleanup for n not a multiple of 8
return cnt;
}
int main() {
int x[N], y[N];
for(int i=0; i<N; i++) x[i] = rand()%10;
//int cnt = compact_SSE2(x,y,N);
int cnt = compact_AVX2(x,y,N);
for(int i=0; i<N; i++) printf("%d ", x[i]); printf("\n");
for(int i=0; i<cnt; i++) printf("%d ", y[i]); printf("\n");
}