Last active
October 12, 2024 23:02
-
-
Save python273/c709de026ce43684292c29cf1f43e7ee to your computer and use it in GitHub Desktop.
https://contest.com/docs/ML-Competition-2023-r2 https://contest.com/ml2023-r2 https://t.me/contest/353
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "tglang.h" | |
#include "weights.h" | |
#include <stdlib.h> | |
#include <string.h> | |
#include <stdint.h> | |
#include <stdio.h> | |
#include <math.h> | |
static inline float silu(float x) { | |
return x / (1.0 + expf(-x)); | |
} | |
int probes_binary_search(int l, int r, uint32_t x) { | |
if (r >= l) { | |
int mid = l + (r - l) / 2; | |
if (probes[mid] == x) return mid; | |
if (probes[mid] > x) return probes_binary_search(l, mid - 1, x); | |
return probes_binary_search(mid + 1, r, x); | |
} | |
return -1; | |
} | |
int compare_slices(const void* a, const void* b) { | |
const uint32_t* val1 = (const uint32_t*)a; | |
const uint32_t* val2 = (const uint32_t*)b; | |
if (*val1 < *val2) { | |
return -1; | |
} else if (*val1 > *val2) { | |
return 1; | |
} | |
return 0; | |
} | |
const char *remove_prefixes(const char *text) { | |
const char *prefixes[] = { | |
"python\n", "javascript\n", "java\n", "lua\n", "bash\n", "csharp\n", "html\n", | |
"js\n", "php\n", "kotlin\n", "dart\n", "c\n", "sql\n", "css\n", "cpp\n", | |
"c++\n", "rust\n", | |
}; | |
const size_t num_prefixes = sizeof(prefixes) / sizeof(prefixes[0]); | |
for (size_t i = 0; i < num_prefixes; ++i) { | |
size_t prefix_len = strlen(prefixes[i]); | |
if (strncmp(text, prefixes[i], prefix_len) == 0) { | |
text += prefix_len; | |
break; | |
} | |
} | |
return text; | |
} | |
enum TglangLanguage tglang_detect_programming_language(const char *text) { | |
float l1[l1_OUT] = {0.0}; | |
float l2[l2_OUT] = {0.0}; | |
float l3[l3_OUT] = {0.0}; | |
float l4[l4_OUT] = {0.0}; | |
float classifier[classifier_OUT] = {0.0}; | |
while (*text == '\n' || *text == ' ') { ++text; } | |
text = remove_prefixes(text); | |
while (*text == '\n' || *text == ' ') { ++text; } | |
size_t text_len = strlen(text); | |
// printf("len %ld\n", text_len); | |
if (text_len < 6) return TGLANG_LANGUAGE_OTHER; | |
if (text_len > 4096*4) text_len = 4096*4; | |
// we need unique 4 byte slices from text | |
uint32_t slices[4096*4]; | |
uint32_t slice = 0x0A0A0A0A; | |
size_t slices_index = 0; | |
for (; slices_index < text_len; slices_index++) { | |
unsigned char c = (unsigned char)text[slices_index]; | |
if (c == '\t') c = ' '; | |
if (c == '\r') c = ' '; | |
slice = slice << 8 | c; | |
slices[slices_index] = slice; | |
} | |
qsort(slices, slices_index, sizeof(slices[0]), compare_slices); | |
uint32_t prev_slice = -1; | |
uint8_t found_slice = 0; | |
for (size_t i = 0; i < slices_index; i++) { | |
uint32_t slice = slices[i]; | |
if (slice == prev_slice) { continue; } | |
prev_slice = slice; | |
int ind = probes_binary_search(0, (sizeof(probes) / sizeof(probes[0]))-1, slice); | |
if (ind == -1) continue; | |
found_slice = 1; | |
// printf("%d ", v); | |
for (int j = 0; j < l1_OUT; j++) { | |
l1[j] += l1_weight[ind][j]; | |
} | |
} | |
if (found_slice == 0) return TGLANG_LANGUAGE_OTHER; | |
// printf("\n"); | |
for (int i = 0; i < l1_OUT; i++) { | |
l1[i] += l1_bias[i]; | |
} | |
for (int j = 0; j < l2_IN; j++) { | |
for (int i = 0; i < l2_OUT; i++) { | |
l2[i] += l1[j] * l2_weight[j][i]; | |
} | |
} | |
for (int i = 0; i < l2_OUT; i++) { | |
l2[i] = silu(l2[i] + l2_bias[i]); | |
} | |
for (int j = 0; j < l3_IN; j++) { | |
for (int i = 0; i < l3_OUT; i++) { | |
l3[i] += l2[j] * l3_weight[j][i]; | |
} | |
} | |
for (int i = 0; i < l3_OUT; i++) { | |
l2[i] += silu(l3[i] + l3_bias[i]); | |
} | |
for (int j = 0; j < l4_IN; j++) { | |
for (int i = 0; i < l4_OUT; i++) { | |
l4[i] += l2[j] * l4_weight[j][i]; | |
} | |
} | |
for (int i = 0; i < l4_OUT; i++) { | |
l2[i] += silu(l4[i] + l4_bias[i]); | |
} | |
for (int j = 0; j < classifier_IN; j++) { | |
for (int i = 0; i < classifier_OUT; i++) { | |
classifier[i] += l2[j] * classifier_weight[j][i]; | |
} | |
} | |
// for (int i = 0; i < classifier_OUT; i++) { | |
// printf("%f ", classifier[i]); | |
// } | |
// printf("\n"); | |
size_t argmax = 0; | |
for (int i = 1; i < classifier_OUT; i++) { | |
if (classifier[i] > classifier[argmax]) { | |
argmax = i; | |
} | |
} | |
// printf("%ld\n", argmax); | |
return class_mapping[argmax]; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# for reference, actual definition slightly different in weights init | |
PROBES_NUM = 177606 | |
d_l1 = 32 | |
d_model = 512 | |
class Net(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.emb = nn.Embedding(PROBES_NUM+1, d_l1, padding_idx=PROBES_NUM) | |
self.emb_bias = nn.Parameter(torch.zeros([d_l1])) | |
self.l2 = nn.Linear(d_l1, d_model) | |
self.l3 = nn.Linear(d_model, d_model) | |
self.l4 = nn.Linear(d_model, d_model) | |
self.classifier = nn.Linear(d_model, LABEL_LEN, bias=False) | |
def forward(self, x): | |
x = self.emb(x) | |
x = x.sum(dim=1) | |
x = x + self.emb_bias | |
x = F.silu(self.l2(x)) | |
x = x + F.silu(self.l3(x)) | |
x = x + F.silu(self.l4(x)) | |
x = self.classifier(x) | |
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import timeit | |
import ctypes | |
with open('tglang.h') as f: | |
s = f.read().split('enum TglangLanguage {', 1)[1].split('}', 1)[0] | |
TglangLanguage = [i.strip().removesuffix(',') for i in s.split('\n') if i] | |
libtglang = ctypes.CDLL('./build/libtglang.so') | |
libtglang.tglang_detect_programming_language.argtypes = [ctypes.c_char_p] | |
libtglang.tglang_detect_programming_language.restype = ctypes.c_int | |
print(libtglang) | |
def detect_programming_language(text): | |
r = libtglang.tglang_detect_programming_language(text) | |
return TglangLanguage[r] | |
print(detect_programming_language(b'SELECT * FROM users;')) | |
with open('../train_model.py', 'rb') as f: | |
s = f.read() | |
print(len(s)) | |
s = s[:4096*4] | |
print(len(s)) | |
print(detect_programming_language(s)) | |
s = timeit.timeit(lambda: detect_programming_language(s), number=10000) | |
print(f'{s/10000} {10000/s}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment