Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions packages/transformers/src/backends/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,25 @@ if (ORT_SYMBOL in globalThis) {
// @ts-ignore
const InferenceSession = ONNX.InferenceSession;

/**
* Returns the list of devices available in the current environment, sorted by priority/performance.
*
* **Example:** Check available devices before loading a model.
* ```javascript
* import { get_available_devices } from '@huggingface/transformers';
*
* const devices = get_available_devices();
* // Node.js (Windows): ['dml', 'webgpu', 'cpu']
* // Node.js (Linux x64): ['cuda', 'webgpu', 'cpu']
* // Browser (WebGPU): ['webgpu', 'wasm']
* // Browser (no WebGPU): ['wasm']
* ```
* @returns {import("../utils/devices.js").DeviceType[]} The list of available devices.
*/
export function get_available_devices() {
return [...supportedDevices];
}

/**
* Map a device to the execution providers to use for the given device.
* @param {import("../utils/devices.js").DeviceType|"auto"|null} [device=null] (Optional) The device to run the inference on.
Expand Down
2 changes: 1 addition & 1 deletion packages/transformers/src/models/modeling_utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -1325,7 +1325,7 @@ export async function decoder_forward(self, model_inputs, is_encoder_decoder = f
// logits will be calculated. During generation, the default is 1 because only the logits of the last
// prompt token are needed for generation. For long sequences, the logits for the entire sequence may
// use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint significantly.
new_model_inputs.num_logits_to_keep = new Tensor('int64', [0n], []);
new_model_inputs.num_logits_to_keep = new Tensor('int64', [1n], []);
}

// Unpack the `past_key_values` object into model inputs
Expand Down
24 changes: 21 additions & 3 deletions packages/transformers/src/pipelines/token-classification.js
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,10 @@ export class TokenClassificationPipeline
}

const isBatched = Array.isArray(texts);
const textList = isBatched ? texts : [texts];

// Run tokenization
const model_inputs = this.tokenizer(isBatched ? texts : [texts], {
const model_inputs = this.tokenizer(textList, {
padding: true,
truncation: true,
});
Expand All @@ -136,26 +137,40 @@ export class TokenClassificationPipeline
for (let i = 0; i < logits.dims[0]; ++i) {
const ids = model_inputs.input_ids[i].tolist();
const batch = logits[i];
const text = textList[i];

const tokens = [];
let charOffset = 0;
for (let j = 0; j < batch.dims[0]; ++j) {
const tokenData = batch[j];
const topScoreIndex = max(tokenData.data)[1];

const entity = id2label ? id2label[topScoreIndex] : `LABEL_${topScoreIndex}`;
if (ignore_labels.includes(entity)) continue;

// TODO add option to keep special tokens?
const word = this.tokenizer.decode([ids[j]], { skip_special_tokens: true });
if (word === '') continue; // Was a special token.

// Locate this token's character span in the original text by
// scanning forward from where the previous token ended.
const idx = text.indexOf(word, charOffset);
let start, end;
if (idx !== -1) {
start = idx;
end = idx + word.length;
charOffset = end;
}

if (ignore_labels.includes(entity)) continue;

const scores = softmax(tokenData.data);
tokens.push({
entity,
score: scores[topScoreIndex],
index: j,
word,
// TODO: Add support for start and end
start,
end,
});
}

Expand Down Expand Up @@ -218,10 +233,13 @@ function groupEntities(tokens, ids, tokenizer) {
scoreSum += tokens[i].score;
groupIds.push(ids[tokens[i].index]);
}
const charStart = tokens[start].start;
const charEnd = tokens[end - 1].end;
return {
entity_group: tag,
score: scoreSum / (end - start),
word: tokenizer.decode(groupIds, { skip_special_tokens: true }),
...(charStart !== undefined ? { start: charStart, end: charEnd } : {}),
};
});
}
67 changes: 58 additions & 9 deletions packages/transformers/src/tokenization_utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,42 @@ const SPECIAL_TOKEN_ATTRIBUTES = [
* @param {string} side Which side to pad the array.
* @private
*/
/**
* Compute character-level [start, end) offsets for each token by scanning
* forward through the original text. Tokens that cannot be found (e.g.
* special tokens like [CLS]/[SEP], or subwords after normalization) get
* [0, 0], matching the Python tokenizers convention.
*
* The scan is tried case-sensitively first, then case-insensitively, to
* handle uncased tokenizers that lowercase the input before tokenizing.
*
* @param {string[]} tokens The token strings produced by the tokenizer.
* @param {string} text The original input text.
* @returns {[number, number][]}
*/
function computeOffsets(tokens, text) {
/** @type {[number, number][]} */
const offsets = [];
const textLower = text.toLowerCase();
let pos = 0;
for (const token of tokens) {
if (token === '') {
offsets.push([0, 0]);
continue;
}
// Try exact match first, then case-insensitive for uncased tokenizers.
let idx = text.indexOf(token, pos);
if (idx === -1) idx = textLower.indexOf(token.toLowerCase(), pos);
if (idx === -1) {
offsets.push([0, 0]);
} else {
offsets.push([idx, idx + token.length]);
pos = idx + token.length;
}
}
return offsets;
}

function padHelper(item, length, value_fn, side) {
for (const key of Object.keys(item)) {
const diff = length - item[key].length;
Expand Down Expand Up @@ -197,6 +233,7 @@ function getSpecialTokens(tokenizer) {
* @property {number|null} [max_length=null] Maximum length of the returned list and optionally padding length.
* @property {TReturnTensor} [return_tensor=true] Whether to return the results as Tensors or arrays.
* @property {boolean|null} [return_token_type_ids=null] Whether to return the token type ids.
* @property {boolean} [return_offsets_mapping=false] Whether to return character-level [start, end) offsets for each token.
*/

/**
Expand Down Expand Up @@ -359,7 +396,7 @@ export class PreTrainedTokenizer
text,
options = {},
) {
const { text_pair = null, add_special_tokens = true, padding = false, return_token_type_ids = null } = options;
const { text_pair = null, add_special_tokens = true, padding = false, return_token_type_ids = null, return_offsets_mapping = false } = options;
let { truncation = null, max_length = null } = options;
const return_tensor = /** @type {TReturnTensor} */ (options.return_tensor ?? true); // Different to HF

Expand All @@ -380,10 +417,10 @@ export class PreTrainedTokenizer
}

encodedTokens = text.map((t, i) =>
this._encode_plus(t, { text_pair: text_pair[i], add_special_tokens, return_token_type_ids }),
this._encode_plus(t, { text_pair: text_pair[i], add_special_tokens, return_token_type_ids, return_offsets_mapping }),
);
} else {
encodedTokens = text.map((x) => this._encode_plus(x, { add_special_tokens, return_token_type_ids }));
encodedTokens = text.map((x) => this._encode_plus(x, { add_special_tokens, return_token_type_ids, return_offsets_mapping }));
}
} else {
if (text === null || text === undefined) {
Expand All @@ -397,7 +434,7 @@ export class PreTrainedTokenizer
}

// For single input, we just wrap in an array, and then unwrap later.
encodedTokens = [this._encode_plus(text, { text_pair, add_special_tokens, return_token_type_ids })];
encodedTokens = [this._encode_plus(text, { text_pair, add_special_tokens, return_token_type_ids, return_offsets_mapping })];
}
// At this point, `encodedTokens` is batched, of shape [batch_size, tokens].
// However, array may be jagged. So, we may need pad to max_length.
Expand Down Expand Up @@ -444,7 +481,7 @@ export class PreTrainedTokenizer
padHelper(
encodedTokens[i],
max_length,
(key) => (key === 'input_ids' ? this.pad_token_id : 0),
(key) => (key === 'input_ids' ? this.pad_token_id : key === 'offset_mapping' ? [0, 0] : 0),
this.padding_side,
);
}
Expand All @@ -454,6 +491,12 @@ export class PreTrainedTokenizer

const result = {};

// offset_mapping is a number[][] — it cannot be tensorized.
// Extract it before the tensor loop and re-attach as a plain array.
const offsetMappings = return_offsets_mapping
? encodedTokens.map((x) => { const v = x.offset_mapping; delete x.offset_mapping; return v; })
: null;

if (return_tensor) {
if (!(padding && truncation)) {
// Not, guaranteed that all items have same length, so
Expand Down Expand Up @@ -502,7 +545,11 @@ export class PreTrainedTokenizer
}
}

return /** @type {BatchEncoding<BatchEncodingItem<TText, TReturnTensor>>} */ (result);
if (offsetMappings) {
result.offset_mapping = isBatched ? offsetMappings : offsetMappings[0];
}

return /** @type {BatchEncoding<BatchEncodingItem<TText, TReturnTensor>>} */ (/** @type {unknown} */ (result));
}

/**
Expand All @@ -524,11 +571,12 @@ export class PreTrainedTokenizer
* @param {string|null} [options.text_pair=null] The optional second text to encode.
* @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model.
* @param {boolean|null} [options.return_token_type_ids=null] Whether to return token_type_ids.
* @returns {{input_ids: number[], attention_mask: number[], token_type_ids?: number[]}} An object containing the encoded text.
* @param {boolean} [options.return_offsets_mapping=false] Whether to return character-level [start, end) offsets for each token.
* @returns {{input_ids: number[], attention_mask: number[], token_type_ids?: number[], offset_mapping?: [number, number][]}} An object containing the encoded text.
* @private
*/
_encode_plus(text, { text_pair = null, add_special_tokens = true, return_token_type_ids = null } = {}) {
const { ids, attention_mask, token_type_ids } = this._tokenizer.encode(text, {
_encode_plus(text, { text_pair = null, add_special_tokens = true, return_token_type_ids = null, return_offsets_mapping = false } = {}) {
const { ids, attention_mask, token_type_ids, tokens } = this._tokenizer.encode(text, {
text_pair,
add_special_tokens,
return_token_type_ids: return_token_type_ids ?? this.return_token_type_ids,
Expand All @@ -537,6 +585,7 @@ export class PreTrainedTokenizer
input_ids: ids,
attention_mask,
...(token_type_ids ? { token_type_ids } : {}),
...(return_offsets_mapping ? { offset_mapping: computeOffsets(tokens, text) } : {}),
};
}

Expand Down
3 changes: 3 additions & 0 deletions packages/transformers/src/transformers.js
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ export { DynamicCache } from './cache_utils.js';
// Cache and file management
export { ModelRegistry } from './utils/model_registry/ModelRegistry.js';

// Device utilities
export { get_available_devices } from './backends/onnx.js';

// Expose common types used across the library for developers to access
/**
* @typedef {import('./utils/hub.js').PretrainedModelOptions} PretrainedModelOptions
Expand Down
Loading