1. 概要
2025年6月22日(日)に開催予定の第4回 岐阜AI勉強会の準備の一環で、DistilBERT を SST-2 のデータセットで fine-tuning する Google Colab 用のノートブックを作成しました。
しばらくこちらのブログを更新していなかったので、作成したノートブックについてのブログページを用意することにしました。
2. DistilBERT を SST-2 のデータセットで fine-tuning する Google Colab のノートブックについて
DistilBERT を SST-2 のデータセットで fine-tuning する Google Colab のノートブックをこちらのリンク
に用意しました。
2.1. SST-2 のデータセットの読み込み
下記の Python Script は SST-2 の train, validation データセットを pandas のライブラリで読み込みます。
SST-2 のデータセットは映画レビューの文章が positive か negative かに分類したデータセットです。
import pandas as pd splits = {'train': 'data/train-00000-of-00001.parquet', 'validation': 'data/validation-00000-of-00001.parquet', 'test': 'data/test-00000-of-00001.parquet'} df_train = pd.read_parquet("hf://datasets/stanfordnlp/sst2/" + splits["train"]) df_validation = pd.read_parquet("hf://datasets/stanfordnlp/sst2/" + splits["validation"])
2.1.1. SST-2 の train データセット
print 文で train データセットの中身を確認すると下記のようになっています。
print(df_train)
label が 0 の行は映画レビューの文章 (sentence) が positive なデータ、label が 1 の行は映画レビューの文章が negative なデータです。67,349 のデータがあります。下記の print 文の出力例では、長い映画レビューの文章は途中までしか表示されていません。
idx sentence label 0 0 hide new secretions from the parental units 0 1 1 contains no wit , only labored gags 0 2 2 that loves its characters and communicates som... 1 3 3 remains utterly satisfied to remain the same t... 0 4 4 on the worst revenge-of-the-nerds clichés the ... 0 ... ... ... ... 67344 67344 a delightful comedy 1 67345 67345 anguish , anger and frustration 0 67346 67346 at achieving the modest , crowd-pleasing goals... 1 67347 67347 a patient viewer 1 67348 67348 this new jangle of noise , mayhem and stupidit... 0 [67349 rows x 3 columns]
2.1.2. SST-2 の validation データセット
validation データセットの中身も train データセットと同様です。872 のデータがあります。
print(df_validation)
idx sentence label 0 0 it 's a charming and often affecting journey . 1 1 1 unflinchingly bleak and desperate 0 2 2 allows us to hope that nolan is poised to emba... 1 3 3 the acting , costumes , music , cinematography... 1 4 4 it 's slow -- very , very slow . 0 .. ... ... ... 867 867 has all the depth of a wading pool . 0 868 868 a movie with a real anarchic flair . 1 869 869 a subject like this should inspire reaction in... 0 870 870 ... is an arthritic attempt at directing by ca... 0 871 871 looking aristocratic , luminous yet careworn i... 1 [872 rows x 3 columns]
2.1.3. SST-2 の test データセット
下記の Python Script を実行し、test データセットの中身も確認しました。
df_test = pd.read_parquet("hf://datasets/stanfordnlp/sst2/" + splits["test"]) print(df_test)
下記の例のように label が全て -1 になっていたため、今回用意した Google Colab のノートブックでは使用しないことにしました。
idx sentence label 0 0 uneasy mishmash of styles and genres . -1 1 1 this film 's relationship to actual tension is... -1 2 2 by the end of no such thing the audience , lik... -1 3 3 director rob marshall went out gunning to make... -1 4 4 lathan and diggs have considerable personal ch... -1 ... ... ... ... 1816 1816 it risks seeming slow and pretentious , becaus... -1 1817 1817 take care of my cat offers a refreshingly diff... -1 1818 1818 davis has filled out his cast with appealing f... -1 1819 1819 it represents better-than-average movie-making... -1 1820 1820 dazzling and sugar-sweet , a blast of shallow ... -1 [1821 rows x 3 columns]
2.2. train データセットを training 用と test 用に分割
今回用意した Google Colab のノートブックでは正解ラベル付きの test 用データも参照したかったので、下記の Python Script のように train データセットを training 用と test 用に分割することにしました。
下記のコードの例では test 用データセットのサイズを元のデータの 1/100 にしています。
from sklearn.model_selection import train_test_split df_splitted_train, df_splitted_test = train_test_split(df_train, test_size=0.01)
2.2.1. 分割して用意した training 用データの確認
print 文で分割して用意した training 用データの中身を確認します。
print(df_splitted_train)
下記のように順番をシャッフルされた 66,675 のデータが表示されます。train データセットには 67,349 のデータが用意されているため、その約 99/100 のデータになります。
idx sentence label 3031 3031 30 seconds of plot 0 61271 61271 if it pared down its plots and characters to a... 0 39336 39336 gives the lie to many clichés and showcases a ... 1 44060 44060 spends a bit too much time 0 212 212 if it is n't entirely persuasive , it does giv... 1 ... ... ... ... 37194 37194 its provocative conclusion 1 6265 6265 an action film disguised as a war tribute is d... 0 54886 54886 goes to absurd lengths to duck the very issues... 0 860 860 a perfect performance 1 15795 15795 constantly pulling the rug from underneath us 1 [66675 rows x 3 columns]
2.2.2. 分割して用意した test 用データの確認
print 文で分割して用意した test 用データの中身を確認します。
print(df_splitted_test)
下記の例のように順番をシャッフルされた 674 のデータが表示されます。train データセットには 67,349 のデータが用意されているため、その約 1/100 のデータになります。
idx sentence label 66730 66730 with outtakes in which most of the characters ... 0 29890 29890 enigma is well-made 1 45801 45801 is ) so stoked to make an important film about... 0 29352 29352 the closest thing to the experience of space t... 1 19858 19858 lose their luster 0 ... ... ... ... 26242 26242 , you will enjoy seeing how both evolve 1 3242 3242 strange and beautiful film . 1 22756 22756 full of detail about the man and his country ,... 1 52906 52906 devastated 0 66781 66781 both deserve better . 0 [674 rows x 3 columns]
2.3. fine-tuning 前の DistilBERT のモデルと tokenizer のインスタンスの取得
下記の Python Script を実行します。
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast # Load tokenizer and model tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased") model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
2.4. pandas の形式のデータセットを Dataset のインスタンスに変換
下記の Python Script で panads の形式のデータセットを Hugging Face datasets ライブラリの Dataset のインスタンスに変換します。
training 用のデータセットは、分割された training 用のデータ (df_splitted_train) から変換しています。
validation データセットは分割せずそのまま変換しています。
from datasets import Dataset # Convert from pandas to Dataset train_dataset = Dataset.from_pandas(df_splitted_train) validation_dataset = Dataset.from_pandas(df_validation)
2.5. 映画レビューの文章を Token に変換
下記の Python Script で training と validation 用のデータセットの映画レビューの文章 (sentence) を DistilBERT の Token 配列に変換しています。Token の長さは最大でも 128 までとしています。SST-2 の DistilBERT 用の Token の最大長は 66 なのでこの条件を満たしています。Token の長さが 128 に満たない場合は padding するようにしています。
# Tokenize def tokenize(batch): return tokenizer(batch["sentence"], padding="max_length", truncation=True, max_length=128) # Convert datasets to tokenized format tokenized_train_dataset = train_dataset.map(tokenize, batched=True) tokenized_validation_dataset = validation_dataset.map(tokenize, batched=True)
2.6. Token に変換したデータの確認
下記の Python Script を実行し、映画レビューの文章がどのような Token に変換されているかを確認しました。
check_counter = 0 for example in tokenized_train_dataset: print(example) check_counter += 1 if check_counter == 10: break
tokenized_train_dataset の中の 10 のデータの中身を print 文で出力しています。下記の出力結果のスクロールバーを横にスライドさせると ‘input_ids’, ‘attention_mask’ 等が表示されます。
{'idx': 12774, 'sentence': 'its make-believe promise of life ', 'label': 0, '__index_level_0__': 12774, 'input_ids': [101, 2049, 2191, 1011, 2903, 4872, 1997, 2166, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} {'idx': 13049, 'sentence': "'s a very very strong `` b + . '' ", 'label': 1, '__index_level_0__': 13049, 'input_ids': [101, 1005, 1055, 1037, 2200, 2200, 2844, 1036, 1036, 1038, 1009, 1012, 1005, 1005, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} {'idx': 43834, 'sentence': 'the road warrior ', 'label': 1, '__index_level_0__': 43834, 'input_ids': [101, 1996, 2346, 6750, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} {'idx': 32075, 'sentence': 'another first-rate performance ', 'label': 1, '__index_level_0__': 32075, 'input_ids': [101, 2178, 2034, 1011, 3446, 2836, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} {'idx': 13119, 'sentence': 'come from a family that eats , meddles , argues , laughs , kibbitzes and fights together ', 'label': 1, '__index_level_0__': 13119, 'input_ids': [101, 2272, 2013, 1037, 2155, 2008, 20323, 1010, 19960, 27822, 1010, 9251, 1010, 11680, 1010, 11382, 10322, 8838, 2229, 1998, 9590, 2362, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} {'idx': 38414, 'sentence': 'ellen pompeo pulls off the feat with aplomb ', 'label': 1, '__index_level_0__': 38414, 'input_ids': [101, 9155, 13433, 8737, 8780, 8005, 2125, 1996, 8658, 2007, 9706, 21297, 2497, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} {'idx': 66216, 'sentence': 'strong itch to explore more ', 'label': 1, '__index_level_0__': 66216, 'input_ids': [101, 2844, 2009, 2818, 2000, 8849, 2062, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} {'idx': 16102, 'sentence': "'s depressing to see how far herzog has fallen . ", 'label': 0, '__index_level_0__': 16102, 'input_ids': [101, 1005, 1055, 2139, 24128, 2000, 2156, 2129, 2521, 2014, 28505, 2038, 5357, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} {'idx': 55769, 'sentence': 'its visual imagination is breathtaking ', 'label': 1, '__index_level_0__': 55769, 'input_ids': [101, 2049, 5107, 9647, 2003, 3052, 17904, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} {'idx': 53622, 'sentence': 'of weird performances and direction ', 'label': 0, '__index_level_0__': 53622, 'input_ids': [101, 1997, 6881, 4616, 1998, 3257, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
‘input_ids’ で指定されているのが Token の配列です。
[101, 2049, 2191, 1011, 2903, 4872, 1997, 2166, 102, 0, 0, 0, …, 0] のように 101 で始まっています。また、末尾は 102, 0, 0, 0, …, 0 で終わっています。101 は DistilBERT の [CLS] Token です。それに続く整数は文章内の単語等に対応しています。末尾の 102 は DistilBERT の [SEP] Token です。それに続く 0 は Token 配列の長さが 128 となるように padding するために置かれた整数です。
‘attention_mask’ で指定されているのはマスクをかける領域です。
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, …, 0] のように ‘input_ids’ において 0 で padding した領域は 0 でマスクをかけて無視するようにしています。マスクをかけない領域の値は 1 となっています。
2.7. fine-tuning の実行
下記の Python Script を実行し、fine-tuning を実行します。TrainingArguments には training 条件のパラメータを渡しています。
- 下記の例ではエポック数は 3 となっています。DistilBERT の fine-tuning では 2 – 4 程度の値で良いようです。
- learning rate は 1e-6 にしました。2e-5, 1e-5 等、一桁大きな値もセットしましたが、1e-6 にしたほうが test データの正解率は高くなりました。
- training のバッチサイズ (per_device_train_batch_size) は 16, 32 あたりの値が使用されるようです。少し試して 32 より 16 にしたほうが test データの正解率が高くなりそうでしたので 16 にしました。
from transformers import Trainer, TrainingArguments # Training arguments training_args = TrainingArguments( output_dir="./distilbert-sst2", num_train_epochs=3, per_device_train_batch_size=16, per_device_eval_batch_size=64, learning_rate=1e-6, weight_decay=0.01, eval_strategy="epoch", logging_strategy="steps", logging_steps=100, save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="eval_loss", warmup_steps=500, gradient_accumulation_steps=1, fp16=True # if available ) # Trainer trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_train_dataset, eval_dataset=tokenized_validation_dataset, processing_class=tokenizer, ) # Train trainer.train() # Save the model and tokenizer trainer.save_model('./fine_tuned_model_lr_1e_minus6_batch16')
Google Colab で T4 GPU を選択して実行しても上記の Python Script の fine-tuning のコードを実行するには10分から15分程度の時間を要します。
上記の Python Script を実行すると下記の例のような出力が得られます。乱数を使用しているため、実行するたびに数値は異なる値になるかと思います。
Epoch Training Loss Validation Loss 1 0.265100 0.301155 2 0.254000 0.286838 3 0.253200 0.288327
Training Loss は training データを対象としたロス関数の出力値です。Epoch 数が増えるにつれて減少しています。
Validation Loss は training データには含まれない validation 用のデータを対象としたロス関数の出力値です。上記の例では 2 Epoch 目の training が完了した時点で最小値 0.286838 を取り、3 Epoch 目が完了した時点では 0.288327 となっています。Training Loss が減少しているにもかかわらず、Validation Loss が増加しているため 3 Epoch 目では overfitting を起こしていることになります。
上記の Python Script では load_best_model_at_end=True と metric_for_best_model=”eval_loss” を指定しているため、Validation Loss が最も小さくなったときのモデルを fine-tuning したモデルとして採用するようにしています。
2.8. fine-tuning した DistilBERT で test データを分類
下記の Python Script を実行し、fine-tuning した DistilBERT で test データを分類しました。使用した test データは上記 2.2. で train データセットから test 用に分割したデータです。test データは上記 2.7. の fine-tuning で training にも validation にも使用されていないデータになります。
下記の Python Script を実行して得られた test データの分類の正解率は 0.9050 になりました。乱数を使用しているため、この値は上記 2.7. までを再実行すると異なる値になります。私が確認した際には 0.87 から 0.92 程度の値になりました。
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import torch from sklearn.metrics import accuracy_score # Load model and tokenizer fine_tuned_tokenizer = DistilBertTokenizer.from_pretrained("./fine_tuned_model_lr_1e_minus6_batch16") fine_tuned_model = DistilBertForSequenceClassification.from_pretrained("./fine_tuned_model_lr_1e_minus6_batch16") fine_tuned_model.eval() # Convert from pandas to Dataset test_dataset = Dataset.from_pandas(df_splitted_test) # Predict in batches preds, labels = [], [] for example in test_dataset: inputs = tokenizer(example["sentence"], return_tensors="pt", padding="max_length", truncation=True, max_length=128) with torch.no_grad(): outputs = fine_tuned_model(**inputs) logits = outputs.logits prediction = torch.argmax(logits, dim=-1).item() preds.append(prediction) labels.append(example["label"]) # Compute accuracy acc = accuracy_score(labels, preds) print(f"Validation Accuracy: {acc:.4f}")
Validation Accuracy: 0.9050