Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion .gradient/notebook-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,26 @@ useful-managing-ipu-resources:
generated: true
notebook:
file: managing_ipu_resources.ipynb
timeout: 1000
timeout: 1000

# Packed BERT tests
packed-bert-single-label:
location: ../packed-bert/
generated: true
notebook:
file: packedBERT_single_label_text_classification.ipynb
timeout: 10000

packed-bert-multi-label:
location: ../packed-bert/
generated: true
notebook:
file: packedBERT_multi_label_text_classification.ipynb
timeout: 10000

packed-bert-question-answering:
location: ../packed-bert/
generated: true
notebook:
file: packedBERT_question_answering.ipynb
timeout: 10000
25 changes: 25 additions & 0 deletions .gradient/prepare-datasets.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,30 @@ echo "Starting preparation of datasets"
# symlink exe_cache files
exe_cache_source_dir="${PUBLIC_DATASETS_DIR}/poplar-executables-hf-3-1"
symlink-public-resources "${exe_cache_source_dir}" $POPLAR_EXECUTABLE_CACHE_DIR

# packed bert executables
packed_sl_exe_cache_source_dir="${PUBLIC_DATASETS_DIR}/packed_bert_slseqcls_exe_cache/packed_bert_slseqcls"
symlink-public-resources "${packed_exe_cache_source_dir}" "${POPLAR_EXECUTABLE_CACHE_DIR}/packed_bert_slseqcls_exe_cache"
packed_ml_exe_cache_source_dir="${PUBLIC_DATASETS_DIR}/packed_bert_mlseqcls_exe_cache/packed_bert_mlseqcls"
symlink-public-resources "${packed_exe_cache_source_dir}" "${POPLAR_EXECUTABLE_CACHE_DIR}/packed_bert_mlseqcls_exe_cache"
packed_qa_exe_cache_source_dir="${PUBLIC_DATASETS_DIR}/packed_bert_qa_exe_cache/packed_bert_squad"
symlink-public-resources "${packed_exe_cache_source_dir}" "${POPLAR_EXECUTABLE_CACHE_DIR}/packed_bert_qa_exe_cache"

# packed bert datasets
packed_sl_dataset_source_dir="${PUBLIC_DATASETS_DIR}/packed_bert_slseqcls_dataset_cache"
symlink-public-resources "${packed_exe_cache_source_dir}" "${HF_DATASETS}/packed_bert_slseqcls_dataset_cache"
packed_ml_dataset_source_dir="${PUBLIC_DATASETS_DIR}/packed_bert_mlseqcls_dataset_cache"
symlink-public-resources "${packed_exe_cache_source_dir}" "${POPLAR_EXECUTABLE_CACHE_DIR}/packed_bert_mlseqcls_dataset_cache"
packed_qa_dataset_source_dir="${PUBLIC_DATASETS_DIR}/packed_bert_qa_dataset_cache"
symlink-public-resources "${packed_exe_cache_source_dir}" "${POPLAR_EXECUTABLE_CACHE_DIR}/packed_bert_qa_dataset_cache"

# packed bert inference checkpoints
symlink-public-resources "${PUBLIC_DATASETS_DIR}/bert-base-uncased-sst2" "${CHECKPOINT_DIR}/bert-base-uncased-sst2"
symlink-public-resources "${PUBLIC_DATASETS_DIR}/bert-base-uncased-go_emotions" "${CHECKPOINT_DIR}/bert-base-uncased-go_emotions"
symlink-public-resources "${PUBLIC_DATASETS_DIR}/bert-base-uncased-squad" "${CHECKPOINT_DIR}/bert-base-uncased-squad"



# symlink HF datasets
HF_DATASETS="conll2003 glue imagefolder librispeech_asr squad swag wikitext wmt16 xsum"
for dataset in ${HF_DATASETS}; do
Expand All @@ -50,6 +74,7 @@ for dataset in ${HF_DATASETS}; do
done
# Image classification dataset
symlink-public-resources "${PUBLIC_DATASETS_DIR}/dfki-sentinel-eurosat" "${DATASETS_DIR}/dfki-sentinel-eurosat"

# pre-install the correct version of optimum for this release
python -m pip install "optimum-graphcore>=0.5, <0.6"

Expand Down
29 changes: 29 additions & 0 deletions .gradient/settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,32 @@ integrations:
dfki-sentinel-eurosat:
type: dataset
ref: paperspace/ds8p6sv96fl1att:k5j4cob
bert-base-uncased-sst2:
type: dataset
ref: paperspace/dskrqljie6pti8y:mfqq5qk
bert-base-uncased-go_emotions:
type: dataset
ref: paperspace/dsz2f8usk60xbos:n3h8ko3
bert-base-uncased-squad:
type: dataset
ref: paperspace/ds9ogwc0fbfh799:3mv59lg
packed_bert_slseqcls_exe_cache:
type: dataset
ref: paperspace/dsfg0gcuqbr0pfc:0pss84k
packed_bert_mlseqcls_exe_cache:
type: dataset
ref: paperspace/dsevh3ol36qzpz2:1yme9yi
packed_bert_qa_exe_cache:
type: dataset
ref: paperspace/dsson0ib8byvqpf:tcgts2v
packed_bert_slseqcls_dataset_cache:
type: dataset
ref: paperspace/dsuuz3dih9su40i:npvb833
packed_bert_mlseqcls_dataset_cache:
type: dataset
ref: paperspace/dsuxwz4nqbbs07s:jipm3jh
packed_bert_qa_dataset_cache:
type: dataset
ref: paperspace/dssvktzrzcoaumk:5zhp5mf


Empty file added packed-bert/__init__.py
Empty file.
Binary file added packed-bert/images/go_emotions.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added packed-bert/models/__init__.py
Empty file.
282 changes: 282 additions & 0 deletions packed-bert/models/modeling_bert_packed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union

import torch
import torch.nn as nn

import poptorch
from optimum.graphcore.models.bert.modeling_bert import BertPipelineMixin
from transformers import BertForQuestionAnswering, BertForSequenceClassification
from transformers.modeling_outputs import QuestionAnsweringModelOutput


class PackedBertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.max_seq_per_pack = config.max_sequences_per_pack
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()

def forward(self, hidden_states):
"""
We "pool" the model by simply taking the hidden states corresponding
to the last max_sequences_per_pack tokens. Note that the [CLS] tokens
are always located at the end of the pack. When the actual number of
sequences is lower than max_sequences_per_pack, we still slice out
the last max_sequences_per_pack tokens, but we will not use all of
them during loss calculation.
"""
sh = hidden_states.shape
last_tokens_tensors = hidden_states[:, -self.max_seq_per_pack :]
last_reshape = last_tokens_tensors.reshape(sh[0] * self.max_seq_per_pack, sh[2])
# output size: [bs x max_sequences_per_pack, hidden_size]
output = self.dense(last_reshape)
output = self.activation(output)

return output


class PackedBertOutputsForMultiLabel(nn.Module):
"""
This class handles the custom model output phase for multi-label sequence classification.
"""

def __init__(self, config):
super().__init__()
self.max_seq_per_pack = config.max_sequences_per_pack
self.multi_loss = torch.nn.BCEWithLogitsLoss(reduction="none")

def forward(
self,
outputs: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor],
batch_dim: int,
labels: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor]:
max_labels = torch.max(attention_mask[:, : -self.max_seq_per_pack], dim=-1).values.unsqueeze(1)

# Create a mask corresponding to actual number of seqs in pack, to mask padding
label_mask = torch.arange(0, self.max_seq_per_pack).unsqueeze(0).repeat(batch_dim, 1)
label_mask = torch.where(
label_mask < max_labels,
torch.ones(batch_dim, self.max_seq_per_pack),
torch.zeros(batch_dim, self.max_seq_per_pack),
)
label_mask = label_mask.view(-1).unsqueeze(1)

# Adjust logits to rule out padding
logits = label_mask * outputs.logits

loss = None
if labels is not None:
# Flatten and adjust labels to rule out padding
labels = labels.view(-1, *(labels.size()[2:])).to(torch.float32)
labels = label_mask * labels

# Adjust the loss to rule out the padding and CLS logits
loss = self.multi_loss(logits, labels)
loss *= label_mask

# Take mean over each multi-class pred
loss = torch.sum(loss) / (torch.sum(max_labels) * labels.shape[-1])
loss = poptorch.identity_loss(loss, reduction="none")

logits = logits.reshape([batch_dim, self.max_seq_per_pack, logits.shape[-1]])

return (loss, logits)
else:
return logits


class PipelinedPackedBertForSequenceClassification(BertForSequenceClassification, BertPipelineMixin):
"""
This class supports doing single-label/multi-label sequence-classification tasks with custom outputs.
The problem_type must be passed to differentiate the two methods - multi_label_classification or single_label_classification. Multi-label requires a custom loss implementation to mask labels and logits, unlike single-label.

In both cases:
* The logits need to be reshaped at output to revert them from the 'unpacked' batch dimension to a batch dimension equivalent to that of the labels passed to the model in order for Optimum's trainer class to perform evaluation.

* The attention mask is reshaped from the 'packed' attention mask to an equivalent binary 3D "extended" attention mask for BERT to recognise the sequences within a single packed input as unrelated sequences.
"""

def __init__(self, config):
super().__init__(config)
self.max_seq_per_pack = config.max_sequences_per_pack
self.problem_type = config.problem_type
self.num_labels = config.num_labels

self.bert.pooler = PackedBertPooler(config)
self.multi_label_outputs = PackedBertOutputsForMultiLabel(config)

def parallelize(self):
super().parallelize()
last_ipu = self.ipu_config.ipus_per_replica - 1
self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu)
return self

def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Tuple[torch.Tensor]:
bs = input_ids.shape[0]
seq_len = input_ids.shape[1]

attention_mask_3d = attention_mask[:, None, :].repeat(1, seq_len, 1)
attention_mask_3d = (attention_mask_3d == attention_mask_3d.transpose(1, 2)) * (attention_mask_3d != 0)

# Manual masking of logits and loss only needed for multi-label, single-label loss allows ignore_index
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask_3d,
token_type_ids=token_type_ids,
position_ids=position_ids,
labels=labels if labels is not None and self.problem_type == "single_label_classification" else None,
)

if self.problem_type == "single_label_classification":
if labels is not None:
logits = output.logits.reshape([-1, self.max_seq_per_pack, self.num_labels])
output.logits = logits
output = (output.loss, output.logits)
else:
output = output.logits

else:
output = self.multi_label_outputs(
outputs=output, attention_mask=attention_mask, batch_dim=bs, labels=labels
)

return output


class PackedBertOutputsForQA(nn.Module):
"""
This class handles the custom output phase for a question-answering task.
"""

def __init__(self, config):
super().__init__()
# Use the default QA model output formatting class to return outputs in the same form as the base model.
self.output = QuestionAnsweringModelOutput
self.max_sequences_per_pack = config.max_sequences_per_pack

def forward(
self,
final_layer_output: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
# Create unpacking mask to separate packed logits out into sequence-specific logits only
unpacking_mask = attention_mask[:, None, :].repeat(1, self.max_sequences_per_pack, 1)
pack_seq_ids = torch.arange(1, self.max_sequences_per_pack + 1).view(self.max_sequences_per_pack, 1)

unpacking_mask = unpacking_mask == pack_seq_ids

# Expand start logits using mask to isolate logits for each internal sequence in the pack
unpacked_start_logits = final_layer_output.start_logits[:, None, :] * unpacking_mask
unpacked_end_logits = final_layer_output.end_logits[:, None, :] * unpacking_mask

# Calculate loss on logits/labels with initial [bs, mspp, ...] dims collapsed into one [bs*mspp, ...]
total_loss = None
if start_positions is not None and end_positions is not None:
start_positions = start_positions.view(-1)
end_positions = end_positions.view(-1)

unpacked_start_logits = unpacked_start_logits.contiguous()
unpacked_end_logits = unpacked_end_logits.contiguous()

unpacked_start_logits = unpacked_start_logits.view(-1, unpacked_start_logits.shape[-1])
unpacked_end_logits = unpacked_end_logits.view(-1, unpacked_end_logits.shape[-1])

loss_fct = nn.CrossEntropyLoss()
start_loss = loss_fct(unpacked_start_logits, start_positions)
end_loss = loss_fct(unpacked_end_logits, end_positions)

total_loss = (start_loss + end_loss) / 2

return self.output(
loss=total_loss,
start_logits=unpacked_start_logits,
end_logits=unpacked_end_logits,
hidden_states=final_layer_output.hidden_states,
attentions=final_layer_output.attentions,
)


class PipelinedPackedBertForQuestionAnswering(BertForQuestionAnswering, BertPipelineMixin):
"""
This class extends BertForQuestionAnswering with some differences required for packing. The 'packed' attention mask must be extended to a 3D binary "extended" attention mask for BERT to recognise the sequences within a single packed input as unrelated sequences. The output is extended to enable masking for padded labels, and then 'unpacking' the packed hidden state output before performing the loss calculation.
"""

def __init__(self, config):
super().__init__(config)
self.max_seq_per_pack = self.config.max_sequences_per_pack
self.packed_outputs = PackedBertOutputsForQA(config)

def parallelize(self):
super().parallelize()
last_ipu = self.ipu_config.ipus_per_replica - 1
self.qa_outputs = poptorch.BeginBlock(self.qa_outputs, "QA Outputs", ipu_id=last_ipu)
return self

def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Tuple[torch.Tensor]:
# Create 3D attention mask for sequence specific attention in pack
seq_len = input_ids.shape[1]
packed_attention_mask = attention_mask[:, None, :].repeat(1, seq_len, 1)
packed_attention_mask = (packed_attention_mask == packed_attention_mask.transpose(1, 2)) * (
packed_attention_mask != 0
)

# Run forwards pass through model without labels
final_layer_output = super().forward(
input_ids, attention_mask=packed_attention_mask, token_type_ids=token_type_ids, position_ids=position_ids
)

# Custom PackedBert for SQuAD output, redirect from before loss function in transformers model class.
output = self.packed_outputs(
final_layer_output,
attention_mask=attention_mask,
start_positions=start_positions,
end_positions=end_positions,
)

if start_positions is not None and end_positions is not None:
return poptorch.identity_loss(output.loss, reduction="mean"), output.start_logits, output.end_logits
else:
return output.start_logits, output.end_logits
Loading