Remaining Problems in AST
1. AST takes a long time to train and consumes large GPU memories
- because the transformer takes the audio spectrogram as a complete sequential data
- pretraining takes too long but without it, AST can ony achieve baseline performance on AudioSet dataset, which raises questions to its learning efficiency on the audio data
2. AST uses a class-token (CLS) to predict labels
- unable to predict the start and end time of events in audio samples
HTS-AT
a hierarchical audio transformer with a token-semantic module for audio classification
- achieves or equals SOTAs on AudioSet and ESC-50 and Speech Command V2 datasets
- the model without pretraining can still achieve the performance that is only 1~2% lower than the best results
- takes fewer parametes (31M vs 87M), fewer GPU memories, and less training time (80hrs vs 600hrs) than AST's to achieve best performance
- enables the audio transformer to produce the localization results of event only with weakly-labeled data and achieves a better performance than the previous CNN-based model
1. Hierarchical Transformer with Window Attention
Encode the Audio Spectrogram
- an audio mel-spectrogram is cut into different patch tokens with a Patch-Embed CNN of kernel size (P × P) and sent into the transformer
- different from images, the width and the height of an audio mel-spectrogram denote different information (i.e. the time and the frequency bin)
- the lenght of time is usually much longer than that of frequency bins
→ to better capture the relationship among frequency bins of the same time frame, first split the mel-spectrogram into patch windows w1, w2, ..., wn and then split the patches inside each window
- the order of tokens follows time → frequency → window
- with this order, patches with different frequency bins at the same time frame will be organized adjacently in the input sequence
Patch-Merge and Window Attention
- the patch tokens are sent into several groups of transforemr-encoder blocks
- at the end of each group, a Patch-Merge layer is implemented to reduce the sequence size
- as illustrated, the shape of the patch tokens is reduced by 8 times after 4 network groups
- for each transformer block inside the group, window attention mechanism is adopted to reduce the calculation
- first split the patch tokens (in 2D format) into non-overlapping (M x M) attention windows aw1, aw2, ..., awk
- then only compute the attention matrix inside each M x M attention window
- as a result, only have k window attention(WA) matrices instead of a whole global attention (GA) matrix
- as the network goes deeper, the Patch-Merge layer will mege adjacent windows, thus the attention relation is calculated in a larger space
- in the code implementation, used the swin transformer block with a shifted window attention, a more efficient window attention mechanism → also helps to use the swin transformer pretrained vision model in the experiment stage
2. Token Semantic Module
the AST uses a class-token(CLS) to predict the classifiation label, which limits it from further indicating the start and end times of events as realized in CNN-based models
In the final layer output, each token contains information about its corresponding time frames and frequency bins
→ convert tokens into activation maps for each label class
For strong-label datasets, let the model directly calculate the loss in specific time rages
For weakly-labeled datasets, leverage the transformer to locate via its strong capability to caputre the relation
In HTS-AT, modified the output structure by adding a token-semantic CNN layer afte the final transformer block.
It has a kernel isze (3, F/8P) and a padding size (1,0) to integrate all frequency bins and map the channel size 8D into the event classes C
→ the output (T/8P, C) is regarded as a event presence map
(where T represents the total number of tokens in the audio signal)
Finally, average the featuremap as the final vector (1, C) to compute the binary cross-entropy loss with the groundtruth labels
Overall, HTS-AT combines the power of transformers for capturing audio semantics with the localization ability of CNNs to address audio classification tasks that require both event identification and temporal localization
'Deep Learning' 카테고리의 다른 글
[Paper] Towards the Practical Utility of Federated Learning in the Medical Domain - review (0) | 2023.07.26 |
---|