Multi-label text classification

4 min read


  1. The Problem

  2. Binary Cross Entropy Loss

  3. Solution

    1. Process and Prepare Text
    2. Save Data
    3. Train with BCE loss
    4. Evaluate and Understand Misclassifications
  4. Links

The Problem

A multi-label classifier maps x inputs to y labels. It’s like running y different binary classification problems. If y=1 we’re in binary classification territory.

Let’s assume we have text with each row containing text and a label that is a comma separated list of multi-labels. Our assigment is to build a multi-label classifier with this dataset.

To build a multi-label classifier first we must understand binary cross entropy with logits loss.

Binary Cross Entropy Loss

This loss is designed to take a multi-hot encoded array of labels with as many columns as number of classes and rows as number of samples. The column is 1 if the class applies to the sample and 0 otherwise. Multiple columns can be 1 for each sample. See this notebook for a practical example.

The loss is defined as:

(x,y)=L={l1,,lN},ln=wn[ynlogσ(xn)+(1yn)log(1σ(xn))], \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - w_n \left[ y_n \cdot \log \sigma(x_n) + (1 - y_n) \cdot \log (1 - \sigma(x_n)) \right],

Which is reduced over all train samples:

(x,y)={mean(L),if reduction=‘mean’;sum(L),if reduction=‘sum’. \ell(x, y) = \begin{cases} \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases}

The default reduction method in PyTorch is ‘mean’.

Weighted case:

c(x,y)=Lc={l1,c,,lN,c},ln,c=wn,c[pcyn,clogσ(xn,c)+(1yn,c)log(1σ(xn,c))], \ell_c(x, y) = L_c = \{l_{1,c},\dots,l_{N,c}\}^\top, \quad l_{n,c} = - w_{n,c} \left[ p_c y_{n,c} \cdot \log \sigma(x_{n,c}) + (1 - y_{n,c}) \cdot \log (1 - \sigma(x_{n,c})) \right],

There are 2 weights here, wn,cw_{n,c} is the sample weight for rescaling each sample’s loss. It’s hard to get weights for each sample in practice. A more practical weight is the positive class weight, pcp_c assigned to each class. Setting pcp_c at the class level helps assign greater importance to a minority class. Technically, pc>1p_c > 1 improves recall and pc<1p_c < 1 improves precision.

Computationally, there is also a bit of detail around the log-sum-exp trick that helps prevent numerical overflow while converting probability values to and from log probabilties, but that’s for another post.


Process and Prepare Text

Here, we want to process the text, including any preprocessing and tokenization. We also add a column indicating the data split to the original data. Remember to stratify your splits especially in imbalanced data scenarios. This helps the model train on data points from even minority classes.

Here you might want to think about:

  • What is my problem domain? What kind of preprocessing do I want? e.g. I want to clean up URLs in text, remove digits and punctuations.
  • What kind of text representation do I want to use? This might be dependent on your latency budgets, importance of case sensitivity for your domain, the training budget, length of text, languages involved etc.
  • What’s the quality of my data? Are my multi-labels well propagated? Are there missing multi-labels?
  • To prevent overfitting, do I need strong negative samples that have token overlap with my positive samples?

We can start off with a DistilBERT uncased transformer model.

Save Data

Once we write up the preprocesing and tokenization, we save the data to S3 in a format that works with SageMaker Hugging Face Estimator.

Train with BCE loss

After deriving class weights, we fine-tune DistilBERT for 2-3 epochs. Generally setting too small class weights (e.g. 0.2) doesn’t help the model learn those class samples in the case of extreme imbalance. I recommend setting class weight at least >0.5>0.5 for each class. If one class has a very high weight from weight calculations, it might be worthwhile reducing it and bringing it a bit closer in range with other weights. Here some experimentation loops can help select well performing class weights.

Evaluate and Understand Misclassifications

Finally, we evaluate the model and understand the misclassifications through integrated gradient scores. We want to focus on over-fitted tokens and ensure that the model has samples with these overfitted tokens across classes, perhaps as negative samples in other classes.