-
-
Save t-mat/d6664e97c8b6407088356b3867d43537 to your computer and use it in GitHub Desktop.
SIMD functions to apply toupper/tolower to each character in a string
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Created by easyaspi314. Released into the public domain. | |
// test:$ gcc -msse -DTEST change_case_simd.c && ./a.out | |
// > cl /EHsc /DTEST change_case_simd.c && .\change_case_simd.exe | |
#include <ctype.h> | |
#include <stdio.h> | |
#include <stdlib.h> | |
#include <string.h> | |
#ifdef DEMONSTRATE_BASIC_ALGORITHM | |
#include <stdint.h> | |
#include <immintrin.h> | |
// | |
// To reduce the number of instructions, we can use the following classic | |
// comparison technique: | |
// | |
// int8_t src = ...; // input UTF-8 byte | |
// int8_t e = (int8_t) (src - 'A' + 0x80); | |
// int8_t u = (int8_t) ('Z' - 'A' + 0x80 + 1); | |
// bool isUpper1 = e < u; | |
// | |
// note: u = -102 = (0x5a - 0x41 + 0x80 + 1) | |
// | |
// This trick works because (in int8_t) it is equivalent to | |
// | |
// isUpper1 = (-128 <= e && e < -102) | |
// | |
// We can confirm it in the following table: | |
// | |
// | src | src-'A'+0x80 | -102 > src-'A'+0x80 | | |
// | ------- | ------------ | ------------------- | | |
// | 0xc0 | 0xff = -1 | false | | |
// | ... | ... | false | | |
// | '['=0x5b | 0x9a = -102 | false | | |
// | 'Z'=0x5a | 0x99 = -103 | true | | |
// | 'Y'=0x59 | 0x98 = -104 | true | | |
// | ... | ... | true | | |
// | 'B'=0x42 | 0x81 = -127 | true | | |
// | 'A'=0x41 | 0x80 = -128 | true | | |
// | '@'=0x40 | 0x7f = +127 = -129 & 255 | false | | |
// | ... | ... | false | | |
// | 0x01 | 0x40 = +64 | false | | |
// | 0x00 | 0x3f = +63 | false | | |
// | 0xff | 0x3e = +62 | false | | |
// | ... | ... | false | | |
// | 0xc1 | 0x00 = +0 | false | | |
// | |
static int8_t CharToLower(int8_t src) { | |
int8_t o1 = (int8_t)(0x80 - 'A'); | |
int8_t c2 = (int8_t)('Z' - 'A' + 1 + 0x80); | |
int8_t d3 = (int8_t)('a' - 'A'); | |
int8_t a1 = src + o1; | |
int8_t a2 = c2 > a1 ? 0xff : 0; | |
int8_t a3 = a2 & d3; | |
int8_t a4 = src + a3; | |
return a4; | |
} | |
static int8_t CharToUpper(int8_t src) { | |
int8_t o1 = (int8_t)(0x80 - 'a'); | |
int8_t c2 = (int8_t)('z' - 'a' + 1 + 0x80); | |
int8_t d3 = (int8_t)('A' - 'a'); | |
int8_t a1 = src + o1; | |
int8_t a2 = c2 > a1 ? 0xff : 0; | |
int8_t a3 = a2 & d3; | |
int8_t a4 = src + a3; | |
return a4; | |
} | |
// Pseudo SIMD | |
typedef union U128 { | |
uint8_t u8[16]; | |
int8_t i8[16]; | |
__m128i m128; | |
} U128; | |
static U128 CharToLowerU128(U128 src) { | |
const int8_t asciiAofs80 = (int8_t)(0x80 - 'A'); | |
const int8_t AtoZp1ofs80 = (int8_t)('Z' - 'A' + 1 + 0x80); | |
const int8_t diff = (int8_t)('a' - 'A'); | |
U128 dst; | |
for(int i = 0; i < 16; ++i) { | |
int8_t a1 = src.i8[i] + asciiAofs80; | |
int8_t a2 = AtoZp1ofs80 > a1 ? 0xff : 0; | |
int8_t a3 = a2 & diff; | |
int8_t a4 = src.i8[i] + a3; | |
dst.i8[i] = a4; | |
} | |
return dst; | |
} | |
static U128 CharToUpperU128(U128 src) { | |
const int8_t asciiAofs80 = (int8_t)(0x80 - 'a'); | |
const int8_t AtoZp1ofs80 = (int8_t)('z' - 'a' + 1 + 0x80); | |
const int8_t diff = (int8_t)('A' - 'a'); | |
U128 dst; | |
for(int i = 0; i < 16; ++i) { | |
int8_t a1 = src.i8[i] + asciiAofs80; | |
int8_t a2 = AtoZp1ofs80 > a1 ? 0xff : 0; | |
int8_t a3 = a2 & diff; | |
int8_t a4 = src.i8[i] + a3; | |
dst.i8[i] = a4; | |
} | |
return dst; | |
} | |
#endif | |
#if defined(__SSE2__) || defined(_M_X64) | |
#include <immintrin.h> | |
// A SIMD function for SSE2 which changes all uppercase ASCII digits | |
// to lowercase. | |
void StringToLower(char *str) | |
{ | |
const __m128i asciiAofs80 = _mm_set1_epi8(-'A' + 0x80); | |
const __m128i AtoZp1ofs80 = _mm_set1_epi8((signed char) ('Z' - 'A' + 1 + 0x80)); | |
const __m128i diff = _mm_set1_epi8('a' - 'A'); | |
size_t len = strlen(str); | |
while (len >= 16) { | |
__m128i src = _mm_loadu_si128((__m128i*)str); | |
__m128i a1 = _mm_add_epi8(src, asciiAofs80); | |
__m128i a2 = _mm_cmpgt_epi8(AtoZp1ofs80, a1); | |
__m128i a3 = _mm_and_si128(a2, diff); | |
__m128i a4 = _mm_add_epi8(src, a3); | |
_mm_storeu_si128((__m128i *)str, a4); | |
len -= 16; | |
str += 16; | |
} | |
while (len-- > 0) { | |
*str = tolower(*str); | |
++str; | |
} | |
} | |
// Same, but to uppercase. | |
void StringToUpper(char *str) | |
{ | |
const __m128i asciiAofs80 = _mm_set1_epi8(-'a' + 0x80); | |
const __m128i AtoZp1ofs80 = _mm_set1_epi8((signed char) ('z' - 'a' + 1 + 0x80)); | |
const __m128i diff = _mm_set1_epi8('A' - 'a'); | |
size_t len = strlen(str); | |
while (len >= 16) { | |
__m128i src = _mm_loadu_si128((__m128i*)str); | |
__m128i a1 = _mm_add_epi8(src, asciiAofs80); | |
__m128i a2 = _mm_cmpgt_epi8(AtoZp1ofs80, a1); // Set 0xff if ('Z'-'A') >= a1.u8[i] >= 0, else 0x00 | |
__m128i a3 = _mm_and_si128(a2, diff); // Set 'a'-'A' if ['A', 'Z'], else 0x00 | |
__m128i a4 = _mm_add_epi8(src, a3); // | |
_mm_storeu_si128((__m128i *)str, a4); | |
len -= 16; | |
str += 16; | |
} | |
while (len-- > 0) { | |
*str = toupper(*str); | |
++str; | |
} | |
} | |
#else | |
/* Just go scalar. */ | |
void StringToLower(char *str) | |
{ | |
size_t len = strlen(str); | |
while (len-- > 0) { | |
*str = tolower(*str); | |
++str; | |
} | |
} | |
void StringToUpper(char *str) | |
{ | |
size_t len = strlen(str); | |
while (len-- > 0) { | |
*str = toupper(*str); | |
++str; | |
} | |
} | |
#endif | |
#ifdef TEST | |
#include <assert.h> | |
static void libc_str_lower(char* str, size_t strLengthInBytes) { | |
for(size_t i = 0; i < strLengthInBytes; ++i) { | |
str[i] = tolower(str[i]); | |
} | |
} | |
static void libc_str_upper(char* str, size_t strLengthInBytes) { | |
for(size_t i = 0; i < strLengthInBytes; ++i) { | |
str[i] = toupper(str[i]); | |
} | |
} | |
static void hexDump(const void* data, size_t dataSizeInBytes) { | |
const unsigned char* const p = (const unsigned char*) data; | |
for(size_t i = 0; i < dataSizeInBytes; ++i) { | |
if(i % 16 == 0) { | |
if(i != 0) { | |
printf("\n"); | |
} | |
printf("0x%04x: ", (int) i); | |
} | |
unsigned char c = p[i]; | |
if(c >= 0x21 && c < 0x7f) { | |
printf("%c ", c); | |
} else { | |
printf("%02x", c); | |
} | |
if(i + 1 >= dataSizeInBytes) { | |
printf("\n"); | |
} else { | |
printf(" "); | |
} | |
} | |
} | |
int main() | |
{ | |
int errorCount = 0; | |
{ | |
// 0123456789abcdef0123456789abcdef | |
char str[] = "Hello world 12345 HI ABXYZ abxyz"; | |
StringToLower(str); | |
puts(str); | |
if(strcmp(str, "hello world 12345 hi abxyz abxyz") != 0) { | |
errorCount += 1; | |
printf("%s(%d) : FAIL\n", __FILE__, __LINE__); | |
} | |
StringToUpper(str); | |
puts(str); | |
if(strcmp(str, "HELLO WORLD 12345 HI ABXYZ ABXYZ") != 0) { | |
errorCount += 1; | |
printf("%s(%d) : FAIL\n", __FILE__, __LINE__); | |
} | |
} | |
{ | |
// 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef | |
char str[] = "Hello world 12345 HI ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz #PAD#"; | |
char lwr[] = "hello world 12345 hi abcdefghijklmnopqrstuvwxyz abcdefghijklmnopqrstuvwxyz #pad#"; | |
char upr[] = "HELLO WORLD 12345 HI ABCDEFGHIJKLMNOPQRSTUVWXYZ ABCDEFGHIJKLMNOPQRSTUVWXYZ #PAD#"; | |
StringToLower(str); | |
puts(str); | |
if(strcmp(str, lwr) != 0) { | |
errorCount += 1; | |
printf("%s(%d) : FAIL\n", __FILE__, __LINE__); | |
} | |
StringToUpper(str); | |
puts(str); | |
if(strcmp(str, upr) != 0) { | |
errorCount += 1; | |
printf("%s(%d) : FAIL\n", __FILE__, __LINE__); | |
} | |
} | |
// All possible chars except 0x00 | |
{ | |
char strSrc[256 + 1]; | |
for(int i = 0; i < sizeof(strSrc); ++i) { | |
strSrc[i] = (char) (i == 0 ? ' ' : i); | |
} | |
strSrc[256] = 0; | |
for(int iLoop = 0; iLoop <= 1; ++iLoop) { | |
char libcStr[sizeof(strSrc)]; | |
char simdStr[sizeof(strSrc)]; | |
memcpy(libcStr, strSrc, sizeof(libcStr)); | |
memcpy(simdStr, strSrc, sizeof(simdStr)); | |
switch(iLoop) { | |
case 0: | |
libc_str_lower(libcStr, sizeof(libcStr) - 1); | |
StringToLower(simdStr); | |
printf("StringToLower:\n"); | |
break; | |
case 1: | |
libc_str_upper(libcStr, sizeof(libcStr) - 1); | |
StringToUpper(simdStr); | |
printf("StringToUpper:\n"); | |
break; | |
default: | |
assert(0); | |
break; | |
} | |
if(memcmp(libcStr, simdStr, sizeof(libcStr)) != 0) { | |
errorCount += 1; | |
printf("\n%s(%d) : FAIL\n", __FILE__, __LINE__); | |
printf("Expected:\n"); | |
hexDump(libcStr, sizeof(libcStr)-1); | |
printf("Actual:\n"); | |
hexDump(simdStr, sizeof(simdStr)-1); | |
} else { | |
hexDump(simdStr, sizeof(simdStr)-1); | |
} | |
} | |
#ifdef DEMONSTRATE_BASIC_ALGORITHM | |
for(int iLoop = 0; iLoop <= 3; ++iLoop) { | |
char libcStr[sizeof(strSrc)]; | |
char simdStr[sizeof(strSrc)]; | |
memcpy(libcStr, strSrc, sizeof(libcStr)); | |
memcpy(simdStr, strSrc, sizeof(simdStr)); | |
switch(iLoop) { | |
case 0: | |
libc_str_lower(libcStr, sizeof(libcStr) - 1); | |
for(int i = 0; i < sizeof(simdStr)-1; ++i) { | |
simdStr[i] = CharToLower(simdStr[i]); | |
} | |
StringToLower(simdStr); | |
printf("StringToLower:\n"); | |
break; | |
case 1: | |
libc_str_upper(libcStr, sizeof(libcStr) - 1); | |
for(int i = 0; i < sizeof(simdStr)-1; ++i) { | |
simdStr[i] = CharToUpper(simdStr[i]); | |
} | |
printf("StringToUpper:\n"); | |
break; | |
case 2: | |
libc_str_lower(libcStr, sizeof(libcStr) - 1); | |
for(int i = 0; i < sizeof(simdStr)-1; i += 16) { | |
U128 src; | |
src.m128= * (__m128i*) &simdStr[i]; | |
U128 dst = CharToLowerU128(src); | |
* (__m128i*) &simdStr[i] = dst.m128; | |
} | |
printf("StringToUpper:\n"); | |
break; | |
case 3: | |
libc_str_upper(libcStr, sizeof(libcStr) - 1); | |
for(int i = 0; i < sizeof(simdStr)-1; i += 16) { | |
U128 src; | |
src.m128= * (__m128i*) &simdStr[i]; | |
U128 dst = CharToUpperU128(src); | |
* (__m128i*) &simdStr[i] = dst.m128; | |
} | |
printf("StringToUpper:\n"); | |
break; | |
default: | |
assert(0); | |
break; | |
} | |
if(memcmp(libcStr, simdStr, sizeof(libcStr)) != 0) { | |
errorCount += 1; | |
printf("\n%s(%d) : FAIL\n", __FILE__, __LINE__); | |
printf("Expected:\n"); | |
hexDump(libcStr, sizeof(libcStr)-1); | |
printf("Actual:\n"); | |
hexDump(simdStr, sizeof(simdStr)-1); | |
} else { | |
hexDump(simdStr, sizeof(simdStr)-1); | |
} | |
} | |
#endif | |
} | |
return errorCount == 0 ? EXIT_SUCCESS : EXIT_FAILURE; | |
} | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment