Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions packages/transformers/src/pipelines/question-answering.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import { softmax } from '../utils/maths.js';
/**
* @typedef {Object} QuestionAnsweringOutput
* @property {number} score The probability associated to the answer.
* @property {number} [start] The character start index of the answer (in the tokenized version of the input).
* @property {number} [end] The character end index of the answer (in the tokenized version of the input).
* @property {number} start The answer start offset (character index **in `context`**; slice with `context.slice(start, end)`).
* @property {number} end The exclusive end offset of the answer in **`context`** (half-open `[start, end)`).
* @property {string} answer The answer to the question.
*
* @typedef {Object} QuestionAnsweringPipelineOptions Parameters specific to question answering pipelines.
Expand Down Expand Up @@ -48,7 +48,9 @@ import { softmax } from '../utils/maths.js';
* const output = await answerer(question, context);
* // {
* // answer: "a nice puppet",
* // score: 0.5768911502526741
* // score: 0.5768911502526741,
* // start: ...,
* // end: ...,
* // }
* ```
*/
Expand All @@ -61,6 +63,7 @@ export class QuestionAnsweringPipeline
text_pair: context,
padding: true,
truncation: true,
return_offsets_mapping: true,
});
const isBatched = Array.isArray(question);

Expand All @@ -71,7 +74,10 @@ export class QuestionAnsweringPipeline
// TODO: add support for `return_special_tokens_mask`
const { all_special_ids, sep_token_id } = this.tokenizer;

const offset_mapping_batches = inputs.offset_mapping;
/** @type {QuestionAnsweringOutput[][]|QuestionAnsweringOutput[]} */
const batchedResults = [];

for (let j = 0; j < start_logits.dims[0]; ++j) {
const ids = input_ids[j];
const sepIndex = ids.findIndex(
Expand All @@ -81,26 +87,26 @@ export class QuestionAnsweringPipeline
x == sep_token_id,
);

const start = start_logits[j].tolist();
const end = end_logits[j].tolist();
const start_logits_row = start_logits[j].tolist();
const end_logits_row = end_logits[j].tolist();

// Now, we mask out values that can't be in the answer
// NOTE: We keep the cls_token unmasked (some models use it to indicate unanswerable questions)
for (let i = 1; i < start.length; ++i) {
for (let i = 1; i < start_logits_row.length; ++i) {
if (
attention_mask[j] == 0 || // is part of padding
i <= sepIndex || // is before the sep_token
all_special_ids.findIndex((x) => x == ids[i]) !== -1 // Is a special token
) {
// Make sure non-context indexes in the tensor cannot contribute to the softmax
start[i] = -Infinity;
end[i] = -Infinity;
start_logits_row[i] = -Infinity;
end_logits_row[i] = -Infinity;
}
}

// Normalize logits and spans to retrieve the answer
const start_scores = softmax(start).map((x, i) => [x, i]);
const end_scores = softmax(end).map((x, i) => [x, i]);
const start_scores = softmax(start_logits_row).map((x, i) => [x, i]);
const end_scores = softmax(end_logits_row).map((x, i) => [x, i]);

// Mask CLS
start_scores[0][0] = 0;
Expand All @@ -112,21 +118,28 @@ export class QuestionAnsweringPipeline
.map((x) => [x[0][1], x[1][1], x[0][0] * x[1][0]])
.sort((a, b) => b[2] - a[2]);

const rowOffsets = isBatched ? offset_mapping_batches[j] : offset_mapping_batches;

const sampleResults = [];
for (let k = 0; k < Math.min(options.length, top_k); ++k) {
const [start, end, score] = options[k];
const [startTok, endTok, spanScore] = options[k];

const answer_tokens = ids.slice(start, end + 1);
const answer_tokens = ids.slice(startTok, endTok + 1);

const answer = this.tokenizer.decode(answer_tokens, {
skip_special_tokens: true,
});

// TODO add start and end?
// NOTE: HF returns character index
/** @type {number} */
const startChar = rowOffsets[startTok][0];
/** @type {number} */
const endChar = rowOffsets[endTok][1];

sampleResults.push({
answer,
score,
score: spanScore,
start: startChar,
end: endChar,
});
}
if (top_k === 1) {
Expand Down
193 changes: 184 additions & 9 deletions packages/transformers/src/tokenization_utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,145 @@ const SPECIAL_TOKEN_ATTRIBUTES = [
* @param {string} side Which side to pad the array.
* @private
*/
/**
* Compute character-level `[start, end)` offsets for each token by scanning
* forward through the original text. Tokens that cannot be found (e.g.
* special tokens like [CLS]/[SEP], or subwords after normalization) get
* [0, 0], matching the Python tokenizers convention.
*
* The scan is tried case-sensitively first, then case-insensitively, to
* handle uncased tokenizers that lowercase the input before tokenizing.
*
* BPE/SentencePiece tokenizers prepend continuation-byte prefix characters
* to tokens: `Ġ` (U+0120) by GPT-2's ByteLevel pre-tokenizer and `▁` (U+2581)
* by SentencePiece models (LLaMA, Mistral, T5, …). These characters are not
* present in the original text, so we strip them before searching.
*
* @param {string[]} tokens The token strings produced by the tokenizer.
* @param {string} text The original input text.
* @returns {[number, number][]}
*/
function computeOffsets(tokens, text) {
/** @type {[number, number][]} */
const offsets = [];
const textLower = text.toLowerCase();
let pos = 0;
for (const token of tokens) {
// Strip BPE/SentencePiece continuation-byte prefix characters.
// Ġ (U+0120) is used by GPT-2's ByteLevel pre-tokenizer.
// ▁ (U+2581) is used by SentencePiece (LLaMA, Mistral, T5, …).
const byteLevelSpacePrefix = token.startsWith('\u0120');
const clean = token.replace(/^[\u0120\u2581]+/, '');
if (clean === '') {
offsets.push([0, 0]);
continue;
}
// Try exact match first, then case-insensitive for uncased tokenizers.
let idx = text.indexOf(clean, pos);
if (idx === -1) idx = textLower.indexOf(clean.toLowerCase(), pos);
if (idx === -1) {
offsets.push([0, 0]);
} else {
let start = idx;
// ByteLevel maps leading space to Ġ; HF offset spans include that space in the original text.
if (byteLevelSpacePrefix && idx > 0 && text[idx - 1] === ' ') {
start = idx - 1;
}
const end = idx + clean.length;
offsets.push([start, end]);
pos = end;
}
}
return offsets;
}

/**
* BERT-like layout: indices after the **first** `sep_token` belong to segment 1 (`text_pair`).
* @param {string[]} tokens
* @param {string|null|undefined} sepToken
* @returns {number[]} 0 = first segment, 1 = second segment
*/
function inferPairSegmentIds(tokens, sepToken) {
const segmentIds = new Array(tokens.length).fill(0);
if (sepToken == null) return segmentIds;

/** @type {number[]} */
const sepIndexes = [];
for (let i = 0; i < tokens.length; ++i) {
if (tokens[i] === sepToken) sepIndexes.push(i);
}
if (sepIndexes.length >= 1) {
for (let i = sepIndexes[0] + 1; i < tokens.length; ++i) {
segmentIds[i] = 1;
}
}
return segmentIds;
}

/**
* Character `[start, end)` offsets per token for a **two-part** encode.
* Segment 0 uses `text` (first sequence); segment 1 uses `textPair` (second sequence).
*
* @param {string[]} tokens
* @param {string} text First sequence text.
* @param {string} textPair Second sequence text.
* @param {number[]} segmentIds Same length as `tokens`.
* @returns {[number, number][]}
*/
function computeOffsetsForSequencePair(tokens, text, textPair, segmentIds) {
/** @type {[number, number][]} */
const offsets = [];
const textLower = text.toLowerCase();
const pairLower = textPair.toLowerCase();
let pos0 = 0;
let pos1 = 0;
for (let i = 0; i < tokens.length; ++i) {
const token = tokens[i];
const segment = segmentIds[i];

const byteLevelSpacePrefix = token.startsWith('\u0120');
const clean = token.replace(/^[\u0120\u2581]+/, '');
if (clean === '') {
offsets.push([0, 0]);
continue;
}

const activeText = segment === 1 ? textPair : text;
const activeLower = segment === 1 ? pairLower : textLower;
let pos = segment === 1 ? pos1 : pos0;

let idx = activeText.indexOf(clean, pos);
if (idx === -1) idx = activeLower.indexOf(clean.toLowerCase(), pos);
if (idx === -1) {
offsets.push([0, 0]);
continue;
}

let start = idx;
if (byteLevelSpacePrefix && idx > 0 && activeText[idx - 1] === ' ') {
start = idx - 1;
}
const endExclusive = idx + clean.length;
offsets.push([start, endExclusive]);
if (segment === 1) {
pos1 = endExclusive;
} else {
pos0 = endExclusive;
}
}
return offsets;
}

function padHelper(item, length, value_fn, side) {
for (const key of Object.keys(item)) {
const diff = length - item[key].length;
const value = value_fn(key);
const padData =
diff === 0
? []
: key === 'offset_mapping'
? Array.from({ length: diff }, () => /** @type {[number, number]} */ ([0, 0]))
: new Array(diff).fill(value_fn(key));

const padData = new Array(diff).fill(value);
item[key] = side === 'right' ? mergeArrays(item[key], padData) : mergeArrays(padData, item[key]);
}
}
Expand Down Expand Up @@ -197,6 +330,7 @@ function getSpecialTokens(tokenizer) {
* @property {number|null} [max_length=null] Maximum length of the returned list and optionally padding length.
* @property {TReturnTensor} [return_tensor=true] Whether to return the results as Tensors or arrays.
* @property {boolean|null} [return_token_type_ids=null] Whether to return the token type ids.
* @property {boolean} [return_offsets_mapping=false] Whether to return `[start, end)` character offsets per token (see `offset_mapping`).
*/

/**
Expand Down Expand Up @@ -359,7 +493,8 @@ export class PreTrainedTokenizer
text,
options = {},
) {
const { text_pair = null, add_special_tokens = true, padding = false, return_token_type_ids = null } = options;
const { text_pair = null, add_special_tokens = true, padding = false, return_token_type_ids = null, return_offsets_mapping = false } =
options;
let { truncation = null, max_length = null } = options;
const return_tensor = /** @type {TReturnTensor} */ (options.return_tensor ?? true); // Different to HF

Expand All @@ -380,10 +515,17 @@ export class PreTrainedTokenizer
}

encodedTokens = text.map((t, i) =>
this._encode_plus(t, { text_pair: text_pair[i], add_special_tokens, return_token_type_ids }),
this._encode_plus(t, {
text_pair: text_pair[i],
add_special_tokens,
return_token_type_ids,
return_offsets_mapping,
}),
);
} else {
encodedTokens = text.map((x) => this._encode_plus(x, { add_special_tokens, return_token_type_ids }));
encodedTokens = text.map((x) =>
this._encode_plus(x, { add_special_tokens, return_token_type_ids, return_offsets_mapping }),
);
}
} else {
if (text === null || text === undefined) {
Expand All @@ -397,7 +539,9 @@ export class PreTrainedTokenizer
}

// For single input, we just wrap in an array, and then unwrap later.
encodedTokens = [this._encode_plus(text, { text_pair, add_special_tokens, return_token_type_ids })];
encodedTokens = [
this._encode_plus(text, { text_pair, add_special_tokens, return_token_type_ids, return_offsets_mapping }),
];
}
// At this point, `encodedTokens` is batched, of shape [batch_size, tokens].
// However, array may be jagged. So, we may need pad to max_length.
Expand Down Expand Up @@ -482,6 +626,12 @@ export class PreTrainedTokenizer
const dims = [encodedTokens.length, encodedTokens[0].input_ids.length];

for (const key of Object.keys(encodedTokens[0])) {
if (key === 'offset_mapping') {
const batched = encodedTokens.map((x) => x[key]);
/** @type {unknown} */
result[key] = isBatched ? batched : batched[0];
continue;
}
result[key] = new Tensor(
'int64',
BigInt64Array.from(encodedTokens.flatMap((x) => x[key]).map(BigInt)),
Expand Down Expand Up @@ -524,19 +674,44 @@ export class PreTrainedTokenizer
* @param {string|null} [options.text_pair=null] The optional second text to encode.
* @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model.
* @param {boolean|null} [options.return_token_type_ids=null] Whether to return token_type_ids.
* @returns {{input_ids: number[], attention_mask: number[], token_type_ids?: number[]}} An object containing the encoded text.
* @param {boolean} [options.return_offsets_mapping=false] When `text_pair` is set, offsets for each token are computed in the matching sequence: first segment in `text`, second in `text_pair` (HF-style). Single-sequence encode uses `text` only.
* @returns {{input_ids: number[], attention_mask: number[], token_type_ids?: number[], offset_mapping?: [number, number][]}} An object containing the encoded text.
* @private
*/
_encode_plus(text, { text_pair = null, add_special_tokens = true, return_token_type_ids = null } = {}) {
const { ids, attention_mask, token_type_ids } = this._tokenizer.encode(text, {
_encode_plus(
text,
{
text_pair = null,
add_special_tokens = true,
return_token_type_ids = null,
return_offsets_mapping = false,
} = {},
) {
const { ids, attention_mask, token_type_ids, tokens } = this._tokenizer.encode(text, {
text_pair,
add_special_tokens,
return_token_type_ids: return_token_type_ids ?? this.return_token_type_ids,
});
/** @type {[number, number][]|undefined} */
let offset_mapping;
if (return_offsets_mapping) {
if (text_pair != null) {
let segmentIds = token_type_ids
? token_type_ids.map((x) => (typeof x === 'bigint' ? Number(x) : x))
: null;
if (!segmentIds || segmentIds.every((id) => id === 0)) {
segmentIds = inferPairSegmentIds(tokens, this.sep_token ?? undefined);
}
offset_mapping = computeOffsetsForSequencePair(tokens, text, text_pair, segmentIds);
} else {
offset_mapping = computeOffsets(tokens, text);
}
}
return {
input_ids: ids,
attention_mask,
...(token_type_ids ? { token_type_ids } : {}),
...(offset_mapping ? { offset_mapping } : {}),
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export default () => {
"default (top_k=1)",
async () => {
const output = await pipe("a", "b c");
const target = { score: 0.11395696550607681, /* start: 0, end: 1, */ answer: "b" };
const target = { score: 0.11395696550607681, start: 0, end: 1, answer: "b" };
expect(output).toBeCloseToNested(target, 5);
},
MAX_TEST_EXECUTION_TIME,
Expand All @@ -32,9 +32,9 @@ export default () => {
async () => {
const output = await pipe("a", "b c", { top_k: 3 });
const target = [
{ score: 0.11395696550607681, /* start: 0, end: 1, */ answer: "b" },
{ score: 0.11300431191921234, /* start: 2, end: 3, */ answer: "c" },
{ score: 0.10732574015855789, /* start: 0, end: 3, */ answer: "b c" },
{ score: 0.11395696550607681, start: 0, end: 1, answer: "b" },
{ score: 0.11300431191921234, start: 2, end: 3, answer: "c" },
{ score: 0.10732574015855789, start: 0, end: 3, answer: "b c" },
];
expect(output).toBeCloseToNested(target, 5);
},
Expand Down
Loading