-
Notifications
You must be signed in to change notification settings - Fork 244
Native Qwen3-Reranker CausalLM support in RerankCalculatorOV #4063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -60,6 +60,8 @@ class RerankCalculatorOV : public CalculatorBase { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static const std::string RERANK_MODEL_INPUT_IDS_NAME; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static const std::string RERANK_MODEL_ATTENTION_MASK_NAME; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static const std::string RERANK_MODEL_TOKEN_TYPE_IDS_NAME; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static const std::string RERANK_MODEL_POSITION_IDS_NAME; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static const std::string RERANK_MODEL_BEAM_IDX_NAME; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static constexpr size_t NUMBER_OF_SPECIAL_TOKENS = 4; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mediapipe::Timestamp timestamp{0}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -151,6 +153,39 @@ class RerankCalculatorOV : public CalculatorBase { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Validate batch size before tokenizing | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (handler.getDocumentsList().size() > this->max_allowed_chunks) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| throw std::runtime_error("Number of documents exceeds max_allowed_chunks"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (rerank_session->isQwen3) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Qwen3 reranker: format each query-document pair using the chat template | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Template from openvino-2026.0-genai/tests/python_tests/utils/qwen3_reranker_utils.py | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto batchSize = handler.getDocumentsList().size(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<std::string> data(batchSize); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::string prefix = "<|im_start|>system\nJudge whether the Document meets the requirements " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "based on the Query and the Instruct provided. Note that the answer can only be " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "\"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "<Query>: " + handler.getQuery() + "\n"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::string suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t i = 0; i < batchSize; i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| data[i] = prefix + "<Document>: " + handler.getDocumentsList()[i] + suffix; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| chunk_mapping.resize(batchSize); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::iota(chunk_mapping.begin(), chunk_mapping.end(), 0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto tokens = rerank_session->getTokenizer().encode(data); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (tokens.input_ids.get_shape().size() != 2) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| throw std::runtime_error("Tokens shape invalid."); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (this->max_position_embeddings < tokens.input_ids.get_shape()[1]) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::ostringstream msg; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg << "Qwen3 rerank request length of " << tokens.input_ids.get_shape()[1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| << " tokens exceeds the model context of " << max_position_embeddings; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| throw std::runtime_error(msg.str()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| SPDLOG_LOGGER_DEBUG(rerank_calculator_logger, "Qwen3 rerank: {} documents, {} tokens per sequence", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batchSize, tokens.input_ids.get_shape()[1]); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return std::make_pair(tokens.input_ids, tokens.attention_mask); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!rerank_session->addBosToken) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto batchSize = handler.getDocumentsList().size(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<std::string> data(batchSize); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -257,6 +292,27 @@ class RerankCalculatorOV : public CalculatorBase { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (typeIds.has_value()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inferRequest.set_tensor(RERANK_MODEL_TOKEN_TYPE_IDS_NAME, typeIds.value()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // For CausalLM models (e.g. Qwen3 rerankers): set position_ids and beam_idx | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (rerank_session->hasPositionIds) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| size_t batch = input_ids.get_shape()[0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| size_t seq_len = input_ids.get_shape()[1]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto position_ids = ov::Tensor(ov::element::i64, input_ids.get_shape()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int64_t* pos_data = position_ids.data<int64_t>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int64_t* attn_data = attention_mask.data<int64_t>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t b = 0; b < batch; b++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int64_t pos = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t s = 0; s < seq_len; s++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pos_data[b * seq_len + s] = attn_data[b * seq_len + s] ? pos++ : 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inferRequest.set_tensor(RERANK_MODEL_POSITION_IDS_NAME, position_ids); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (rerank_session->hasBeamIdx) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| size_t batch = input_ids.get_shape()[0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto beam_idx = ov::Tensor(ov::element::i32, {batch}); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::fill_n(beam_idx.data<int32_t>(), batch, 0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+299
to
+313
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto position_ids = ov::Tensor(ov::element::i64, input_ids.get_shape()); | |
| int64_t* pos_data = position_ids.data<int64_t>(); | |
| int64_t* attn_data = attention_mask.data<int64_t>(); | |
| for (size_t b = 0; b < batch; b++) { | |
| int64_t pos = 0; | |
| for (size_t s = 0; s < seq_len; s++) { | |
| pos_data[b * seq_len + s] = attn_data[b * seq_len + s] ? pos++ : 0; | |
| } | |
| } | |
| inferRequest.set_tensor(RERANK_MODEL_POSITION_IDS_NAME, position_ids); | |
| } | |
| if (rerank_session->hasBeamIdx) { | |
| size_t batch = input_ids.get_shape()[0]; | |
| auto beam_idx = ov::Tensor(ov::element::i32, {batch}); | |
| std::fill_n(beam_idx.data<int32_t>(), batch, 0); | |
| // Derive element types from the compiled model inputs to avoid dtype mismatches. | |
| const ov::element::Type pos_element_type = | |
| inferRequest.get_compiled_model().input(RERANK_MODEL_POSITION_IDS_NAME).get_element_type(); | |
| const ov::element::Type mask_element_type = | |
| inferRequest.get_compiled_model().input(RERANK_MODEL_ATTENTION_MASK_NAME).get_element_type(); | |
| ov::Tensor position_ids(pos_element_type, input_ids.get_shape()); | |
| if (pos_element_type == ov::element::i64) { | |
| int64_t* pos_data = position_ids.data<int64_t>(); | |
| if (mask_element_type == ov::element::i64) { | |
| int64_t* attn_data = attention_mask.data<int64_t>(); | |
| for (size_t b = 0; b < batch; b++) { | |
| int64_t pos = 0; | |
| for (size_t s = 0; s < seq_len; s++) { | |
| pos_data[b * seq_len + s] = attn_data[b * seq_len + s] ? pos++ : 0; | |
| } | |
| } | |
| } else if (mask_element_type == ov::element::i32) { | |
| int32_t* attn_data = attention_mask.data<int32_t>(); | |
| for (size_t b = 0; b < batch; b++) { | |
| int64_t pos = 0; | |
| for (size_t s = 0; s < seq_len; s++) { | |
| pos_data[b * seq_len + s] = attn_data[b * seq_len + s] ? pos++ : 0; | |
| } | |
| } | |
| } else { | |
| throw std::runtime_error("Unsupported attention_mask element type for position_ids generation"); | |
| } | |
| } else if (pos_element_type == ov::element::i32) { | |
| int32_t* pos_data = position_ids.data<int32_t>(); | |
| if (mask_element_type == ov::element::i64) { | |
| int64_t* attn_data = attention_mask.data<int64_t>(); | |
| for (size_t b = 0; b < batch; b++) { | |
| int32_t pos = 0; | |
| for (size_t s = 0; s < seq_len; s++) { | |
| pos_data[b * seq_len + s] = | |
| attn_data[b * seq_len + s] ? static_cast<int32_t>(pos++) : 0; | |
| } | |
| } | |
| } else if (mask_element_type == ov::element::i32) { | |
| int32_t* attn_data = attention_mask.data<int32_t>(); | |
| for (size_t b = 0; b < batch; b++) { | |
| int32_t pos = 0; | |
| for (size_t s = 0; s < seq_len; s++) { | |
| pos_data[b * seq_len + s] = attn_data[b * seq_len + s] ? pos++ : 0; | |
| } | |
| } | |
| } else { | |
| throw std::runtime_error("Unsupported attention_mask element type for position_ids generation"); | |
| } | |
| } else { | |
| throw std::runtime_error("Unsupported position_ids element type in compiled model"); | |
| } | |
| inferRequest.set_tensor(RERANK_MODEL_POSITION_IDS_NAME, position_ids); | |
| } | |
| if (rerank_session->hasBeamIdx) { | |
| size_t batch = input_ids.get_shape()[0]; | |
| const ov::element::Type beam_element_type = | |
| inferRequest.get_compiled_model().input(RERANK_MODEL_BEAM_IDX_NAME).get_element_type(); | |
| ov::Tensor beam_idx(beam_element_type, {batch}); | |
| if (beam_element_type == ov::element::i32) { | |
| std::fill_n(beam_idx.data<int32_t>(), batch, 0); | |
| } else if (beam_element_type == ov::element::i64) { | |
| std::fill_n(beam_idx.data<int64_t>(), batch, 0); | |
| } else { | |
| throw std::runtime_error("Unsupported beam_idx element type in compiled model"); | |
| } |
Copilot
AI
Mar 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR introduces a new Qwen3-specific request formatting path and new model inputs (position_ids, beam_idx) handling, but there are no unit/functional tests covering these behaviors. Since the repo already has rerank-related tests (e.g. src/test/reranknode_test.cpp, src/test/rerank_chunking_test.cpp), please add coverage that at least validates Qwen3 detection from config.json and that the calculator sets the expected extra tensors / produces a [batch, 1] logits output.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -23,10 +23,18 @@ | |||||||||||||||||||||||||||||||||||||||||||||||
| #include <string> | ||||||||||||||||||||||||||||||||||||||||||||||||
| #include <unordered_map> | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| #include <openvino/opsets/opset1.hpp> | ||||||||||||||||||||||||||||||||||||||||||||||||
| #include <openvino/opsets/opset8.hpp> | ||||||||||||||||||||||||||||||||||||||||||||||||
| #include <openvino/core/preprocess/pre_post_process.hpp> | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| namespace ovms { | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| struct RerankServable : SidepacketServable { | ||||||||||||||||||||||||||||||||||||||||||||||||
| bool addBosToken = true; | ||||||||||||||||||||||||||||||||||||||||||||||||
| bool isQwen3 = false; | ||||||||||||||||||||||||||||||||||||||||||||||||
| bool hasPositionIds = false; | ||||||||||||||||||||||||||||||||||||||||||||||||
| bool hasBeamIdx = false; | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| RerankServable(const std::string& modelDir, const std::string& targetDevice, const std::string& pluginConfig, const std::string& graphPath) : | ||||||||||||||||||||||||||||||||||||||||||||||||
| SidepacketServable(modelDir, targetDevice, pluginConfig, graphPath) { | ||||||||||||||||||||||||||||||||||||||||||||||||
| std::filesystem::path tokenizerConfigPath = (parsedModelsPath / "tokenizer_config.json"); | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -49,6 +57,120 @@ struct RerankServable : SidepacketServable { | |||||||||||||||||||||||||||||||||||||||||||||||
| addBosToken = false; | ||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| protected: | ||||||||||||||||||||||||||||||||||||||||||||||||
| std::shared_ptr<ov::Model> applyPrePostProcessing(ov::Core& core, std::shared_ptr<ov::Model> model, ov::AnyMap& properties) override { | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would split that logic. Since we only need graph postprocessing for Qwen3 I would extract detection to another I think it will be more straighforward on the higher level as now we need to get into that function to see early return if it's not qwen3 model |
||||||||||||||||||||||||||||||||||||||||||||||||
| // Detect Qwen3 model type from config.json | ||||||||||||||||||||||||||||||||||||||||||||||||
| std::filesystem::path configPath = parsedModelsPath / "config.json"; | ||||||||||||||||||||||||||||||||||||||||||||||||
| if (std::filesystem::exists(configPath)) { | ||||||||||||||||||||||||||||||||||||||||||||||||
| std::ifstream ifs(configPath.string()); | ||||||||||||||||||||||||||||||||||||||||||||||||
| if (ifs.is_open()) { | ||||||||||||||||||||||||||||||||||||||||||||||||
| rapidjson::Document modelConfig; | ||||||||||||||||||||||||||||||||||||||||||||||||
| rapidjson::IStreamWrapper isw(ifs); | ||||||||||||||||||||||||||||||||||||||||||||||||
| rapidjson::ParseResult parseResult = modelConfig.ParseStream(isw); | ||||||||||||||||||||||||||||||||||||||||||||||||
| if (!parseResult.Code()) { | ||||||||||||||||||||||||||||||||||||||||||||||||
| if (modelConfig.HasMember("model_type") && modelConfig["model_type"].IsString()) { | ||||||||||||||||||||||||||||||||||||||||||||||||
| std::string modelType = modelConfig["model_type"].GetString(); | ||||||||||||||||||||||||||||||||||||||||||||||||
| if (modelType == "qwen3") { | ||||||||||||||||||||||||||||||||||||||||||||||||
| SPDLOG_INFO("Detected Qwen3 reranker model, applying specialized postprocessing"); | ||||||||||||||||||||||||||||||||||||||||||||||||
| isQwen3 = true; | ||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if (!isQwen3) { | ||||||||||||||||||||||||||||||||||||||||||||||||
| return model; | ||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| // Check model inputs for position_ids and beam_idx | ||||||||||||||||||||||||||||||||||||||||||||||||
| for (const auto& input : model->inputs()) { | ||||||||||||||||||||||||||||||||||||||||||||||||
| if (input.get_any_name() == "position_ids") { | ||||||||||||||||||||||||||||||||||||||||||||||||
| hasPositionIds = true; | ||||||||||||||||||||||||||||||||||||||||||||||||
| SPDLOG_DEBUG("Qwen3 reranker model has position_ids input"); | ||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||
| if (input.get_any_name() == "beam_idx") { | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be else if ? |
||||||||||||||||||||||||||||||||||||||||||||||||
| hasBeamIdx = true; | ||||||||||||||||||||||||||||||||||||||||||||||||
| SPDLOG_DEBUG("Qwen3 reranker model has beam_idx input"); | ||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| // Check output shape — only apply postprocessing for CausalLM models (3D output) | ||||||||||||||||||||||||||||||||||||||||||||||||
| ov::PartialShape outputShape = model->get_output_partial_shape(0); | ||||||||||||||||||||||||||||||||||||||||||||||||
| if (outputShape.rank().get_length() == 2) { | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I would go for the wider check. If we expect 3D, let's check for 3D, so if |
||||||||||||||||||||||||||||||||||||||||||||||||
| // Already a 2D output (text-classification export) — postprocessing won't help | ||||||||||||||||||||||||||||||||||||||||||||||||
| // because the classification head has random weights | ||||||||||||||||||||||||||||||||||||||||||||||||
| SPDLOG_WARN("Qwen3 reranker has 2D output shape (text-classification export). " | ||||||||||||||||||||||||||||||||||||||||||||||||
| "Re-export with --task text-generation for correct scoring."); | ||||||||||||||||||||||||||||||||||||||||||||||||
| return model; | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+101
to
+106
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if (outputShape.rank().get_length() == 2) { | |
| // Already a 2D output (text-classification export) — postprocessing won't help | |
| // because the classification head has random weights | |
| SPDLOG_WARN("Qwen3 reranker has 2D output shape (text-classification export). " | |
| "Re-export with --task text-generation for correct scoring."); | |
| return model; | |
| ov::Rank outputRank = outputShape.rank(); | |
| if (outputRank.is_dynamic()) { | |
| SPDLOG_WARN("Qwen3 reranker output rank is dynamic; skipping specialized postprocessing"); | |
| return model; | |
| } | |
| std::size_t outputRankLength = outputRank.get_length(); | |
| if (outputRankLength == 2) { | |
| // Already a 2D output (text-classification export) — postprocessing won't help | |
| // because the classification head has random weights | |
| SPDLOG_WARN("Qwen3 reranker has 2D output shape (text-classification export). " | |
| "Re-export with --task text-generation for correct scoring."); | |
| return model; | |
| } else if (outputRankLength != 3) { | |
| SPDLOG_WARN("Qwen3 reranker has unexpected output rank {}. Expected 2 (classification) or 3 (CausalLM). " | |
| "Skipping specialized postprocessing.", | |
| outputRankLength); | |
| return model; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if it's safe to rely on tokenizer->encode as in certain settings it will treat it as an input prompt and add special tokens like <bos><yes_token> and we end up picking wrong tokens.
I can see in GenAI there is a direct vocab read for that:
https://github.com/openvinotoolkit/openvino.genai/blob/716a778fc0ccfa86f1395b186a0cb2ca8ed7ece5/src/cpp/src/rag/text_rerank_pipeline.cpp#L179
I think it would be safer to do it that way.
Copilot
AI
Mar 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::numeric_limits<int64_t>::max() is used here but <limits> isn’t included in this header. Please include <limits> explicitly to avoid relying on transitive includes (IWYU) and prevent fragile builds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be named just axis or sliceAxis right? There is no axis2, so we don't need to differ and value of the constant tells which axis we pick.
Copilot
AI
Mar 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The custom postprocess returns a new node without ensuring the output tensor name and element type match what the rest of the rerank pipeline expects. RerankCalculatorOV later fetches inferRequest.get_tensor("logits") and reads it as float*; if PrePostProcessor drops/changes the output name or leaves the output as f16, inference or scoring will break. Please either (a) set the postprocessed result tensor names back to "logits" and convert the output to f32 in the postprocess graph, or (b) update the calculator to query the compiled model’s output name and handle non-f32 element types.
| return diff; // [batch, 1] | |
| // Ensure the final output tensor matches pipeline expectations: | |
| // - element type: f32 (RerankCalculatorOV reads as float*) | |
| // - tensor name: "logits" (queried via inferRequest.get_tensor("logits")) | |
| auto diffF32 = std::make_shared<ov::op::v0::Convert>(diff, ov::element::f32); | |
| diffF32->set_friendly_name("logits"); | |
| return diffF32; // [batch, 1] logits in f32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this might be valid. Setting the name will ensure the output read will work even if something changes due to graph surgery.
Can you confirm new output node is indeed f32? I wonder if the convert op here is indeed necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the Qwen3 path, the tokenizer outputs are not validated the way the non-Qwen3 path effectively is (via
chunkDocuments()), but later code assumesattention_maskis i64 and usesattention_mask.data<int64_t>()to computeposition_ids. Please add explicit validation oftokens.input_ids/tokens.attention_maskelement types and shapes for the Qwen3 branch to avoid UB if the tokenizer output precision/layout differs.