DistilBERT を SST-2 のデータセットで fine-tuning

1. 概要

2025年6月22日(日)に開催予定の第4回 岐阜AI勉強会の準備の一環で、DistilBERT を SST-2 のデータセットで fine-tuning する Google Colab 用のノートブックを作成しました。
しばらくこちらのブログを更新していなかったので、作成したノートブックについてのブログページを用意することにしました。

2. DistilBERT を SST-2 のデータセットで fine-tuning する Google Colab のノートブックについて

DistilBERT を SST-2 のデータセットで fine-tuning する Google Colab のノートブックをこちらのリンク
に用意しました。

2.1. SST-2 のデータセットの読み込み

下記の Python Script は SST-2 の train, validation データセットを pandas のライブラリで読み込みます。
SST-2 のデータセットは映画レビューの文章が positive か negative かに分類したデータセットです。

import pandas as pd

splits = {'train': 'data/train-00000-of-00001.parquet', 'validation': 'data/validation-00000-of-00001.parquet', 'test': 'data/test-00000-of-00001.parquet'}
df_train = pd.read_parquet("hf://datasets/stanfordnlp/sst2/" + splits["train"])
df_validation = pd.read_parquet("hf://datasets/stanfordnlp/sst2/" + splits["validation"])
2.1.1. SST-2 の train データセット

print 文で train データセットの中身を確認すると下記のようになっています。

print(df_train)

label が 0 の行は映画レビューの文章 (sentence) が positive なデータ、label が 1 の行は映画レビューの文章が negative なデータです。67,349 のデータがあります。下記の print 文の出力例では、長い映画レビューの文章は途中までしか表示されていません。

         idx                                           sentence  label
0          0       hide new secretions from the parental units       0
1          1               contains no wit , only labored gags       0
2          2  that loves its characters and communicates som...      1
3          3  remains utterly satisfied to remain the same t...      0
4          4  on the worst revenge-of-the-nerds clichés the ...      0
...      ...                                                ...    ...
67344  67344                               a delightful comedy       1
67345  67345                   anguish , anger and frustration       0
67346  67346  at achieving the modest , crowd-pleasing goals...      1
67347  67347                                  a patient viewer       1
67348  67348  this new jangle of noise , mayhem and stupidit...      0

[67349 rows x 3 columns]
2.1.2. SST-2 の validation データセット

validation データセットの中身も train データセットと同様です。872 のデータがあります。

print(df_validation)
     idx                                           sentence  label
0      0    it 's a charming and often affecting journey .       1
1      1                 unflinchingly bleak and desperate       0
2      2  allows us to hope that nolan is poised to emba...      1
3      3  the acting , costumes , music , cinematography...      1
4      4                  it 's slow -- very , very slow .       0
..   ...                                                ...    ...
867  867              has all the depth of a wading pool .       0
868  868              a movie with a real anarchic flair .       1
869  869  a subject like this should inspire reaction in...      0
870  870  ... is an arthritic attempt at directing by ca...      0
871  871  looking aristocratic , luminous yet careworn i...      1

[872 rows x 3 columns]
2.1.3. SST-2 の test データセット

下記の Python Script を実行し、test データセットの中身も確認しました。

df_test = pd.read_parquet("hf://datasets/stanfordnlp/sst2/" + splits["test"])
print(df_test)

下記の例のように label が全て -1 になっていたため、今回用意した Google Colab のノートブックでは使用しないことにしました。

       idx                                           sentence  label
0        0             uneasy mishmash of styles and genres .     -1
1        1  this film 's relationship to actual tension is...     -1
2        2  by the end of no such thing the audience , lik...     -1
3        3  director rob marshall went out gunning to make...     -1
4        4  lathan and diggs have considerable personal ch...     -1
...    ...                                                ...    ...
1816  1816  it risks seeming slow and pretentious , becaus...     -1
1817  1817  take care of my cat offers a refreshingly diff...     -1
1818  1818  davis has filled out his cast with appealing f...     -1
1819  1819  it represents better-than-average movie-making...     -1
1820  1820  dazzling and sugar-sweet , a blast of shallow ...     -1

[1821 rows x 3 columns]
2.2. train データセットを training 用と test 用に分割

今回用意した Google Colab のノートブックでは正解ラベル付きの test 用データも参照したかったので、下記の Python Script のように train データセットを training 用と test 用に分割することにしました。
下記のコードの例では test 用データセットのサイズを元のデータの 1/100 にしています。

from sklearn.model_selection import train_test_split

df_splitted_train, df_splitted_test = train_test_split(df_train, test_size=0.01)
2.2.1. 分割して用意した training 用データの確認

print 文で分割して用意した training 用データの中身を確認します。

print(df_splitted_train)

下記のように順番をシャッフルされた 66,675 のデータが表示されます。train データセットには 67,349 のデータが用意されているため、その約 99/100 のデータになります。

         idx                                           sentence  label
3031    3031                                30 seconds of plot       0
61271  61271  if it pared down its plots and characters to a...      0
39336  39336  gives the lie to many clichés and showcases a ...      1
44060  44060                        spends a bit too much time       0
212      212  if it is n't entirely persuasive , it does giv...      1
...      ...                                                ...    ...
37194  37194                        its provocative conclusion       1
6265    6265  an action film disguised as a war tribute is d...      0
54886  54886  goes to absurd lengths to duck the very issues...      0
860      860                             a perfect performance       1
15795  15795     constantly pulling the rug from underneath us       1

[66675 rows x 3 columns]
2.2.2. 分割して用意した test 用データの確認

print 文で分割して用意した test 用データの中身を確認します。

print(df_splitted_test)

下記の例のように順番をシャッフルされた 674 のデータが表示されます。train データセットには 67,349 のデータが用意されているため、その約 1/100 のデータになります。

         idx                                           sentence  label
66730  66730  with outtakes in which most of the characters ...      0
29890  29890                               enigma is well-made       1
45801  45801  is ) so stoked to make an important film about...      0
29352  29352  the closest thing to the experience of space t...      1
19858  19858                                 lose their luster       0
...      ...                                                ...    ...
26242  26242           , you will enjoy seeing how both evolve       1
3242    3242                      strange and beautiful film .       1
22756  22756  full of detail about the man and his country ,...      1
52906  52906                                        devastated       0
66781  66781                             both deserve better .       0

[674 rows x 3 columns]
2.3. fine-tuning 前の DistilBERT のモデルと tokenizer のインスタンスの取得

下記の Python Script を実行します。

from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast

# Load tokenizer and model
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
2.4. pandas の形式のデータセットを Dataset のインスタンスに変換

下記の Python Script で panads の形式のデータセットを Hugging Face datasets ライブラリの Dataset のインスタンスに変換します。
training 用のデータセットは、分割された training 用のデータ (df_splitted_train) から変換しています。
validation データセットは分割せずそのまま変換しています。

from datasets import Dataset

# Convert from pandas to Dataset
train_dataset = Dataset.from_pandas(df_splitted_train)
validation_dataset = Dataset.from_pandas(df_validation)
2.5. 映画レビューの文章を Token に変換

下記の Python Script で training と validation 用のデータセットの映画レビューの文章 (sentence) を DistilBERT の Token 配列に変換しています。Token の長さは最大でも 128 までとしています。SST-2 の DistilBERT 用の Token の最大長は 66 なのでこの条件を満たしています。Token の長さが 128 に満たない場合は padding するようにしています。

# Tokenize
def tokenize(batch):
    return tokenizer(batch["sentence"], padding="max_length", truncation=True, max_length=128)

# Convert datasets to tokenized format
tokenized_train_dataset = train_dataset.map(tokenize, batched=True)
tokenized_validation_dataset = validation_dataset.map(tokenize, batched=True)
2.6. Token に変換したデータの確認

下記の Python Script を実行し、映画レビューの文章がどのような Token に変換されているかを確認しました。

check_counter = 0
for example in tokenized_train_dataset:
    print(example)
    check_counter += 1
    if check_counter == 10:
        break

tokenized_train_dataset の中の 10 のデータの中身を print 文で出力しています。下記の出力結果のスクロールバーを横にスライドさせると ‘input_ids’, ‘attention_mask’ 等が表示されます。

{'idx': 12774, 'sentence': 'its make-believe promise of life ', 'label': 0, '__index_level_0__': 12774, 'input_ids': [101, 2049, 2191, 1011, 2903, 4872, 1997, 2166, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
{'idx': 13049, 'sentence': "'s a very very strong `` b + . '' ", 'label': 1, '__index_level_0__': 13049, 'input_ids': [101, 1005, 1055, 1037, 2200, 2200, 2844, 1036, 1036, 1038, 1009, 1012, 1005, 1005, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
{'idx': 43834, 'sentence': 'the road warrior ', 'label': 1, '__index_level_0__': 43834, 'input_ids': [101, 1996, 2346, 6750, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
{'idx': 32075, 'sentence': 'another first-rate performance ', 'label': 1, '__index_level_0__': 32075, 'input_ids': [101, 2178, 2034, 1011, 3446, 2836, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
{'idx': 13119, 'sentence': 'come from a family that eats , meddles , argues , laughs , kibbitzes and fights together ', 'label': 1, '__index_level_0__': 13119, 'input_ids': [101, 2272, 2013, 1037, 2155, 2008, 20323, 1010, 19960, 27822, 1010, 9251, 1010, 11680, 1010, 11382, 10322, 8838, 2229, 1998, 9590, 2362, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
{'idx': 38414, 'sentence': 'ellen pompeo pulls off the feat with aplomb ', 'label': 1, '__index_level_0__': 38414, 'input_ids': [101, 9155, 13433, 8737, 8780, 8005, 2125, 1996, 8658, 2007, 9706, 21297, 2497, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
{'idx': 66216, 'sentence': 'strong itch to explore more ', 'label': 1, '__index_level_0__': 66216, 'input_ids': [101, 2844, 2009, 2818, 2000, 8849, 2062, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
{'idx': 16102, 'sentence': "'s depressing to see how far herzog has fallen . ", 'label': 0, '__index_level_0__': 16102, 'input_ids': [101, 1005, 1055, 2139, 24128, 2000, 2156, 2129, 2521, 2014, 28505, 2038, 5357, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
{'idx': 55769, 'sentence': 'its visual imagination is breathtaking ', 'label': 1, '__index_level_0__': 55769, 'input_ids': [101, 2049, 5107, 9647, 2003, 3052, 17904, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
{'idx': 53622, 'sentence': 'of weird performances and direction ', 'label': 0, '__index_level_0__': 53622, 'input_ids': [101, 1997, 6881, 4616, 1998, 3257, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}

‘input_ids’ で指定されているのが Token の配列です。
[101, 2049, 2191, 1011, 2903, 4872, 1997, 2166, 102, 0, 0, 0, …, 0] のように 101 で始まっています。また、末尾は 102, 0, 0, 0, …, 0 で終わっています。101 は DistilBERT の [CLS] Token です。それに続く整数は文章内の単語等に対応しています。末尾の 102 は DistilBERT の [SEP] Token です。それに続く 0 は Token 配列の長さが 128 となるように padding するために置かれた整数です。

‘attention_mask’ で指定されているのはマスクをかける領域です。
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, …, 0] のように ‘input_ids’ において 0 で padding した領域は 0 でマスクをかけて無視するようにしています。マスクをかけない領域の値は 1 となっています。

2.7. fine-tuning の実行

下記の Python Script を実行し、fine-tuning を実行します。TrainingArguments には training 条件のパラメータを渡しています。

  • 下記の例ではエポック数は 3 となっています。DistilBERT の fine-tuning では 2 – 4 程度の値で良いようです。
  • learning rate は 1e-6 にしました。2e-5, 1e-5 等、一桁大きな値もセットしましたが、1e-6 にしたほうが test データの正解率は高くなりました。
  • training のバッチサイズ (per_device_train_batch_size) は 16, 32 あたりの値が使用されるようです。少し試して 32 より 16 にしたほうが test データの正解率が高くなりそうでしたので 16 にしました。
from transformers import Trainer, TrainingArguments

# Training arguments
training_args = TrainingArguments(
    output_dir="./distilbert-sst2",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    learning_rate=1e-6,
    weight_decay=0.01,
    eval_strategy="epoch",
    logging_strategy="steps",
    logging_steps=100,
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    warmup_steps=500,
    gradient_accumulation_steps=1,
    fp16=True  # if available
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_validation_dataset,
    processing_class=tokenizer,
)

# Train
trainer.train()

# Save the model and tokenizer
trainer.save_model('./fine_tuned_model_lr_1e_minus6_batch16')

Google Colab で T4 GPU を選択して実行しても上記の Python Script の fine-tuning のコードを実行するには10分から15分程度の時間を要します。

上記の Python Script を実行すると下記の例のような出力が得られます。乱数を使用しているため、実行するたびに数値は異なる値になるかと思います。

Epoch	Training Loss	Validation Loss
1	0.265100	0.301155
2	0.254000	0.286838
3	0.253200	0.288327

Training Loss は training データを対象としたロス関数の出力値です。Epoch 数が増えるにつれて減少しています。

Validation Loss は training データには含まれない validation 用のデータを対象としたロス関数の出力値です。上記の例では 2 Epoch 目の training が完了した時点で最小値 0.286838 を取り、3 Epoch 目が完了した時点では 0.288327 となっています。Training Loss が減少しているにもかかわらず、Validation Loss が増加しているため 3 Epoch 目では overfitting を起こしていることになります。

上記の Python Script では load_best_model_at_end=True と metric_for_best_model=”eval_loss” を指定しているため、Validation Loss が最も小さくなったときのモデルを fine-tuning したモデルとして採用するようにしています。

2.8. fine-tuning した DistilBERT で test データを分類

下記の Python Script を実行し、fine-tuning した DistilBERT で test データを分類しました。使用した test データは上記 2.2. で train データセットから test 用に分割したデータです。test データは上記 2.7. の fine-tuning で training にも validation にも使用されていないデータになります。

下記の Python Script を実行して得られた test データの分類の正解率は 0.9050 になりました。乱数を使用しているため、この値は上記 2.7. までを再実行すると異なる値になります。私が確認した際には 0.87 から 0.92 程度の値になりました。

from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
from sklearn.metrics import accuracy_score

# Load model and tokenizer
fine_tuned_tokenizer = DistilBertTokenizer.from_pretrained("./fine_tuned_model_lr_1e_minus6_batch16")
fine_tuned_model = DistilBertForSequenceClassification.from_pretrained("./fine_tuned_model_lr_1e_minus6_batch16")
fine_tuned_model.eval()

# Convert from pandas to Dataset
test_dataset = Dataset.from_pandas(df_splitted_test)

# Predict in batches
preds, labels = [], []
for example in test_dataset:
    inputs = tokenizer(example["sentence"], return_tensors="pt", padding="max_length", truncation=True, max_length=128)
    with torch.no_grad():
        outputs = fine_tuned_model(**inputs)
    logits = outputs.logits
    prediction = torch.argmax(logits, dim=-1).item()
    preds.append(prediction)
    labels.append(example["label"])

# Compute accuracy
acc = accuracy_score(labels, preds)
print(f"Validation Accuracy: {acc:.4f}")
Validation Accuracy: 0.9050

返信を残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

CAPTCHA