본문 바로가기

Deep Learning

[Paper] HTS-AT : A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection

Remaining Problems in AST

 

1. AST takes a long time to train and consumes large GPU memories

- because the transformer takes the audio spectrogram as a complete sequential data

- pretraining takes too long but without it, AST can ony achieve baseline performance on AudioSet dataset, which raises questions to its learning efficiency on the audio data

 

2. AST uses a class-token (CLS) to predict labels

- unable to predict the start and end time of events in audio samples

 

HTS-AT

a hierarchical audio transformer with a token-semantic module for audio classification

 

- achieves or equals SOTAs on AudioSet and ESC-50 and Speech Command V2 datasets

- the model without pretraining can still achieve the performance that is only 1~2% lower than the best results

- takes fewer parametes (31M vs 87M), fewer GPU memories, and less training time (80hrs vs 600hrs) than AST's to achieve best performance

- enables the audio transformer to produce the localization results of event only with weakly-labeled data and achieves a better performance than the previous CNN-based model

 

1. Hierarchical Transformer with Window Attention

 

Model Architecture

 

Encode the Audio Spectrogram

 

- an audio mel-spectrogram is cut into different patch tokens with a Patch-Embed CNN of kernel size (P × P) and sent into the transformer

- different from images, the width and the height of an audio mel-spectrogram denote different information (i.e. the time and the frequency bin)

- the lenght of time is usually much longer than that of frequency bins

→ to better capture the relationship among frequency bins of the same time frame, first split the mel-spectrogram into patch windows w1, w2, ..., wn and then split the patches inside each window

- the order of tokens follows time → frequency → window

- with this order, patches with different frequency bins at the same time frame will be organized adjacently in the input sequence

 

Patch-Merge and Window Attention

 

- the patch tokens are sent into several groups of transforemr-encoder blocks

- at the end of each group, a Patch-Merge layer is implemented to reduce the sequence size

- as illustrated, the shape of the patch tokens is reduced by 8 times after 4 network groups

 

- for each transformer block inside the group, window attention mechanism is adopted to reduce the calculation

- first split the patch tokens (in 2D format) into non-overlapping (M x M) attention windows aw1, aw2, ..., awk

- then only compute the attention matrix inside each M x M attention window

- as a result, only have k window attention(WA) matrices instead of a whole global attention (GA) matrix

- as the network goes deeper, the Patch-Merge layer will mege adjacent windows, thus the attention relation is calculated in a larger space

- in the code implementation, used the swin transformer block with a shifted window attention, a more efficient window attention mechanism → also helps to use the swin transformer pretrained vision model in the experiment stage

 

2. Token Semantic Module

 

the AST uses a class-token(CLS) to predict the classifiation label, which limits it from further indicating the start and end times of events as realized in CNN-based models

In the final layer output, each token contains information about its corresponding time frames and frequency bins

→ convert tokens into activation maps for each label class

 

For strong-label datasets, let the model directly calculate the loss in specific time rages

For weakly-labeled datasets, leverage the transformer to locate via its strong capability to caputre the relation

 

In HTS-AT, modified the output structure by adding a token-semantic CNN layer afte the final transformer block.

It has a kernel isze (3, F/8P) and a padding size (1,0) to integrate all frequency bins and map the channel size 8D into the event classes C

→ the output (T/8P, C) is regarded as a event presence map 

(where T represents the total number of tokens in the audio signal)

 

Finally, average the featuremap as the final vector (1, C) to compute the binary cross-entropy loss with the groundtruth labels

 

Overall, HTS-AT combines the power of transformers for capturing audio semantics with the localization ability of CNNs to address audio classification tasks that require both event identification and temporal localization