How PyTorch NestedTensors, FlashAttention2, and xFormers can Increase Efficiency and Scale back AI Prices
As generative AI (genAI) fashions develop in each recognition and scale, so do the computational calls for and prices related to their coaching and deployment. Optimizing these fashions is essential for enhancing their runtime efficiency and decreasing their operational bills. On the coronary heart of recent genAI programs is the Transformer structure and its consideration mechanism, which is notably compute-intensive.
In a previous post, we demonstrated how utilizing optimized consideration kernels can considerably speed up the efficiency of Transformer fashions. On this submit, we proceed our exploration by addressing the problem of variable-length enter sequences — an inherent property of real-world knowledge, together with paperwork, code, time-series, and extra.
The Problem of Batching Variable-Size Enter
In a typical deep studying workload, particular person samples are grouped into batches earlier than being copied to the GPU and fed to the AI mannequin. Batching improves computational effectivity and sometimes aids mannequin convergence throughout coaching. Normally, batching entails stacking the entire pattern tensors alongside a brand new dimension — the batch dimension. Nevertheless, torch.stack requires that each one tensors to have the identical form, which isn’t the case with variable-length sequences.
Padding and its Inefficiencies
The standard approach to handle this problem is to pad the enter sequences to a set size after which carry out stacking. This resolution requires acceptable masking throughout the mannequin in order that the output isn’t affected by the irrelevant tensor parts. Within the case of consideration layers, a padding masks signifies which tokens are padding and shouldn’t be attended to (e.g., see PyTorch MultiheadAttention). Nevertheless, padding can waste appreciable GPU sources, growing prices and slowing improvement. That is very true for large-scale AI fashions.
Don’t Pad, Concatenate
One approach to keep away from padding is to concatenate sequences alongside an present dimension as a substitute of stacking them alongside a brand new dimension. Opposite to torch.stack, torch.cat permits inputs of various shapes. The output of concatenation is single sequence whose size equals the sum of the lengths of the person sequences. For this resolution to work, our single sequence would must be supplemented by an consideration masks that may be certain that every token solely attends to different tokens in the identical unique sequence, in a course of generally known as document masking. Denoting the sum of the lengths of the entire particular person by N and adopting ”big O” notation, the dimensions of this masks would must be O(N²), as would the compute complexity of a regular consideration layer, making this resolution extremely inefficient.
Consideration Layer Optimization
The answer to this drawback comes within the type of specialised consideration layers. Opposite to the usual consideration layer that performs the complete set of O(N²) consideration scores solely to masks out the irrelevant ones, these optimized consideration kernels are designed to calculate solely the scores that matter. On this submit we are going to discover a number of options, every with their very own distinct traits. These embody:
Integration into Current HuggingFace Fashions
For groups working with pre-trained fashions, transitioning to those optimizations may appear difficult. We are going to show how HuggingFace’s APIs simplify this course of, enabling builders to combine these methods with minimal code adjustments and energy.
Disclaimers
- Please don’t interpret our use of any platforms, libraries, or optimization methods as an endorsement for his or her use. The perfect choices for you’ll rely tremendously on the specifics of your personal use-case.
- Among the APIs mentioned listed below are in prototype or beta phases and should change sooner or later.
- The code examples offered are for demonstrative functions solely. We make no claims relating to their accuracy, optimality, or robustness.
Particular because of Yitzhak Levi and Peleg Nahaliel for his or her contributions to this submit.
To facilitate our dialogue we are going to outline a easy generative mannequin (partially impressed by the GPT mannequin outlined here). For a extra complete information on constructing language fashions, please see one of many many wonderful tutorials out there on-line (e.g., here).
Transformer Block
We start by establishing a primary Transformer block, particularly designed to facilitate experimentation with totally different consideration mechanisms and optimizations. Whereas our block performs the identical computation as normal Transformer blocks, we make slight modifications to the standard alternative of operators in an effort to help the opportunity of PyTorch NestedTensor inputs (as described here).
# basic imports
import time, functools# torch imports
import torch
from torch.utils.knowledge import Dataset, DataLoader
import torch.nn as nn
# Outline Transformer settings
BATCH_SIZE = 32
NUM_HEADS = 16
HEAD_DIM = 64
DIM = NUM_HEADS * HEAD_DIM
DEPTH = 24
NUM_TOKENS = 1024
MAX_SEQ_LEN = 1024
PAD_ID = 0
DEVICE = 'cuda'
class MyAttentionBlock(nn.Module):
def __init__(
self,
attn_fn,
dim,
num_heads,
format=None,
**kwargs
):
tremendous().__init__()
self.attn_fn = attn_fn
self.num_heads = num_heads
self.dim = dim
self.head_dim = dim // num_heads
self.norm1 = nn.LayerNorm(dim, bias=False)
self.norm2 = nn.LayerNorm(dim, bias=False)
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
# mlp layers
self.fc1 = nn.Linear(dim, dim * 4)
self.act = nn.GELU()
self.fc2 = nn.Linear(dim * 4, dim)
self.permute = functools.partial(torch.transpose, dim0=1, dim1=2)
if format == 'bshd':
self.permute = nn.Identification()
def mlp(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
def reshape_and_permute(self,x, batch_size):
x = x.view(batch_size, -1, self.num_heads, self.head_dim)
return self.permute(x)
def ahead(self, x_in, attn_mask=None):
batch_size = x_in.measurement(0)
x = self.norm1(x_in)
qkv = self.qkv(x)
# slightly than first reformatting after which splitting the enter
# state, we first cut up after which reformat q, okay, v in an effort to
# help PyTorch Nested Tensors
q, okay, v = qkv.chunk(3, -1)
q = self.reshape_and_permute(q, batch_size)
okay = self.reshape_and_permute(okay, batch_size)
v = self.reshape_and_permute(v, batch_size)
# name the attn_fn with the enter attn_mask
x = self.attn_fn(q, okay, v, attn_mask=attn_mask)
# reformat output
x = self.permute(x).reshape(batch_size, -1, self.dim)
x = self.proj(x)
x = x + x_in
x = x + self.mlp(self.norm2(x))
return x
Transformer Decoder Mannequin
Constructing on our programmable Transformer block, we assemble a typical Transformer decoder mannequin.
class MyDecoder(nn.Module):
def __init__(
self,
block_fn,
num_tokens,
dim,
num_heads,
num_layers,
max_seq_len,
pad_idx=None
):
tremendous().__init__()
self.num_heads = num_heads
self.pad_idx = pad_idx
self.embedding = nn.Embedding(num_tokens, dim, padding_idx=pad_idx)
self.positional_embedding = nn.Embedding(max_seq_len, dim)
self.blocks = nn.ModuleList([
block_fn(
dim=dim,
num_heads=num_heads
)
for _ in range(num_layers)])
self.output = nn.Linear(dim, num_tokens)def embed_tokens(self, input_ids, position_ids=None):
x = self.embedding(input_ids)
if position_ids is None:
position_ids = torch.arange(input_ids.form[1],
system=x.system)
x = x + self.positional_embedding(position_ids)
return x
def ahead(self, input_ids, position_ids=None, attn_mask=None):
# Embed tokens and add positional encoding
x = self.embed_tokens(input_ids, position_ids)
if self.pad_idx isn't None:
assert attn_mask is None
# create a padding masks - we assume boolean masking
attn_mask = (input_ids != self.pad_idx)
attn_mask = attn_mask.view(BATCH_SIZE, 1, 1, -1)
.develop(-1, self.num_heads, -1, -1)
for b in self.blocks:
x = b(x, attn_mask)
logits = self.output(x)
return logits
Variable Size Sequence Enter
Subsequent, we create a dataset containing sequences of variable lengths, the place every sequence is made up of randomly generated tokens. For simplicity, we (arbitrarily) choose a set distribution for the sequence lengths. In real-world situations, the distribution of sequence lengths usually displays the character of the information, such because the size of paperwork or audio segments. Notice, that the distribution of lengths straight impacts the computational inefficiencies brought on by padding.
# Use random knowledge
class FakeDataset(Dataset):
def __len__(self):
return 1000000def __getitem__(self, index):
size = torch.randint(1, MAX_SEQ_LEN, (1,))
sequence = torch.randint(1, NUM_TOKENS, (size + 1,))
enter = sequence[:-1]
goal = sequence[1:]
return enter, goal
def pad_sequence(sequence, size, pad_val):
return torch.nn.useful.pad(
sequence,
(0, size - sequence.form[0]),
worth=pad_val
)
def collate_with_padding(batch):
padded_inputs = []
padded_targets = []
for b in batch:
padded_inputs.append(pad_sequence(b[0], MAX_SEQ_LEN, PAD_ID))
padded_targets.append(pad_sequence(b[1], MAX_SEQ_LEN, PAD_ID))
padded_inputs = torch.stack(padded_inputs, dim=0)
padded_targets = torch.stack(padded_targets, dim=0)
return {
'inputs': padded_inputs,
'targets': padded_targets
}
def data_to_device(knowledge, system):
if isinstance(knowledge, dict):
return {
key: data_to_device(val,system)
for key, val in knowledge.gadgets()
}
elif isinstance(knowledge, (record, tuple)):
return kind(knowledge)(
data_to_device(val, system) for val in knowledge
)
elif isinstance(knowledge, torch.Tensor):
return knowledge.to(system=system, non_blocking=True)
else:
return knowledge.to(system=system)
Coaching/Analysis Loop
Lastly, we implement a predominant operate that performs coaching/analysis on enter sequences of various size.
def predominant(
block_fn,
data_collate_fn=collate_with_padding,
pad_idx=None,
practice=True,
compile=False
):
torch.random.manual_seed(0)
system = torch.system(DEVICE)
torch.set_float32_matmul_precision("excessive")# Create dataset and dataloader
data_set = FakeDataset()
data_loader = DataLoader(
data_set,
batch_size=BATCH_SIZE,
collate_fn=data_collate_fn,
num_workers=12,
pin_memory=True,
drop_last=True
)
mannequin = MyDecoder(
block_fn=block_fn,
num_tokens=NUM_TOKENS,
dim=DIM,
num_heads=NUM_HEADS,
num_layers=DEPTH,
max_seq_len=MAX_SEQ_LEN,
pad_idx=pad_idx
).to(system)
if compile:
mannequin = torch.compile(mannequin)
# Outline loss and optimizer
criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_ID)
optimizer = torch.optim.SGD(mannequin.parameters())
def train_step(mannequin, inputs, targets,
position_ids=None, attn_mask=None):
with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
outputs = mannequin(inputs, position_ids, attn_mask)
outputs = outputs.view(-1, NUM_TOKENS)
targets = targets.flatten()
loss = criterion(outputs, targets)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
@torch.no_grad()
def eval_step(mannequin, inputs, targets,
position_ids=None, attn_mask=None):
with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
outputs = mannequin(inputs, position_ids, attn_mask)
if outputs.is_nested:
outputs = outputs.knowledge._values
targets = targets.knowledge._values
else:
outputs = outputs.view(-1, NUM_TOKENS)
targets = targets.flatten()
loss = criterion(outputs, targets)
return loss
if practice:
mannequin.practice()
step_fn = train_step
else:
mannequin.eval()
step_fn = eval_step
t0 = time.perf_counter()
summ = 0
depend = 0
for step, knowledge in enumerate(data_loader):
# Copy knowledge to GPU
knowledge = data_to_device(knowledge, system=system)
step_fn(mannequin, knowledge['inputs'], knowledge['targets'],
position_ids=knowledge.get('indices'),
attn_mask=knowledge.get('attn_mask'))
# Seize step time
batch_time = time.perf_counter() - t0
if step > 20: # Skip first steps
summ += batch_time
depend += 1
t0 = time.perf_counter()
if step >= 100:
break
print(f'common step time: {summ / depend}')
PyTorch SDPA with Padding
For our baseline experiments, we configure our Transformer block to make the most of PyTorch’s SDPA mechanism. In our experiments, we run each coaching and analysis, each with and with out torch.compile. These had been run on an NVIDIA H100 with CUDA 12.4 and PyTorch 2.5.1
from torch.nn.useful import scaled_dot_product_attention as sdpa
block_fn = functools.partial(MyAttentionBlock, attn_fn=sdpa)
causal_block_fn = functools.partial(
MyAttentionBlock,
attn_fn=functools.partial(sdpa, is_causal=True)
)for mode in ['eval', 'train']:
for compile in [False, True]:
block_func = causal_block_fn
if mode == 'practice' else block_fn
print(f'{mode} with {collate}, '
f'{"compiled" if compile else "uncompiled"}')
predominant(block_fn=block_func,
pad_idx=PAD_ID,
practice=mode=='practice',
compile=compile)
Efficiency Outcomes:
- Analysis: 132 milliseconds (ms) with out torch.compile, 130 ms with torch.compile
- Coaching: 342 ms with out torch.compile, 299 ms with torch.compile
On this part, we are going to discover a number of optimization methods for dealing with variable-length enter sequences in Transformer fashions.
Padding Optimization
Our first optimization relates to not the eye kernel however to our padding mechanism. Quite than padding the sequences in every batch to a relentless size, we pad to the size of the longest sequence within the batch. The next block of code consists of our revised collation operate and up to date experiments.
def collate_pad_to_longest(batch):
padded_inputs = []
padded_targets = []
max_length = max([b[0].form[0] for b in batch])
for b in batch:
padded_inputs.append(pad_sequence(b[0], max_length, PAD_ID))
padded_targets.append(pad_sequence(b[1], max_length, PAD_ID))
padded_inputs = torch.stack(padded_inputs, dim=0)
padded_targets = torch.stack(padded_targets, dim=0)
return {
'inputs': padded_inputs,
'targets': padded_targets
}for mode in ['eval', 'train']:
for compile in [False, True]:
block_func = causal_block_fn
if mode == 'practice' else block_fn
print(f'{mode} with {collate}, '
f'{"compiled" if compile else "uncompiled"}')
predominant(block_fn=block_func,
data_collate_fn=collate_pad_to_longest,
pad_idx=PAD_ID,
practice=mode=='practice',
compile=compile)
Padding to the longest sequence in every batch ends in a slight efficiency acceleration:
- Analysis: 129 ms with out torch.compile, 116 ms with torch.compile
- Coaching: 337 ms with out torch.compile, 294 ms with torch.compile
SDPA with PyTorch NestedTensors
Subsequent, we benefit from the built-in help for PyTorch NestedTensors in SDPA in analysis mode. Presently a prototype characteristic, PyTorch NestedTensors permits for grouping collectively tensors of various size. These are generally known as jagged or ragged tensors. Within the code block under, we outline a collation operate for grouping our sequences into NestedTensors. We additionally outline an indices entry in order that we are able to correctly calculate the positional embeddings.
PyTorch NestedTensors are supported by a limited number of PyTorch ops. Working round these limitations can require some creativity. For instance, addition between NestedTensors is just supported once they share exactly the identical “jagged” form. Within the code under we use a workaround to make sure that the indices entry shares the identical form because the mannequin inputs.
def nested_tensor_collate(batch):
inputs = torch.nested.as_nested_tensor([b[0] for b in batch],
structure=torch.jagged)
targets = torch.nested.as_nested_tensor([b[1] for b in batch],
structure=torch.jagged)
indices = torch.concat([torch.arange(b[0].form[0]) for b in batch])# workaround for making a NestedTensor with similar "jagged" form
xx = torch.empty_like(inputs)
xx.knowledge._values[:] = indices
return {
'inputs': inputs,
'targets': targets,
'indices': xx
}
for compile in [False, True]:
print(f'eval with nested tensors, '
f'{"compiled" if compile else "uncompiled"}')
predominant(
block_fn=block_fn,
data_collate_fn=nested_tensor_collate,
practice=False,
compile=compile
)
Though, with torch.compile, the NestedTensor optimization ends in a step time of 131 ms, much like our baseline outcome, in compiled mode the step time drops to 42 ms for a formidable ~3x enchancment.
FlashAttention2
In our previous post we demonstrated using FlashAttention and its influence on the efficiency of a transformer mannequin. On this submit we show using flash_attn_varlen_func from flash-attn (2.7.0), an API designed to be used with variable-sized inputs. To make use of this operate, we concatenate the entire sequences within the batch right into a single sequence. We additionally create a cu_seqlens tensor that factors to the indices throughout the concatenated tensor the place every of the person sequences begin. The code block under consists of our collation operate adopted by analysis and coaching experiments. Notice, that flash_attn_varlen_func doesn’t help torch.compile (on the time of this writing).
def collate_concat(batch):
inputs = torch.concat([b[0] for b in batch]).unsqueeze(0)
targets = torch.concat([b[1] for b in batch]).unsqueeze(0)
indices = torch.concat([torch.arange(b[0].form[0]) for b in batch])
seqlens = torch.tensor([b[0].form[0] for b in batch])
seqlens = torch.cumsum(seqlens, dim=0, dtype=torch.int32)
cu_seqlens = torch.nn.useful.pad(seqlens, (1, 0))return {
'inputs': inputs,
'targets': targets,
'indices': indices,
'attn_mask': cu_seqlens
}
from flash_attn import flash_attn_varlen_func
fa_varlen = lambda q, okay, v, attn_mask: flash_attn_varlen_func(
q.squeeze(0),
okay.squeeze(0),
v.squeeze(0),
cu_seqlens_q=attn_mask,
cu_seqlens_k=attn_mask,
max_seqlen_q=MAX_SEQ_LEN,
max_seqlen_k=MAX_SEQ_LEN
).unsqueeze(0)
fa_varlen_causal = lambda q, okay, v, attn_mask: flash_attn_varlen_func(
q.squeeze(0),
okay.squeeze(0),
v.squeeze(0),
cu_seqlens_q=attn_mask,
cu_seqlens_k=attn_mask,
max_seqlen_q=MAX_SEQ_LEN,
max_seqlen_k=MAX_SEQ_LEN,
causal=True
).unsqueeze(0)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=fa_varlen,
format='bshd')
causal_block_fn = functools.partial(MyAttentionBlock,
attn_fn=fa_varlen_causal,
format='bshd')
print('flash-attn eval')
predominant(
block_fn=block_fn,
data_collate_fn=collate_concat,
practice=False
)
print('flash-attn practice')
predominant(
block_fn=causal_block_fn,
data_collate_fn=collate_concat,
practice=True,
)
The influence of this optimization is dramatic, 51 ms for analysis and 160 ms for coaching, amounting to 2.6x and a pair of.1x efficiency boosts in comparison with our baseline experiment.
XFormers Reminiscence Environment friendly Consideration
In our earlier submit we demonstrated using the memory_efficient_attention operator from xFormers (0.0.28). Right here we show using BlockDiagonalMask, particularly designed for enter sequences of arbitrary size. The required collation operate seems within the code block under adopted by the analysis and coaching experiments. Notice, that torch.compile failed in coaching mode.
from xformers.ops import fmha
from xformers.ops import memory_efficient_attention as meadef collate_xformer(batch):
inputs = torch.concat([b[0] for b in batch]).unsqueeze(0)
targets = torch.concat([b[1] for b in batch]).unsqueeze(0)
indices = torch.concat([torch.arange(b[0].form[0]) for b in batch])
seqlens = [b[0].form[0] for b in batch]
batch_sizes = [1 for b in batch]
block_diag = fmha.BlockDiagonalMask.from_seqlens(seqlens, system='cpu')
block_diag._batch_sizes = batch_sizes
return {
'inputs': inputs,
'targets': targets,
'indices': indices,
'attn_mask': block_diag
}
mea_eval = lambda q, okay, v, attn_mask: mea(
q,okay,v, attn_bias=attn_mask)
mea_train = lambda q, okay, v, attn_mask: mea(
q,okay,v, attn_bias=attn_mask.make_causal())
block_fn = functools.partial(MyAttentionBlock,
attn_fn=mea_eval,
format='bshd')
causal_block_fn = functools.partial(MyAttentionBlock,
attn_fn=mea_train,
format='bshd')
print(f'xFormer Consideration ')
for compile in [False, True]:
print(f'eval with xFormer Consideration, '
f'{"compiled" if compile else "uncompiled"}')
predominant(block_fn=block_fn,
practice=False,
data_collate_fn=collate_xformer,
compile=compile)
print(f'practice with xFormer Consideration')
predominant(block_fn=causal_block_fn,
practice=True,
data_collate_fn=collate_xformer)
The resultant step time had been 50 ms and 159 ms for analysis and coaching with out torch.compile. Analysis with torch.compile resulted in a step time of 42 ms.
Outcomes
The desk under summarizes the outcomes of our optimization strategies.
The perfect performer for our toy mannequin is xFormer’s memory_efficient_attention which delivered a ~3x efficiency for analysis and ~2x efficiency for coaching. We warning in opposition to deriving any conclusions from these outcomes because the efficiency influence of various consideration features can differ considerably relying on the particular mannequin and use case.
The instruments and methods described above are straightforward to implement when making a mannequin from scratch. Nevertheless, nowadays it isn’t unusual for ML builders to undertake present (pretrained) fashions and finetune them for his or her use case. Whereas the optimizations we’ve described might be built-in with out altering the set of mannequin weights and with out altering the mannequin habits, it isn’t fully clear what the easiest way to do that is. In a really perfect world, our ML framework would permit us to program using an consideration mechanism that’s optimized for variable-length inputs. On this part we show the right way to optimize HuggingFace fashions for variable-length inputs.
A Toy HuggingFace Mannequin – GPT2LMHeadModel
To facilitate the dialogue, we create a toy instance wherein we practice a HuggingFace GPT2LMHead mannequin on variable-length sequences. This requires adapting our random dataset and data-padding collation operate in keeping with HuggingFace’s enter specs.
from transformers import GPT2Config, GPT2LMHeadModel# Use random knowledge
class HuggingFaceFakeDataset(Dataset):
def __len__(self):
return 1000000
def __getitem__(self, index):
size = torch.randint(1, MAX_SEQ_LEN, (1,))
input_ids = torch.randint(1, NUM_TOKENS, (size,))
labels = input_ids.clone()
labels[0] = PAD_ID # ignore first token
return {
'input_ids': input_ids,
'labels': labels
}
return input_ids, labels
def hf_collate_with_padding(batch):
padded_inputs = []
padded_labels = []
for b in batch:
input_ids = b['input_ids']
labels = b['labels']
padded_inputs.append(pad_sequence(input_ids, MAX_SEQ_LEN, PAD_ID))
padded_labels.append(pad_sequence(labels, MAX_SEQ_LEN, PAD_ID))
padded_inputs = torch.stack(padded_inputs, dim=0)
padded_labels = torch.stack(padded_labels, dim=0)
return {
'input_ids': padded_inputs,
'labels': padded_labels,
'attention_mask': (padded_inputs != PAD_ID)
}
Coaching Perform
Our coaching operate instantiates a GPT2LMHeadModel based mostly on the requested GPT2Config and proceeds to coach it on our variable-length sequences.
def hf_main(
config,
collate_fn=hf_collate_with_padding,
compile=False
):
torch.random.manual_seed(0)
system = torch.system(DEVICE)
torch.set_float32_matmul_precision("excessive")# Create dataset and dataloader
data_set = HuggingFaceFakeDataset()
data_loader = DataLoader(
data_set,
batch_size=BATCH_SIZE,
collate_fn=collate_fn,
num_workers=12 if DEVICE == "CUDA" else 0,
pin_memory=True,
drop_last=True
)
mannequin = GPT2LMHeadModel(config).to(system)
if compile:
mannequin = torch.compile(mannequin)
# Outline loss and optimizer
criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_ID)
optimizer = torch.optim.SGD(mannequin.parameters())
mannequin.practice()
t0 = time.perf_counter()
summ = 0
depend = 0
for step, knowledge in enumerate(data_loader):
# Copy knowledge to GPU
knowledge = data_to_device(knowledge, system=system)
input_ids = knowledge['input_ids']
labels = knowledge['labels']
position_ids = knowledge.get('position_ids')
attn_mask = knowledge.get('attention_mask')
with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
outputs = mannequin(input_ids=input_ids,
position_ids=position_ids,
attention_mask=attn_mask)
logits = outputs.logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
loss = criterion(logits.view(-1, NUM_TOKENS), labels.flatten())
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
# Seize step time
batch_time = time.perf_counter() - t0
if step > 20: # Skip first steps
summ += batch_time
depend += 1
t0 = time.perf_counter()
if step >= 100:
break
print(f'common step time: {summ / depend}')
SDPA with Padding
Within the callback under we name our coaching operate with the default sequence-padding collator.
config = GPT2Config(
n_layer=DEPTH,
n_embd=DIM,
n_head=NUM_HEADS,
vocab_size=NUM_TOKENS,
)for compile in [False, True]:
print(f"HF GPT2 practice with SDPA, compile={compile}")
hf_main(config=config, compile=compile)
The resultant step occasions are 815 ms with out torch.compile and 440 ms with torch.compile.
FlashAttention2
We now benefit from HuggingFace’s built-in support for FlashAttention2, by setting the attn_implementation parameter to “flash_attention_2”. Behind the scenes, HuggingFace will unpad the padded knowledge enter after which cross them to the optimized flash_attn_varlen_func operate we noticed above:
flash_config = GPT2Config(
n_layer=DEPTH,
n_embd=DIM,
n_head=NUM_HEADS,
vocab_size=NUM_TOKENS,
attn_implementation='flash_attention_2'
)print(f"HF GPT2 practice with flash")
hf_main(config=flash_config)
The resultant time step is 620 ms, amounting to a 30% increase (in uncompiled mode) with only a easy flick of a change.
FlashAttention2 with Unpadded Enter
After all, padding the sequences within the collation operate solely to have them unpadded, hardly appears smart. In a latest update to HuggingFace, help was added for passing in concatenated (unpadded) sequences to a choose variety of fashions. Sadly, (as of the time of this writing) our GPT2 mannequin didn’t make the reduce. Nevertheless, including help requires simply 5 small line additions adjustments to modeling_gpt2.py in an effort to propagate the sequence position_ids to the flash-attention kernel. The total patch seems within the block under:
@@ -370,0 +371 @@
+ position_ids = None
@@ -444,0 +446 @@
+ position_ids=position_ids
@@ -611,0 +614 @@
+ position_ids=None
@@ -621,0 +625 @@
+ position_ids=position_ids
@@ -1140,0 +1145 @@
+ position_ids=position_ids
We outline a collate operate that concatenates our sequences and practice our hugging face mannequin on unpadded sequences. (Additionally see the built-in DataCollatorWithFlattening utility.)
def collate_flatten(batch):
input_ids = torch.concat([b['input_ids'] for b in batch]).unsqueeze(0)
labels = torch.concat([b['labels'] for b in batch]).unsqueeze(0)
position_ids = [torch.arange(b['input_ids'].form[0]) for b in batch]
position_ids = torch.concat(position_ids)return {
'input_ids': input_ids,
'labels': labels,
'position_ids': position_ids
}
print(f"HF GPT2 practice with flash, no padding")
hf_main(config=flash_config, collate_fn=collate_flatten)
The ensuing step time is 323 ms, 90% sooner than working flash-attention on the padded enter.
Outcomes
The outcomes of our HuggingFace experiments are summarized under.
With little effort, we had been capable of increase our runtime efficiency by 2.5x when in comparison with the uncompiled baseline experiment, and by 36% when in comparison with the compiled model.
On this part, we demonstrated how the HuggingFace APIs permit us to leverage the optimized kernels in FlashAttention2, considerably boosting the coaching efficiency of present fashions on sequences of various size.
As AI fashions proceed to develop in each recognition and complexity, optimizing their efficiency has turn out to be important for decreasing runtime and prices. That is very true for compute-intensive parts like consideration layers. On this submit, we’ve continued our exploration of consideration layer optimization, and demonstrated new instruments and methods for enhancing Transformer mannequin efficiency. For extra insights on AI mannequin optimization, make sure to take a look at the first post on this sequence in addition to our many other posts on this subject.