This post was originally part of another that was dedicated to exploring small language models. But, because it turned out to be a significant tangent, I decided to make it separate. Here I dive into the T5 architecture, sentence representations, and ultimately create a new class for FLAN-T5 that is an encoder-only classifier.
For background, the original T5 model from Google (blog post here and paper here) is an encoder-decoder pre-trained on a larger and cleaner version of the common crawl corpus. The encoder-decoder architecture makes T5 well suited to tasks like language translation, text summarization, or question and answering. T5 capabilities were taken another step further with FLAN-T5 which used instruction fine-tuning. Because FLAN-T5 has been both pre-trained and fine-tuned on a variety of language tasks, this makes it a good candidate for a downstream task like natural language sequence classification. And, in the aforementioned post, (todo: link) I use it for such a task - specifically a binary question-answering task.
Now huggingface does offer a sequence classifier for T5, but internally it uses the full encoder-decoder. After looking into the pytorch code for that class, it appears that the first hidden state of the decoder is used as the sentence representation that gets fed to the classification head. And while I can see how the first hidden state of the decoder is a learned representation of the input sequence, I wondered whether or not just the encoder could be used. If so, this could effectively reduce the number of parameters in the model by half. In a real-time inference setting, this might allow the model to be placed on smaller hardware and improve inference latency - very important considerations for production machine learning.
However, the thought of removing the decoder does bring up an obvious question. What sentence representation do we use instead (i.e. what vector to feed to the classification layer)? Do we use the last hidden state of the encoder? Do we take the mean of all hidden states (masking all pad tokens)? BERT conveniently had a beginning-of-sentence token [CLS] but BERT was also pre-trained with a classification task (next sentence prediction) which allowed this token to capture an effective sentence representation. On the other hand, T5 doesn’t have a beginning-of-sentence token and, even with FLAN, it wasn’t pre-trained with a classification task like BERT.
The following snippet shows that end-of-sentence (</s>) and padding (<pad>) are the only default special tokens.
Where is my BOS token?</s><pad><pad><pad><pad><pad><pad><pad><pad>
How about using the end-of-sentence token? Well, if there are multiple input sequences it’s not so straightforward. For example, the question and answer case where we have both a question input and a passage input. Notice that a </s> token is also put between the inputs.
question ="do good samaritan laws protect those who help at an accident"passage ="Good Samaritan laws offer legal protection to people who..."encoded_inputs = tokenizer.encode_plus( question, passage, max_length=32, return_tensors="pt", padding="max_length",)decoded_inputs = tokenizer.decode(encoded_inputs["input_ids"][0], skip_special_tokens=False)print(decoded_inputs)
do good samaritan laws protect those who help at an accident</s> Good Samaritan laws offer legal protection to people who...</s><pad>
So the </s> token might not be an effective sentence representation for a task like binary question-answering. And while we could prepend a new classifier token to the input, even more fine-tuning would probably be required. Then what to use? Fortunately, this was already looked into (even before the FLAN release) in the following paper. Results show that mean pooling is an effective sentence representation, and may even perform better than using the first hidden state of the decoder.
With all this in mind, let’s sketch out the data flow of an encoder-only classifier for T5.
from transformers.models.t5.configuration_t5 import T5Configfrom transformers.models.t5.modeling_t5 import T5EncoderModelfrom transformers.models.t5.modeling_t5 import T5ClassificationHeadconfig = T5Config.from_pretrained(model_id)config.n_positions =64# limit input dimencoder = T5EncoderModel(config)# classifier = torch.nn.Linear(config.hidden_size, config.num_labels)config.num_labels =2classification_head = T5ClassificationHead(config)inputs = tokenizer.encode_plus( question, passage, max_length=config.n_positions, return_tensors="pt", padding="max_length", truncation=True)encoder_outputs = encoder( inputs.input_ids, attention_mask=inputs.attention_mask, return_dict=True,)sequence_output = encoder_outputs.last_hidden_statesequence_output.shape # batch size x input dim x hidden dim
torch.Size([1, 64, 512])
Before taking the mean of all final hidden states, I’ll have to mask out padding tokens. We should be able to use the input attention mask for this.
import torch# sanity check - last EOS token should match up with input attention maskeos_mask = inputs.input_ids.eq(config.eos_token_id)idx_last_eos_token = torch.where(eos_mask.squeeze())[0][-1]print(f"last EOS token index: {idx_last_eos_token}")print(f"attn mask at last EOS token: {inputs.attention_mask.squeeze()[idx_last_eos_token]}")print(f"attn mask after last EOS token: {inputs.attention_mask.squeeze()[idx_last_eos_token +1]}")assert inputs.attention_mask.squeeze()[: idx_last_eos_token +1].sum() == idx_last_eos_token +1valid_sequence_output = (sequence_output * inputs.attention_mask.unsqueeze(-1)) # attn mask 2D -> 3Dsentence_representation = valid_sequence_output.sum(dim=1) # sum across inputssentence_representation /= inputs.attention_mask.sum(dim=1).unsqueeze(-1) # mean of validsentence_representation.shape # batch size x hidden dim
last EOS token index: 30
attn mask at last EOS token: 1
attn mask after last EOS token: 0
torch.Size([1, 512])
And finally the logits after feeding the sentence representation to the classification layer.
logits = classification_head(sentence_representation)logits # batch size x num labels
Time to create a new class in a similar format to the encoder-decoder sequence classifier. Again, this will use mean pooling of the encoder hidden states for a sentence representation. In addition, this class can also be used to extract pre-trained FLAN weights (with the .from_pretrained method) and further fine-tune on downstream sequence classification tasks.
from transformers.models.t5.modeling_t5 import T5PreTrainedModelfrom torch.nn import CrossEntropyLossfrom transformers.modeling_outputs import SequenceClassifierOutputclass T5EncoderForSequenceClassification(T5PreTrainedModel):def__init__(self, config: T5Config):super().__init__(config)self.num_labels = config.num_labels# note that name of the encoder needs to be 'transformer'# to properly load all pre-trained weightsself.transformer = T5EncoderModel(config)# should probably add dropout here if trainingself.classifier = T5ClassificationHead(config)self.post_init() # init weightsdef forward(self, input_ids, attention_mask, labels =None) -> SequenceClassifierOutput: encoder_outputs =self.transformer(input_ids, attention_mask=attention_mask) sequence_output = encoder_outputs.last_hidden_state batch_size, _, hidden_size = sequence_output.shape sentence_representation = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) sentence_representation /= attention_mask.sum(dim=1).unsqueeze(-1) logits =self.classifier(sentence_representation) loss =Noneif labels isnotNone: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )
Now to verify that FLAN weights can be loaded into the encoder for the new class. Notice the message indicating that only classification head weights were not initialized (success!).
Some weights of T5EncoderForSequenceClassification were not initialized from the model checkpoint at google/flan-t5-small and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of T5ForSequenceClassification were not initialized from the model checkpoint at google/flan-t5-small and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
As another sanity check we can compare the number of parameters between the encoder-decoder classifer and encoder-only classifier.
def torch_trainable_params(model): trainable_parameters =filter(lambda p: p.requires_grad, model.parameters()) num_params =sum([np.prod(p.size()) for p in trainable_parameters])print(f"{model.__class__.__name__} trainable params: {num_params:,}")print("\nnum params:")torch_trainable_params(encoder_clf)torch_trainable_params(enc_dec_clf)
num params:
T5EncoderForSequenceClassification trainable params: 35,596,482
T5ForSequenceClassification trainable params: 60,775,298
Now for an important question: what is the new model size in memory and how does it compare?
model sizes:
T5EncoderForSequenceClassification model size: 135.790 MB
T5ForSequenceClassification model size: 231.839 MB
And finally, how much does using the encoder-only model reduce inference latency?
# from boolq datasetquestion ="is elder scrolls online the same as skyrim"passage =""" As with other games in The Elder Scrolls series, the game is set on the continent of Tamriel. The events of the game occur a millennium before those of The Elder Scrolls V: Skyrim and around 800 years before The Elder Scrolls III: Morrowind and The Elder Scrolls IV: Oblivion. It has a broadly similar structure to Skyrim, with two separate conflicts progressing at the same time, one with the fate of the world in the balance, and one where the prize is supreme power on Tamriel. In The Elder Scrolls Online, the first struggle is against the Daedric Prince Molag Bal, who is attempting to meld the plane of Mundus with his realm of Coldharbour, and the second is to capture the vacant imperial throne, contested by three alliances of the mortal races. The player character has been sacrificed to Molag Bal, and Molag Bal has stolen their soul, the recovery of which is the primary game objective."""inputs = tokenizer.encode_plus( question, passage, max_length=512, return_tensors="pt", padding="max_length", truncation=True)inputs_dict = {"input_ids": inputs.input_ids, "attention_mask": inputs.attention_mask}
# note this is on colab CPUenc_only_time =%timeit -o encoder_clf(**inputs_dict)
364 ms ± 5.22 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
910 ms ± 99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
So yes, this modification both decreased model size and improved inference latency! Whether or not the encoder-only model is just as performant (in terms of accuracy) as the encoder-decoder model is left to another post.