Adding support for the granite multilingual embeddings R2 (ibm-granite/granite-embedding-{97,311}m-multilingual-r2 models)#22716
Conversation
34541a7 to
4f283cf
Compare
|
@gabe-l-hart here is the PR. |
|
Thanks @hansolosan! I'll take a first pass review in the next day or two and notify maintainers once we're ready for final review. |
gabe-l-hart
left a comment
There was a problem hiding this comment.
I think it would be good to make the hparam more flexible for future models that need it.
| // FFN gated activation flavor (used by ModernBert/derivatives that may use | ||
| // SwiGLU instead of the default GeGLU). The graph for those archs reads | ||
| // this to pick LLM_FFN_SWIGLU vs LLM_FFN_GEGLU. | ||
| bool ffn_act_swiglu = false; |
There was a problem hiding this comment.
NIT: Most model-specific hparams live towards the bottom of the field declaration. This effects how structs are initialized and while this repo doesn't ever use direct initialization for hparams, other tools that use this header (yes that violates encapsulation, but it's the internet), might.
There was a problem hiding this comment.
Less-NIT: In the GGUF, this looks like it's represented as a string, but here it's a bool which limits the future usability. I think it would be cleaner to use llm_ffn_op_type (declared in llama-graph.h, so available here). This would also avoid the need for the ternary above.
If we go that route, we could also align the name as ffn_op. Further even, we could add a helper in llama-graph.* to do the enum <-> string mapping so it's centralized and reusable.
There was a problem hiding this comment.
Hi Gabe - looks like llama-graph.h includes this file (llama-hparams.h) - we could move the field llm_ffn_op_type to llama-arch.h and have it included from here.
There was a problem hiding this comment.
Ah, very good point. There are a number of enums declared in llama-hparams.h so I think the best solution would be to move the declaration of llm_ffn_op_type to llama-hparams.h (I think llama-arch.h is reserved for GGUF KVs and tensor names). This kind of header rearranging definitely requires thoughts from the maintainers, though.
cc @CISC @ggerganov for thoughts on moving llm_ffn_op_type
There was a problem hiding this comment.
I don't think that's a good idea, this should probably be handled in llama-model.cpp just like rope_scaling_type.
There was a problem hiding this comment.
@gabe-l-hart done - I moved the structure to llama-params.h and made the type change. I also checked the regex change - it has a beneficial speed effect in several languages: Bengali, Hindi, Telugu, and Thai,with no speed regression on a few other languages:
Pre-tokenization Regex Benchmark (C++ std::regex, collapsed-byte approach)
Old: without \p{M} in lookaheads (pre-PR)
New: with \p{M} in lookaheads (post-PR fix)
Files: 17, samples/file: 10000
Results by language
| Lang | Samples | Bytes | Cpts | Time_old | MB/s | Time_new | MB/s | Ratio | Mismatches |
|---|---|---|---|---|---|---|---|---|---|
| ar | 10000 | 4660367 | 2615782 | 901.9 ms | 4.93 | 889.0 ms | 5.00 | 1.015 | 3219 |
| bn | 10000 | 11936213 | 4492159 | 2666.8 ms | 4.27 | 1493.0 ms | 7.62 | 1.786 | 9933 |
| en | 10000 | 3335006 | 3320259 | 978.0 ms | 3.25 | 975.1 ms | 3.26 | 1.003 | 85 |
| es | 10000 | 3629385 | 3559269 | 1016.9 ms | 3.40 | 1016.5 ms | 3.40 | 1.000 | 2 |
| fa | 10000 | 5241562 | 2910480 | 1070.0 ms | 4.67 | 1062.4 ms | 4.70 | 1.007 | 3202 |
| fi | 10000 | 3384955 | 3270612 | 854.5 ms | 3.78 | 853.3 ms | 3.78 | 1.001 | 4 |
| fr | 10000 | 4180418 | 4042979 | 1122.1 ms | 3.55 | 1122.2 ms | 3.55 | 1.000 | 6 |
| hi | 10000 | 11837824 | 4607603 | 2688.4 ms | 4.20 | 1612.5 ms | 7.00 | 1.667 | 9951 |
| id | 10000 | 4760430 | 4752414 | 1218.5 ms | 3.73 | 1218.6 ms | 3.73 | 1.000 | 41 |
| ja | 10000 | 3634512 | 1311999 | 397.8 ms | 8.71 | 398.1 ms | 8.71 | 0.999 | 3 |
| ko | 10000 | 4279015 | 1832329 | 753.1 ms | 5.42 | 753.1 ms | 5.42 | 1.000 | 4 |
| ru | 10000 | 4993773 | 2768805 | 891.2 ms | 5.34 | 893.1 ms | 5.33 | 0.998 | 398 |
| sw | 10000 | 2242871 | 2237672 | 605.3 ms | 3.53 | 605.1 ms | 3.53 | 1.000 | 11 |
| te | 10000 | 10521870 | 3988310 | 2368.7 ms | 4.24 | 1237.4 ms | 8.11 | 1.914 | 9620 |
| th | 10000 | 11484151 | 4093380 | 1584.5 ms | 6.91 | 938.9 ms | 11.66 | 1.688 | 9930 |
| yo | 10000 | 1786712 | 1433300 | 467.2 ms | 3.65 | 441.9 ms | 3.86 | 1.057 | 3172 |
| zh | 10000 | 5184752 | 1830799 | 556.8 ms | 8.88 | 556.7 ms | 8.88 | 1.000 | 6 |
| TOTAL | 170000 | 97093816 | 20141.7 ms | 4.60 | 16067.1 ms | 5.76 | 1.254 | 49587 |
Ratio > 1.0 means old is slower; < 1.0 means new is slower.
Mismatches = samples where old/new produce different pre-token counts.
==============================================================================
This is the program I ran:
// Benchmark pre-tokenization: old regex (no \p{M}) vs new regex (with \p{M})
// Uses the same collapsed-regex + std::regex approach as llama.cpp's unicode.cpp
//
// Build:
// g++ -O2 -std=c++17 -I src -o bench_pretokenize_cpp bench_pretokenize_cpp.cpp \
// src/unicode.cpp src/unicode-data.cpp -lpthread
//
// Run:
// ./bench_pretokenize_cpp --fof files.txt --num-samples 1000
//
// The --fof file contains one path per line (files can be .jsonl or .jsonl.bz2).
// Each JSONL line must have a "text" field.
// Language is inferred from the filename (e.g., "en_wiki.jsonl.bz2" -> "en").
#include "unicode.h"
#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <map>
#include <regex>
#include <string>
#include <vector>
// ---------- regex patterns ----------
// Old GPT-4o approximation (before PR: no \p{M} in lookaheads)
static const char * REGEX_OLD =
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+"
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|"
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*"
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|"
"\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
// New GPT-4o approximation (after PR: \p{M} added to lookaheads)
static const char * REGEX_NEW =
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}\\p{M}])([^a-z]))*((?=[\\p{L}\\p{M}])([^A-Z]))+"
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|"
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}\\p{M}])([^a-z]))+((?=[\\p{L}\\p{M}])([^A-Z]))*"
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|"
"\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
// ---------- collapsed regex logic (mirrors unicode.cpp) ----------
static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", unicode_cpt_flags::NUMBER },
{ "\\p{L}", unicode_cpt_flags::LETTER },
{ "\\p{P}", unicode_cpt_flags::PUNCTUATION },
{ "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
{ "\\p{S}", unicode_cpt_flags::SYMBOL },
{ "\\p{Lu}", unicode_cpt_flags::LETTER },
{ "\\p{Ll}", unicode_cpt_flags::LETTER },
{ "\\p{Lt}", unicode_cpt_flags::LETTER },
{ "\\p{Lm}", unicode_cpt_flags::LETTER },
{ "\\p{Lo}", unicode_cpt_flags::LETTER },
};
static const std::map<int, int> k_ucat_cpt = {
{ unicode_cpt_flags::NUMBER, 0xD1 },
{ unicode_cpt_flags::LETTER, 0xD2 },
{ unicode_cpt_flags::PUNCTUATION, 0xD3 },
{ unicode_cpt_flags::ACCENT_MARK, 0xD4 },
{ unicode_cpt_flags::SYMBOL, 0xD5 },
};
static const std::map<int, std::string> k_ucat_map = {
{ unicode_cpt_flags::NUMBER, "\x30-\x39" },
{ unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" },
{ unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" },
{ unicode_cpt_flags::ACCENT_MARK, "" },
{ unicode_cpt_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" },
};
static std::string collapse_text(const std::vector<uint32_t> & cpts) {
std::string out(cpts.size(), '\0');
for (size_t i = 0; i < cpts.size(); ++i) {
if (cpts[i] < 128) {
out[i] = (char)cpts[i];
continue;
}
const auto flags = unicode_cpt_flags_from_cpt(cpts[i]);
if (flags.is_whitespace) {
out[i] = (char)0x0B;
} else {
auto it = k_ucat_cpt.find(flags.category_flag());
if (it != k_ucat_cpt.end()) {
out[i] = (char)it->second;
} else {
out[i] = (char)0xD0;
}
}
}
return out;
}
static std::string collapse_regex(const std::string & regex_expr) {
std::string out;
bool inside = false;
for (size_t i = 0; i < regex_expr.size(); ++i) {
if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
out += '[';
inside = true;
continue;
}
if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
out += ']';
inside = false;
continue;
}
if (regex_expr[i] == '\\' && i + 3 < regex_expr.size() &&
regex_expr[i + 1] == 'p' && regex_expr[i + 2] == '{') {
size_t closing = regex_expr.find('}', i + 3);
if (closing != std::string::npos && closing <= i + 10) {
const std::string pat = regex_expr.substr(i, closing - i + 1);
auto it = k_ucat_enum.find(pat);
if (it != k_ucat_enum.end()) {
if (!inside) out += '[';
out += (char)k_ucat_cpt.at(it->second);
out += k_ucat_map.at(it->second);
if (!inside) out += ']';
i = closing;
continue;
}
}
}
out += regex_expr[i];
}
return out;
}
// ---------- JSON text extraction ----------
// Minimal: extract "text" field value from a JSONL line.
// Handles escaped quotes within the value.
static bool extract_text_field(const std::string & line, std::string & out) {
// Find "text" key
const char * patterns[] = { "\"text\":", "\"text\" :" };
size_t key_end = std::string::npos;
for (auto p : patterns) {
size_t pos = line.find(p);
if (pos != std::string::npos) {
key_end = pos + strlen(p);
break;
}
}
if (key_end == std::string::npos) return false;
// Skip whitespace after colon
while (key_end < line.size() && (line[key_end] == ' ' || line[key_end] == '\t')) key_end++;
if (key_end >= line.size() || line[key_end] != '"') return false;
// Parse string value (handle escapes)
out.clear();
size_t i = key_end + 1;
while (i < line.size()) {
if (line[i] == '\\' && i + 1 < line.size()) {
char c = line[i + 1];
switch (c) {
case '"': out += '"'; break;
case '\\': out += '\\'; break;
case '/': out += '/'; break;
case 'n': out += '\n'; break;
case 'r': out += '\r'; break;
case 't': out += '\t'; break;
case 'b': out += '\b'; break;
case 'f': out += '\f'; break;
case 'u': {
// \uXXXX
if (i + 5 < line.size()) {
char hex[5] = {line[i+2], line[i+3], line[i+4], line[i+5], 0};
uint32_t cp = (uint32_t)strtoul(hex, nullptr, 16);
// Handle surrogate pairs
if (cp >= 0xD800 && cp <= 0xDBFF && i + 11 < line.size() &&
line[i+6] == '\\' && line[i+7] == 'u') {
char hex2[5] = {line[i+8], line[i+9], line[i+10], line[i+11], 0};
uint32_t cp2 = (uint32_t)strtoul(hex2, nullptr, 16);
cp = 0x10000 + ((cp - 0xD800) << 10) + (cp2 - 0xDC00);
i += 6; // skip the second \uXXXX
}
out += unicode_cpt_to_utf8(cp);
i += 4; // skip XXXX (loop will skip \u)
}
break;
}
default: out += c; break;
}
i += 2;
} else if (line[i] == '"') {
return true;
} else {
out += line[i];
i++;
}
}
return false; // unterminated string
}
// ---------- language detection from path ----------
// Extract language from the parent directory of the file.
// e.g., "/data/corpora/en/wiki_00.jsonl.bz2" -> "en"
// "/data/ru/corpus.jsonl" -> "ru"
// Falls back to filename-based heuristic if no parent directory.
static std::string detect_language(const std::string & path) {
// Find the last directory component before the filename
size_t last_slash = path.find_last_of('/');
if (last_slash != std::string::npos && last_slash > 0) {
size_t prev_slash = path.find_last_of('/', last_slash - 1);
size_t dir_start = (prev_slash != std::string::npos) ? prev_slash + 1 : 0;
std::string dir = path.substr(dir_start, last_slash - dir_start);
if (!dir.empty()) {
return dir;
}
}
// Fallback: use filename without extensions
std::string base = (last_slash != std::string::npos) ? path.substr(last_slash + 1) : path;
while (true) {
size_t dot = base.find_last_of('.');
if (dot == std::string::npos) break;
std::string ext = base.substr(dot);
if (ext == ".jsonl" || ext == ".json" || ext == ".bz2" || ext == ".gz" ||
ext == ".zst" || ext == ".txt" || ext == ".xz") {
base = base.substr(0, dot);
} else {
break;
}
}
return base;
}
// ---------- file reading (plain or bz2) ----------
struct line_reader {
FILE * fp = nullptr;
bool is_pipe = false;
bool open(const std::string & path) {
// Check if bz2
bool bz2 = (path.size() > 4 && path.substr(path.size() - 4) == ".bz2");
bool gz = (path.size() > 3 && path.substr(path.size() - 3) == ".gz");
bool zst = (path.size() > 4 && path.substr(path.size() - 4) == ".zst");
bool xz = (path.size() > 3 && path.substr(path.size() - 3) == ".xz");
if (bz2) {
std::string cmd = "bzcat '" + path + "'";
fp = popen(cmd.c_str(), "r");
is_pipe = true;
} else if (gz) {
std::string cmd = "zcat '" + path + "'";
fp = popen(cmd.c_str(), "r");
is_pipe = true;
} else if (zst) {
std::string cmd = "zstdcat '" + path + "'";
fp = popen(cmd.c_str(), "r");
is_pipe = true;
} else if (xz) {
std::string cmd = "xzcat '" + path + "'";
fp = popen(cmd.c_str(), "r");
is_pipe = true;
} else {
fp = fopen(path.c_str(), "r");
is_pipe = false;
}
return fp != nullptr;
}
bool getline(std::string & line) {
line.clear();
if (!fp) return false;
char buf[8192];
while (fgets(buf, sizeof(buf), fp)) {
line += buf;
if (!line.empty() && line.back() == '\n') {
line.pop_back();
return true;
}
}
return !line.empty();
}
void close() {
if (fp) {
if (is_pipe) pclose(fp); else fclose(fp);
fp = nullptr;
}
}
~line_reader() { close(); }
};
// ---------- per-language stats ----------
struct lang_stats {
std::string lang;
size_t n_samples = 0;
size_t total_bytes = 0;
size_t total_cpts = 0;
size_t tokens_old = 0;
size_t tokens_new = 0;
double time_old_ms = 0;
double time_new_ms = 0;
size_t mismatches = 0; // samples where old != new token count
};
// ---------- main ----------
static void usage(const char * prog) {
fprintf(stderr,
"Usage: %s --fof <file-of-files> [--num-samples N]\n"
"\n"
" --fof FILE File containing one input path per line.\n"
" Input files are JSONL (optionally .bz2/.gz/.zst/.xz).\n"
" Each line must have a \"text\" field.\n"
" --num-samples N Number of samples to process per file (default: 1000).\n"
"\n"
"Language is inferred from the input filename.\n"
"Reports pre-tokenization speed with old vs new regex, grouped by language.\n",
prog);
}
int main(int argc, char ** argv) {
std::string fof_path;
int num_samples = 1000;
for (int i = 1; i < argc; ++i) {
if (strcmp(argv[i], "--fof") == 0 && i + 1 < argc) {
fof_path = argv[++i];
} else if (strcmp(argv[i], "--num-samples") == 0 && i + 1 < argc) {
num_samples = atoi(argv[++i]);
} else if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) {
usage(argv[0]);
return 0;
} else {
fprintf(stderr, "Unknown argument: %s\n", argv[i]);
usage(argv[0]);
return 1;
}
}
if (fof_path.empty()) {
fprintf(stderr, "Error: --fof is required.\n\n");
usage(argv[0]);
return 1;
}
// Read file-of-files
std::vector<std::string> input_files;
{
std::ifstream fof(fof_path);
if (!fof.is_open()) {
fprintf(stderr, "Error: cannot open fof file: %s\n", fof_path.c_str());
return 1;
}
std::string line;
while (std::getline(fof, line)) {
// Trim whitespace
while (!line.empty() && (line.back() == '\r' || line.back() == ' ' || line.back() == '\t'))
line.pop_back();
while (!line.empty() && (line.front() == ' ' || line.front() == '\t'))
line.erase(line.begin());
if (!line.empty() && line[0] != '#') {
input_files.push_back(line);
}
}
}
if (input_files.empty()) {
fprintf(stderr, "Error: no files in fof.\n");
return 1;
}
// Compile collapsed regexes
std::string regex_old_collapsed = collapse_regex(REGEX_OLD);
std::string regex_new_collapsed = collapse_regex(REGEX_NEW);
std::regex re_old(regex_old_collapsed, std::regex_constants::optimize | std::regex_constants::nosubs);
std::regex re_new(regex_new_collapsed, std::regex_constants::optimize | std::regex_constants::nosubs);
printf("==============================================================================\n");
printf("Pre-tokenization Regex Benchmark (C++ std::regex, collapsed-byte approach)\n");
printf(" Old: without \\p{M} in lookaheads (pre-PR)\n");
printf(" New: with \\p{M} in lookaheads (post-PR fix)\n");
printf(" Files: %zu, samples/file: %d\n", input_files.size(), num_samples);
printf("==============================================================================\n\n");
// Accumulate per-language stats
std::map<std::string, lang_stats> stats_map;
// Group files by language, accumulate collapsed texts per language
std::map<std::string, std::vector<std::string>> lang_collapsed; // lang -> collapsed texts
std::map<std::string, size_t> lang_bytes; // lang -> total raw bytes
for (const auto & path : input_files) {
std::string lang = detect_language(path);
size_t slash = path.find_last_of('/');
std::string display = (slash != std::string::npos) ? path.substr(slash + 1) : path;
fprintf(stderr, " Reading: %s (lang=%s) ...", display.c_str(), lang.c_str());
fflush(stderr);
line_reader reader;
if (!reader.open(path)) {
fprintf(stderr, " FAILED to open, skipping.\n");
continue;
}
std::string line, text;
int count = 0;
while (count < num_samples && reader.getline(line)) {
if (!extract_text_field(line, text)) continue;
if (text.empty()) continue;
count++;
lang_bytes[lang] += text.size();
auto cpts = unicode_cpts_from_utf8(text);
lang_collapsed[lang].push_back(collapse_text(cpts));
}
reader.close();
fprintf(stderr, " %d samples\n", count);
}
// Now benchmark each language with warmup
const size_t WARMUP_SAMPLES = 200;
for (auto & [lang, collapsed_texts] : lang_collapsed) {
auto & st = stats_map[lang];
st.lang = lang;
st.n_samples = collapsed_texts.size();
st.total_bytes = lang_bytes[lang];
// Count total codepoints (collapsed text length == codepoint count)
for (const auto & c : collapsed_texts) st.total_cpts += c.size();
// Warmup phase: run both regexes on first min(WARMUP_SAMPLES, n) texts
size_t n_warmup = std::min(WARMUP_SAMPLES, collapsed_texts.size());
fprintf(stderr, " Benchmarking lang=%s (%zu samples, warmup=%zu) ...",
lang.c_str(), collapsed_texts.size(), n_warmup);
fflush(stderr);
for (size_t i = 0; i < n_warmup; ++i) {
const auto & collapsed = collapsed_texts[i];
{
auto it = std::sregex_iterator(collapsed.begin(), collapsed.end(), re_old);
auto end = std::sregex_iterator();
size_t c = 0;
for (; it != end; ++it) c++;
(void)c;
}
{
auto it = std::sregex_iterator(collapsed.begin(), collapsed.end(), re_new);
auto end = std::sregex_iterator();
size_t c = 0;
for (; it != end; ++it) c++;
(void)c;
}
}
// Timed pass: run through all samples
for (const auto & collapsed : collapsed_texts) {
// Bench old regex
size_t tokens_old_sample;
{
auto t0 = std::chrono::high_resolution_clock::now();
auto it = std::sregex_iterator(collapsed.begin(), collapsed.end(), re_old);
auto end = std::sregex_iterator();
size_t c = 0;
for (; it != end; ++it) c++;
auto t1 = std::chrono::high_resolution_clock::now();
tokens_old_sample = c;
st.tokens_old += c;
st.time_old_ms += std::chrono::duration<double, std::milli>(t1 - t0).count();
}
// Bench new regex
size_t tokens_new_sample;
{
auto t0 = std::chrono::high_resolution_clock::now();
auto it = std::sregex_iterator(collapsed.begin(), collapsed.end(), re_new);
auto end = std::sregex_iterator();
size_t c = 0;
for (; it != end; ++it) c++;
auto t1 = std::chrono::high_resolution_clock::now();
tokens_new_sample = c;
st.tokens_new += c;
st.time_new_ms += std::chrono::duration<double, std::milli>(t1 - t0).count();
}
// Check mismatch
if (tokens_old_sample != tokens_new_sample) {
st.mismatches++;
}
}
fprintf(stderr, " done\n");
}
// Print results
printf("\n");
printf("==============================================================================\n");
printf("Results by language\n");
printf("==============================================================================\n\n");
printf(" %-8s %8s %10s %10s | %10s %8s | %10s %8s | %7s %s\n",
"Lang", "Samples", "Bytes", "Cpts",
"Time_old", "MB/s",
"Time_new", "MB/s",
"Ratio", "Mismatches");
printf(" %-8s %8s %10s %10s | %10s %8s | %10s %8s | %7s %s\n",
"--------", "--------", "----------", "----------",
"----------", "--------",
"----------", "--------",
"-------", "----------");
// Sort languages alphabetically
std::vector<std::string> langs;
for (auto & [k, v] : stats_map) langs.push_back(k);
std::sort(langs.begin(), langs.end());
double total_time_old = 0, total_time_new = 0;
size_t total_bytes = 0, total_samples = 0, total_mismatches = 0;
for (const auto & lang : langs) {
const auto & st = stats_map[lang];
double mbs_old = (st.total_bytes / (1024.0 * 1024.0)) / (st.time_old_ms / 1000.0);
double mbs_new = (st.total_bytes / (1024.0 * 1024.0)) / (st.time_new_ms / 1000.0);
double ratio = st.time_old_ms / st.time_new_ms;
printf(" %-8s %8zu %10zu %10zu | %8.1f ms %6.2f | %8.1f ms %6.2f | %7.3f %zu\n",
st.lang.c_str(), st.n_samples, st.total_bytes, st.total_cpts,
st.time_old_ms, mbs_old,
st.time_new_ms, mbs_new,
ratio, st.mismatches);
total_time_old += st.time_old_ms;
total_time_new += st.time_new_ms;
total_bytes += st.total_bytes;
total_samples += st.n_samples;
total_mismatches += st.mismatches;
}
printf(" %-8s %8s %10s %10s | %10s %8s | %10s %8s | %7s %s\n",
"--------", "--------", "----------", "----------",
"----------", "--------",
"----------", "--------",
"-------", "----------");
double total_mbs_old = (total_bytes / (1024.0 * 1024.0)) / (total_time_old / 1000.0);
double total_mbs_new = (total_bytes / (1024.0 * 1024.0)) / (total_time_new / 1000.0);
double total_ratio = total_time_old / total_time_new;
printf(" %-8s %8zu %10zu %10s | %8.1f ms %6.2f | %8.1f ms %6.2f | %7.3f %zu\n",
"TOTAL", total_samples, total_bytes, "",
total_time_old, total_mbs_old,
total_time_new, total_mbs_new,
total_ratio, total_mismatches);
printf("\n");
printf(" Ratio > 1.0 means old is slower; < 1.0 means new is slower.\n");
printf(" Mismatches = samples where old/new produce different pre-token counts.\n");
printf("\n==============================================================================\n");
return 0;
}Input files look like this (this is the MIRACL training data)
{"docid":"545#19","title":"ดาราศาสตร์","text":"คลื่นวิทยุที่แผ่จากวัตถุดาราศาสตร์จำนวนหนึ่งอาจอยู่ในรูปของการแผ่รังสีความร้อน โดยมากแล้วการแผ่คลื่นวิทยุที่ตรวจจับได้บนโลกมักอยู่ในรูปแบ
บของการแผ่รังสีซิงโครตรอน ซึ่งเกิดจากการที่อิเล็กตรอนเคลื่อนที่เป็นคาบรอบเส้นแรงสนามแม่เหล็ก นอกจากนี้สเปกตรัมที่เกิดจากแก๊สระหว่างดาว โดยเฉพาะอย่างยิ่งเส้นสเปกตรัมของไฮโดรเจนที่ 21 เซ
นติเมตร จะสามารถสังเกตได้ในช่วงคลื่นวิทยุ"}There was a problem hiding this comment.
For completeness, the old regex without \p{M} causes excessive backtracking on diacritical marks in std::regex.
There was a problem hiding this comment.
Depends, if we expect to always be able to map directly to llm_ffn_op_type we don't need to, however there might be cases where we can't/don't want to do that in which case maybe it would be beneficial with a separate enum that can f.ex. contain values with some high bit reserved for special ops that are handled outside of build_ffn/_moe.
..or maybe that just overcomplicates things for no reason, not sure? :)
@CISC Looking a little more, I think the code as it stands now is a pretty good compromise. It is extensible to future uses where models need to specify the the ffn activation beyond geglu vs swiglu, but it doesn't impact the public header.
It does look like all other enum types that are used in llama_hparams are declared in the public header (here). That seems like it's not strictly necessary since llama_hparams itself is not public and therefore public API users wouldn't need to use values of llm_ffn_op_type directly. It looks like the reason for the other enums being public is that the user can set them via llama_context_params. It seems like this wouldn't be necessary for ffn_op_type since this is just a property of the specific model instances?
There was a problem hiding this comment.
@gabe-l-hart let me know what I should change in the code, thanks.
There was a problem hiding this comment.
@gabe-l-hart I did the changes - please let me know if you find anything out of place :).
|
I've confirmed that the inference is working as intended. Here was my process: Conversion(cd ~/models && hf download ibm-granite/granite-embedding-97m-multilingual-r2 --local-dir ibm-granite/granite-embedding-97m-multilingual-r2)
python convert_hf_to_gguf.py ~/models/ibm-granite/granite-embedding-97m-multilingual-r2/Baseline w/ Sentence TransformersI used this script to compare the results of running with granite_embed.pyfrom sentence_transformers import SentenceTransformer
import numpy as np
import subprocess
import shlex
import sys
model_path = "/Users/ghart/models/ibm-granite/granite-embedding-97m-multilingual-r2"
lcpp_model = f"{model_path}/granite-embedding-97M-multilingual-r2-BF16.gguf"
lcpp_exe = "./build/bin/llama-embedding"
if len(sys.argv) > 1:
model_path = sys.argv[1]
if len(sys.argv) > 2:
lcpp_model = sys.argv[2]
if len(sys.argv) > 3:
lcpp_exe = sys.argv[3]
model = SentenceTransformer(model_path)
input_queries = [
"hello world",
"tell me a story about a developer and their dog",
"123sfg this is a r@nd0m t35t",
]
def cosine_similarity(vector_a: np.ndarray, vector_b: np.ndarray) -> float:
vector_a = np.asarray(vector_a)
vector_b = np.asarray(vector_b)
numerator = np.dot(vector_a, vector_b)
denominator_a = np.linalg.norm(vector_a)
denominator_b = np.linalg.norm(vector_b)
if denominator_a == 0 or denominator_b == 0: return 0.0
cosine_sim = numerator / (denominator_a * denominator_b)
return cosine_sim
for query in input_queries:
print("### BASELINE ###")
embedding = model.encode([query])
print("Embedding shape:", embedding.shape)
print("Embedding vector:", embedding[:, :8])
print("### llama.cpp ###")
cmd = f"{lcpp_exe} -m {lcpp_model} -p \"{query}\" --temp 0 --embd-normalize -1"
print(f"llama.cpp command: {cmd}")
proc = subprocess.Popen(
shlex.split(cmd),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, _ = proc.communicate()
vals = out.decode("utf-8").split(":")[-1]
vals = [
float(v) for v in vals.split()
if v.strip()
]
lcpp_emb = np.array(vals)
print("llama.cpp Embedding shape:", lcpp_emb.shape)
print("llama.cpp Embedding vector:", lcpp_emb[:8])
print()
cos_sim = cosine_similarity(embedding, lcpp_emb)
print(f"COSINE SIMILARITY: {cos_sim}")
print("--------------------------------")
print()Results w/out branchResults w/ branch |
gabe-l-hart
left a comment
There was a problem hiding this comment.
@CISC @ggerganov I think the current changeset is good as is, but would like to get your review and opinion on the best placement for the llm_ffn_op_type enum. I also think it's ready to release CI.
| LLAMA_SWA_TYPE_SYMMETRIC = 3, | ||
| }; | ||
|
|
||
| enum llm_ffn_op_type { |
There was a problem hiding this comment.
The biggest question on this PR is whether moving the enum feels ok. I think this move is cleaner than following the pattern of llama_rope_scaling_type which is declared in the public llama.h header. The enum needs to move so it can be used within the non-public llama_hparams struct, but there's still no concrete reason why an external user of the library would need to access the enum, so I think we should not move it to the public API yet. This is a weak opinion weakly held! I'm definitely open to other views.
There was a problem hiding this comment.
TBC, I did not mean that it should be moved to llama.h, I meant that it should be mapped to the appropriate ffn_op_type in llama-model.cpp, similarly to how rope_scaling_type is handled.
There was a problem hiding this comment.
Ah, yes, that makes sense. That would replace the logic here to do the mapping just for this model. @hansolosan this should be fine with keeping the enum definition here in llama-hparams.h.
There was a problem hiding this comment.
OK, I moved it back to llama-graph.hpp.
| LLAMA_SWA_TYPE_SYMMETRIC = 3, | ||
| }; | ||
|
|
||
| enum llm_ffn_op_type { |
There was a problem hiding this comment.
TBC, I did not mean that it should be moved to llama.h, I meant that it should be mapped to the appropriate ffn_op_type in llama-model.cpp, similarly to how rope_scaling_type is handled.
gabe-l-hart
left a comment
There was a problem hiding this comment.
Thanks for the clarification @CISC
| LLAMA_SWA_TYPE_SYMMETRIC = 3, | ||
| }; | ||
|
|
||
| enum llm_ffn_op_type { |
There was a problem hiding this comment.
Ah, yes, that makes sense. That would replace the logic here to do the mapping just for this model. @hansolosan this should be fine with keeping the enum definition here in llama-hparams.h.
gabe-l-hart
left a comment
There was a problem hiding this comment.
We still need to revert the ignore_merges line in the gpt4o block. Also a couple more NIT suggestions. Getting close!
gabe-l-hart
left a comment
There was a problem hiding this comment.
Looks good to me! @CISC over to you
CISC
left a comment
There was a problem hiding this comment.
Rebase and adjust accordingly to refactoring.
Otherwise generally looks good, @ggerganov any feedback on moving llm_ffn_op_type?
| { "gelu", LLM_FFN_GEGLU }, | ||
| { "geglu", LLM_FFN_GEGLU }, | ||
| { "silu", LLM_FFN_SWIGLU }, | ||
| { "swish", LLM_FFN_SWIGLU }, | ||
| { "swiglu", LLM_FFN_SWIGLU }, | ||
| { "relu", LLM_FFN_RELU }, | ||
| { "reglu", LLM_FFN_REGLU }, |
There was a problem hiding this comment.
How do we genuinely distinguish whether the appropriate mapping is gated or non-gated?
There was a problem hiding this comment.
@CISC For ModernBERT, the gated assumption is structural, not inferred from hidden_activation. The FFN up-projection is loaded with shape {n_embd, 2 * n_ff} (code in modern-bert.cpp:57), which is only valid for a GLU-style FFN. A non-gated variant requires a shape {n_embd, n_ff} and would fail at load time with a tensor shape mismatch; hidden_activation just selects the gate nonlinearity (SiLU -> SwiGLU, GELU -> GeGLU) — the FFN is always gated.
There was a problem hiding this comment.
@hansolosan I think we may be waiting on an answer to this
There was a problem hiding this comment.
I'm not sure I understand the question - can you clarify?
There was a problem hiding this comment.
I think this is a question for @hansolosan (I'm not deep enough in these models to understand it either).
There was a problem hiding this comment.
The questions is how/when do we know to map f.ex. gelu to LLM_FFN_GEGLU and relu to LLM_FFN_RELU? Ie, gated vs. non-gated.
There was a problem hiding this comment.
@CISC - sorry, I replied last week, but forgot to send it, so the answer is out-of-sequence now :), sorry about that.
There was a problem hiding this comment.
@CISC For ModernBERT, the gated assumption is structural, not inferred from hidden_activation. The FFN up-projection is loaded with shape {n_embd, 2 * n_ff} (code in modern-bert.cpp:57), which is only valid for a GLU-style FFN. A non-gated variant requires a shape {n_embd, n_ff} and would fail at load time with a tensor shape mismatch; hidden_activation just selects the gate nonlinearity (SiLU -> SwiGLU, GELU -> GeGLU) — the FFN is always gated.
Right, but then what's up with relu? :)
Also, there are a lot of models with separate up/gate, or no gate at all, just thinking about future reusability...
There was a problem hiding this comment.
Fair question - ReGLU (LLM_FFN_REGLU) exists as an enum value, but no model under src/models/ uses it.
The gated/non-gated distinction in llm_ffn_op_type_from_string() matters for gelu and silu, where gated variants dominate. For relu, a model with hidden_activation: "relu" in its config is probably non-gated (classic transformer / BERT-era). So the current relu -> LLM_FFN_RELU mapping might be the right default if a non-ModernBert arch ever wires this up. The map's asymmetry (gelu->GeGLU, silu->SwiGLU, relu->RELU non-gated) is probably appropriate.
Given this helper's implicit contract ("HF activation string -> FFN op for a gated arch"), we could:
- Leave the map as-is — works for ModernBert, and ReGLU doesn't pair with ReLU anyway.
- Rename/scope the helper to make the assumption explicit (e.g. llm_ffn_op_type_from_string_gated), accepting that non-gated arches get their own helper later.
- Add a bool gated parameter and branch internally — premature until a non-gated consumer exists.
49822b4 to
13d7e66
Compare
gabe-l-hart
left a comment
There was a problem hiding this comment.
This still looks correct after the conversion refactor and rebase
13d7e66 to
e315c65
Compare
|
I've just rebased, so it would be ready for merging. Please let me know if there's anything I should do. |
@CISC @ggerganov I think we're good on this (with the possible exception of the question about gated vs non-gated). Can we release CI? |
e315c65 to
846ac0f
Compare
…gual-r2 embedding models: * Added a version of the gpt4o tokenizer that has a fixed regex (better handling of marks), and different token merging setting for the 97m model * Reused gemma4 tokenizer for the 311m model
846ac0f to
b9f69a7
Compare
…ite Embedding Multilingual R2 * added new GGUF key <arch>.hidden_activation (LLM_KV_HIDDEN_ACT) + writer * added a forward declaration of llm_ffn_op_type to llama-hparams.h * added llm_ffn_op in hparams * added LLM_FFN_NONE = 0 sentinel to llm_ffn_op_type (value-initialization), modern-bert: explicitly assigns LLM_FFN_GEGLU before reading GGUF (unchanged). * centralized hidden_act mapping in llama-model.cpp, added llm_ffn_op_type_from_string() helper, mirroring rope_scaling_type/llama_rope_scaling_type_from_string() * modern-bert reads the GGUF key (when present) and uses the resulting op in its FFN graph
* Added the hashes for the granite embedding multilingual R2 models * Set the hidden_activation in the GGUF if the field is present in config.json (such as for the granite embedding models)
b9f69a7 to
53af5c9
Compare
Overview
The PR adds support for 2 granite multilingual models just released, based on the ModernBERT architecture. Support is added to link the tokenizers properly and to use a different activation function for the 97m model (SiLU/SwiGLU) instead of the regular GeGLU.
Additional information
The models are available here: https://huggingface.co/ibm-granite/granite-embedding-97m-multilingual-r2 and https://huggingface.co/ibm-granite/granite-embedding-311m-multilingual-r2. In retrieval scores, the 97m is 8 points better than the next model on the MMTEB leaderboard under 100M parameters, and the 311m model is the second one in the <500M parameters category.
Requirements
I am not an AI agent :).