1. 概要
こちらのリンク先の Google Colab のコードセルを順に実行し、FLUX.2 Klein 4B の Transformer への入力を調べました。
2. diffusers の FLUX.2 Klein 4B の Transformer の担う処理
Flux2Transformer2DModel が FLUX.2 Klein 4B の画像生成 Transformer の処理全体を担う class になります。この class は、生成画像に対応する潜在空間の多次元配列データを、乱数で生成した初期データから順に更新していく速度ベクトルを計算します。
こちらのリンク先の下記のフローの条件付き予測を担当しています。指定された回数だけ、多次元配列データを更新するための速度ベクトルを繰り返し計算します。
Prompt ↓ Qwen2Tokenizer ↓ Qwen3ForCausalLM(Text Encoder) ↓ Text Embeddings(固定・全ステップ共有) ↓ Latent(初期ノイズ) ↓ [反復] Flux2Transformer2DModel(条件付き予測) + FlowMatchEulerDiscreteScheduler(更新) ↓ Latent(収束) ↓ AutoencoderKLFlux2.decode ↓ Image
Flux2Transformer2DModel の forward メソッドが、画像を生成する際に実行される Flux2KleinPipeline の Denoising loop 内の下記のスクリプトから呼ばれます。下記のハイライトした二つの行で呼ばれているself.transformer(...)が Flux2Transformer2DModel の forward メソッドの呼び出しになります。
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input, # (B, image_seq_len, C)
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids, # B, text_seq_len, 4
img_ids=latent_image_ids, # B, image_seq_len, 4
joint_attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1) :]
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self._attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1) :]
noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
FLUX.2 Klein 4B で実行する場合、self.do_classifier_free_guidanceは false なので、一つ目のnoise_pred = self.transformer(...)だけが実行されます。
補足:
Flux2Transformer2DModel は下記のように複数の Mixin を継承しています。このうち、ModelMixin が torch.nn.Module を継承しています。
class Flux2Transformer2DModel(
ModelMixin,
ConfigMixin,
PeftAdapterMixin,
FromOriginalModelMixin,
FluxTransformer2DLoadersMixin,
CacheMixin,
AttentionMixin,
):
先ほどのスクリプトのself.transformer(...)の transformer は Flux2Transformer2DModel のインスタンスです。torch.nn.Module を継承したクラスのインスタンスから、self.transformer(...) のように __call__ メソッドを呼ぶと、torch.nn.Module の __call__ メソッドから forward メソッドが呼ばれるため、Flux2Transformer2DModel の forward メソッドが呼ばれます。
下記のログは、Flux2Transformer2DModel の forward メソッドの先頭にブレイクポイントをセットし、デバッガでスタックトレースを表示した結果です。ハイライトした行が、Flux2KleinPipeline の Denoising loop からのself.transformer(...)の呼び出しになります。self.transformer(...)からいくつかのメソッドを経由して Flux2Transformer2DModel の forward メソッドが呼ばれています。
/content/diffusers/src/diffusers/models/transformers/transformer_flux2.py(825)forward()
824
--> 825 num_txt_tokens = encoder_hidden_states.shape[1]
826
ipdb> where
[... skipping 21 hidden frame(s)]
/tmp/ipython-input-744/1128065702.py(4)<cell line: 0>()
3
----> 4 image = pipe(
5 prompt=prompt,
/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py(124)decorate_context()
123 with ctx_factory():
--> 124 return func(*args, **kwargs)
125
/content/diffusers/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py(843)__call__()
842 with self.transformer.cache_context("cond"):
--> 843 noise_pred = self.transformer(
844 hidden_states=latent_model_input, # (B, image_seq_len, C)
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py(1776)_wrapped_call_impl()
1775 else:
-> 1776 return self._call_impl(*args, **kwargs)
1777
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py(1787)_call_impl()
1786 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787 return forward_call(*args, **kwargs)
1788
/usr/local/lib/python3.12/dist-packages/accelerate/hooks.py(175)new_forward()
174 else:
--> 175 output = module._old_forward(*args, **kwargs)
176 return module._hf_hook.post_forward(module, output)
/content/diffusers/src/diffusers/utils/peft_utils.py(315)wrapper()
314 # Execute the forward pass
--> 315 result = forward_fn(self, *args, **kwargs)
316 return result
> /content/diffusers/src/diffusers/models/transformers/transformer_flux2.py(825)forward()
824
--> 825 num_txt_tokens = encoder_hidden_states.shape[1]
826
3. diffusers の FLUX.2 Klein 4B を動かすときの Transformer への入力
Flux2Transformer2DModel の forward メソッドの引数は、下記の Denoising loop のself.transformer(...)メソッドが渡す引数です。
noise_pred = self.transformer(
hidden_states=latent_model_input, # (B, image_seq_len, C)
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids, # B, text_seq_len, 4
img_ids=latent_image_ids, # B, image_seq_len, 4
joint_attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
Flux2Transformer2DModel の forward メソッドの引数は下記のようになっています。
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> torch.Tensor | Transformer2DModelOutput:
こちらのリンク先の Google Colab のコードセルを順に実行し、Flux2Transformer2DModel の forward メソッドの引数をデバッグ実行で確認しました。
3.1. デバッグ実行による調査:テキストプロンプトのみを入力として、height=1024, width=1024 の画像を生成する場合
下記のコードセルを実行したときの Flux2Transformer2DModel の forward メソッドの引数を確認しました。
device = "cuda"
prompt = 'A cat holding a sign that says "Gifu AI Study Group"'
image = pipe(
prompt=prompt,
height=1024,
width=1024,
guidance_scale=1.0,
num_inference_steps=4,
generator=torch.Generator(device=device).manual_seed(0)
).images[0]
image.save("flux-klein.png")
Flux2Transformer2DModel の forward メソッドの引数は下記のようになりました。多次元配列データ torch.Tensor については、各軸のサイズを p hidden_states.shape 等で確認しました。
/content/diffusers/src/diffusers/models/transformers/transformer_flux2.py(825)forward()
824
--> 825 num_txt_tokens = encoder_hidden_states.shape[1]
826
ipdb> p hidden_states.shape
torch.Size([1, 4096, 128])
ipdb> p encoder_hidden_states.shape
torch.Size([1, 512, 7680])
ipdb> p timestep
tensor([1.], device='cuda:0', dtype=torch.bfloat16)
ipdb> p img_ids.shape
torch.Size([1, 4096, 4])
ipdb> p txt_ids.shape
torch.Size([1, 512, 4])
ipdb> p guidance
None
ipdb> p joint_attention_kwargs
None
ipdb> p return_dict
False
hidden_states のサイズは、torch.Size([1, 4096, 128]) でした。生成画像の高さが 1024、幅が 1024 となるように指定しているため、(1024 / 16) x (1024 / 16) = 64 x 64 = 4096 で、2軸目の画像小領域 (高さと幅が 1/8 の潜在空間の隣接する 2 x 2 領域をまとめた領域) の数は 4096 になっています。3軸目のサイズ 128 は、画像小領域に対応する画像データを表すベクトルの次元です。
encoder_hidden_states のサイズは、torch.Size([1, 512, 7680]) でした。これはプロンプト文字列から変換された 512 のトークンベクトルで、各トークンベクトルの次元は 7680 となっています。7680 という次元は Qwen3ForCausalLM の 36 層のうち、(9, 18, 27) で指定される浅い層、中ほどの層、やや深い層の 2560 次元のトークンベクトルを 3つ連結したベクトルの次元になります。
timestep の大きさは 1. となっています。今回の条件では、こちらのリンク先の「5.4. num_inference_steps=4 のとき」に記載した timesteps = [1000.0000, 967.3840, 908.1439, 767.2000] の中の数値を 1000 で割った数値が timestep として順に入ってきます。上記のログでは、Denoising loop の一度目のself.transformer(...)メソッドの呼び出しをチェックしているため、timestep は 1. となっています。timestep の値は、FLUX.2 の画像生成 Transformer のモジュレーション処理で参照されます。
guidance は None です。FLUX.2 Klein 4B を動かすとき、guidance scale は画像生成 Transformer 側では参照されないようです。
joint_attention_kwargs は None です。今回の実行条件では何も指定されていません。
return_dict は false です。下記のスクリプトのように Transformer の出力の多次元配列データを返します。 true の場合は、Transformer2DModelOutput のインスタンスを返します。
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
img_ids は、サイズが torch.Size([1, 4096, 4]) の多次元配列データです。下記のログは img_ids の中身を出力したログです。hidden_states の画像情報に対応する画像小領域の高さ方向と幅方向の index が格納されています。下記のログの例の場合、高さ方向 [0 – 63]、幅方向 [0 – 63]、全部で 64 x 64 = 4096 の 4つの数値からなるデータが格納されています。4つの数値のうち最初の 0 は生成画像に対応するデータであることを表しています。
ipdb> pp img_ids
tensor([[[ 0, 0, 0, 0],
[ 0, 0, 1, 0],
[ 0, 0, 2, 0],
...,
[ 0, 63, 61, 0],
[ 0, 63, 62, 0],
[ 0, 63, 63, 0]]], device='cuda:0')
txt_ids は、サイズが torch.Size([1, 512, 4]) の多次元配列データです。下記のログは txt_ids の中身を出力したログです。encoder_hidden_states のテキストトークンデータに対応するトークンの index が 4つの数値の中の 4番目にセットされています。
ipdb> pp txt_ids
tensor([[[ 0, 0, 0, 0],
[ 0, 0, 0, 1],
[ 0, 0, 0, 2],
...,
[ 0, 0, 0, 509],
[ 0, 0, 0, 510],
[ 0, 0, 0, 511]]], device='cuda:0')
3.2. デバッグ実行による調査:テキストプロンプトのみを入力として、height=256, width=512 の画像を生成する場合
下記のコードセルを実行したときの Flux2Transformer2DModel の forward メソッドの引数を確認しました。
device = "cuda"
prompt = 'A cat holding a sign that says "Gifu AI Study Group"'
image = pipe(
prompt=prompt,
height=256,
width=512,
guidance_scale=1.0,
num_inference_steps=4,
generator=torch.Generator(device=device).manual_seed(0)
).images[0]
image.save("flux-klein.png")
デバッグ実行のログは下記のようになりました。上記 3.1. の height=1024, width=1024 の画像を生成する場合と異なるのは、hidden_states と img_ids です。
> /content/diffusers/src/diffusers/models/transformers/transformer_flux2.py(825)forward()
824
--> 825 num_txt_tokens = encoder_hidden_states.shape[1]
826
ipdb> p hidden_states.shape
torch.Size([1, 512, 128])
ipdb> p encoder_hidden_states.shape
torch.Size([1, 512, 7680])
ipdb> p img_ids.shape
torch.Size([1, 512, 4])
ipdb> p txt_ids.shape
torch.Size([1, 512, 4])
hidden_states のサイズは、torch.Size([1, 512, 128]) でした。生成画像の高さが 256、幅が 512 となるように指定しているため、(256 / 16) x (512 / 16) = 16 x 32 = 512 で、2軸目の画像小領域 (高さと幅が 1/8 の潜在空間の隣接する 2 x 2 領域をまとめた領域) の数は 512 です。3軸目のサイズ 128 は、画像小領域に対応する画像データを表すベクトルの次元です。
img_ids は、サイズが torch.Size([1, 512, 4]) の多次元配列データです。下記のログは img_ids の中身を出力したログです。hidden_states の画像情報に対応する画像小領域の高さ方向と幅方向の index が格納されています。下記のログの例の場合、高さ方向 [0 – 15]、幅方向 [0 – 31]、全部で 16 x 32 = 512 の 4つの数値からなるデータが格納されています。
ipdb> p img_ids
tensor([[[ 0, 0, 0, 0],
[ 0, 0, 1, 0],
[ 0, 0, 2, 0],
...,
[ 0, 15, 29, 0],
[ 0, 15, 30, 0],
[ 0, 15, 31, 0]]], device='cuda:0')
3.3. デバッグ実行による調査:テキストプロンプトと 2枚の画像を入力として、height=1024, width=1024 の画像を生成する場合
下記のコードセルを実行したときの Flux2Transformer2DModel の forward メソッドの引数を確認しました。下記のスクリプトは、Web ページ上の指定した URL の 2枚の画像を参照し、1枚の画像を生成します。1つ目の画像を背景とし、2つ目の画像の鳥が写った写真を生成します。
from diffusers.utils import load_image
device = "cuda"
url = 'https://www.leafwindow.com/wordpress-05/wp-content/uploads/2023/12/IMG_6492-20.jpg'
image1 = load_image(url)
url = 'https://www.leafwindow.com/wordpress-05/wp-content/uploads/2023/02/DSC00022-min-SonyAlpha-%E6%A8%AA.jpg'
image2 = load_image(url)
prompt = """
Use image 1 strictly as the environmental background and scene layout.
Place the bird from image 2 clearly in the foreground as the main subject.
Preserve the bird's shape, colors, and feather details from image 2.
Ensure realistic lighting consistent with the background environment.
Seamless compositing, natural shadows, accurate depth of field,
high detail, professional wildlife photography.
"""
image = pipe(
prompt=prompt,
image=[image1, image2],
height=1024,
width=1024,
guidance_scale=1.0,
num_inference_steps=4,
generator=torch.Generator(device=device).manual_seed(0)
).images[0]
image.save("flux-klein-with-input-image.png")
デバッグ実行のログは下記のようになりました。上記 3.1. の height=1024, width=1024 の画像を生成する場合とサイズが異なるのは、hidden_states と img_ids です。
テキストプロンプトが異なるため、encoder_hidden_states のデータの中身も異なりますが、encoder_hidden_states のサイズは torch.Size([1, 512, 7680]) で、上記 3.1. のときと同じです。
> /content/diffusers/src/diffusers/models/transformers/transformer_flux2.py(825)forward()
824
--> 825 num_txt_tokens = encoder_hidden_states.shape[1]
826
ipdb> p hidden_states.shape
torch.Size([1, 10002, 128])
ipdb> p encoder_hidden_states.shape
torch.Size([1, 512, 7680])
ipdb> p timestep
tensor([1.], device='cuda:0', dtype=torch.bfloat16)
ipdb> p img_ids.shape
torch.Size([1, 10002, 4])
ipdb> p txt_ids.shape
torch.Size([1, 512, 4])
ipdb> p guidance
None
ipdb> p joint_attention_kwargs
None
ipdb> p return_dict
False
hidden_states のサイズは、torch.Size([1, 10002, 128]) でした。上記 3.1. のときと生成画像のサイズは同じですが、hidden_states には生成画像用のサイズ torch.Size([1, 4096, 128]) の多次元配列データに入力として与えた 2枚の画像のデータが連結されています。
img_ids は、サイズが torch.Size([1, 10002, 4]) の多次元配列データです。下記のログは img_ids の中身を出力したログです。上記 3.1. と 3.2. で確認した際は、4つ並んだ数値の中の 1つ目の数値は 0 で固定でした。下記のログでは、最後の 3行のデータの 1つ目の数値が 20 になっています。
ipdb> p img_ids
tensor([[[ 0, 0, 0, 0],
[ 0, 0, 1, 0],
[ 0, 0, 2, 0],
...,
[20, 51, 75, 0],
[20, 51, 76, 0],
[20, 51, 77, 0]]], device='cuda:0')
下記のログは、img_ids のデータの中身をもう少し詳細に確認した際のログの抜粋になります。
ipdb> p img_ids[0, 0:10, :]
tensor([[0, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 2, 0],
...,
[0, 0, 7, 0],
[0, 0, 8, 0],
[0, 0, 9, 0]], device='cuda:0')
ipdb> p img_ids[0, 4050:4100, :]
tensor([[ 0, 63, 18, 0],
[ 0, 63, 19, 0],
[ 0, 63, 20, 0],
...,
[ 0, 63, 61, 0],
[ 0, 63, 62, 0],
[ 0, 63, 63, 0],
[10, 0, 0, 0],
[10, 0, 1, 0],
[10, 0, 2, 0],
[10, 0, 3, 0]], device='cuda:0')
ipdb> p img_ids[0, 5900:6000, :]
tensor([[10, 48, 28, 0],
[10, 48, 29, 0],
[10, 48, 30, 0],
...,
[10, 49, 34, 0],
[10, 49, 35, 0],
[10, 49, 36, 0],
[20, 0, 0, 0],
[20, 0, 1, 0],
[20, 0, 2, 0],
...,
[20, 0, 51, 0],
[20, 0, 52, 0],
[20, 0, 53, 0]], device='cuda:0')
ipdb> p img_ids[0, 9950:, :]
tensor([[20, 51, 26, 0],
[20, 51, 27, 0],
[20, 51, 28, 0],
...,
[20, 51, 75, 0],
[20, 51, 76, 0],
[20, 51, 77, 0]], device='cuda:0')
最初に、高さ方向 [0 – 63]、幅方向 [0 – 63] の index の生成画像のデータが並んでいます。生成画像のデータの 4つの数値の中の 1つ目の数値は 0 です。
次に、高さ方向 [0 – 49]、幅方向 [0 – 36] の index の 1枚目の入力画像のデータが並んでいます。1枚目の入力画像のデータの 4つの数値の中の 1つ目の数値は 10 です。1枚目の入力画像のサイズは高さ 807、幅 605 です。小数点以下を切り捨てると 807 / 16 = 50、 605 / 16 = 37 で、高さ方向 [0 – 49]、幅方向 [0 – 36] の index の範囲と対応付けられます。
最後に、高さ方向 [0 – 51]、幅方向 [0 – 77] の index の 2枚目の入力画像のデータが並んでいます。2枚目の入力画像のデータの 4つの数値の中の 1つ目の数値は 20 です。2枚目の入力画像のサイズは高さ 1000、幅 1500 です。こちらの範囲は、1000 / 16 = 62、1500 / 16 = 93 よりも限定された範囲となっています。
Flux2KleinPipeline の __call__ メソッドの入力画像を処理するコードに下記のコードがあり、高さと幅の積が 1024 x 1024 を超えるときはリサイズされるようです。
for img in image:
image_width, image_height = img.size
if image_width * image_height > 1024 * 1024:
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
image_width, image_height = img.size
上記のスクリプトから呼ばれる Flux2ImageProcessor の _resize_to_target_area メソッドは下記のようになっています。
@staticmethod
def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image:
image_width, image_height = image.size
scale = math.sqrt(target_area / (image_width * image_height))
width = int(image_width * scale)
height = int(image_height * scale)
return image.resize((width, height), PIL.Image.Resampling.LANCZOS)
下記の計算結果のように、上記のメソッドで計算されるリサイズ後の高さと幅の大きさを 16 で割ると高さ方向 [0 – 51]、幅方向 [0 – 77] の index の範囲の大きさになります。
>>> import math >>> scale = math.sqrt(1024 * 1024 / (1500 * 1000)) >>> print(scale) 0.8360924988699915 >>> 1500 * scale 1254.1387483049873 >>> 1000 * scale 836.0924988699915 >>> 1500 * scale / 16 78.3836717690617 >>> 1000 * scale / 16 52.255781179374466
4. 画像生成 Transformer の入力となるテキストトークンデータ encoder_hidden_states について
4.1. Flux2KleinPipeline による encoder_hidden_states の生成について
Flux2Transformer2DModel の forward メソッドの引数の一つ encoder_hidden_states は、テキストトークンデータで、サイズ torch.Size([1, 512, 7680]) の torch.Tensor です。512 はトークンベクトルの数で、7680 はトークベクトルの次元です。
この多次元配列データは、Flux2KleinPipeline の下記の _get_qwen3_prompt_embeds メソッドでプロンプト文字列から生成されます。
@staticmethod
def _get_qwen3_prompt_embeds(
text_encoder: Qwen3ForCausalLM,
tokenizer: Qwen2TokenizerFast,
prompt: str | list[str],
dtype: torch.dtype | None = None,
device: torch.device | None = None,
max_sequence_length: int = 512,
hidden_states_layers: list[int] = (9, 18, 27),
):
...
# Forward pass through the model
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
プロンプト文字列を分解した各トークンを Qwen3 のテキストエンコーダに入力し、hidden_states_layers: list[int] = (9, 18, 27)で指定された中間層の出力 3つを合わせて 7680 次元のトークンベクトルを生成しています。ここで参照している Qwen3 は Qwen3 4B モデルで、次の 4.2. に記載したように 36 層の Transformer ブロックで構成されています。
下の図はトークンIDから最初のトークンベクトルを生成する層を 0 層目とし、その後に 36 層の Transformer 層を並べた図になります。hidden_states_layers: list[int] = (9, 18, 27)で指定された中間層の出力は、初期 1/4、中間 1/2、その後の 3/4 に位置する中間層の出力になります。
0 ----------- 9 ----------- 18 ----------- 27 ----------- 36 | early | mid | late |
27 層目以降は Qwen3 で次に出力するトークンを推測するタスクに特化し過ぎていて、意味表現としては劣化していることがあるため、それ以前の中間層の出力を等間隔で 3つ取り、使用しているようです。初期 1/4 番目の層の出力は単語レベルの情報、中間 1/2 番目の層の出力は文脈依存の情報、その後の 3/4 番目の層の出力はより高次の文脈を反映した意味の情報、という位置付けのようです。
4.2. テキストプロンプトの処理を実行する Qwen3ForCausalLM の構成について
Flux2KleinPipeline のインスタンスを pipe という名前で生成後、下記のコードセルを実行すると、
print(pipe.text_encoder)
下記のように text_encoder の構成が出力されます。下記のネットワーク構成は Qwen3 4B モデルの構成になります。
Qwen3ForCausalLM(
(model): Qwen3Model(
(embed_tokens): Embedding(151936, 2560)
(layers): ModuleList(
(0-35): 36 x Qwen3DecoderLayer(
(self_attn): Qwen3Attention(
(q_proj): Linear(in_features=2560, out_features=4096, bias=False)
(k_proj): Linear(in_features=2560, out_features=1024, bias=False)
(v_proj): Linear(in_features=2560, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=2560, bias=False)
(q_norm): Qwen3RMSNorm((128,), eps=1e-06)
(k_norm): Qwen3RMSNorm((128,), eps=1e-06)
)
(mlp): Qwen3MLP(
(gate_proj): Linear(in_features=2560, out_features=9728, bias=False)
(up_proj): Linear(in_features=2560, out_features=9728, bias=False)
(down_proj): Linear(in_features=9728, out_features=2560, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): Qwen3RMSNorm((2560,), eps=1e-06)
(post_attention_layernorm): Qwen3RMSNorm((2560,), eps=1e-06)
)
)
(norm): Qwen3RMSNorm((2560,), eps=1e-06)
(rotary_emb): Qwen3RotaryEmbedding()
)
(lm_head): Linear(in_features=2560, out_features=151936, bias=False)
)
(embed_tokens): Embedding(151936, 2560) は、プロンプト文字列から得られた語彙数 151936 のテキストトークンIDを 2560 次元のトークンベクトルに変換する行列を表しています。テキストエンコーダでは、まず、この処理が実行されます。
上記の出力の下記の部分は 36 層の Transformer ブロックを表しています。重みが異なる同じ構成の Transformer ブロックが 36 層あり、512 トークンある 2560 次元のトークンベクトル列が順に変換されていきます。
(layers): ModuleList(
(0-35): 36 x Qwen3DecoderLayer(
...
)
)
36 層の Transformer の各ブロックの処理の流れをテキスト形式の図で表すと下記のようになります。x は 2560 次元のトークンベクトルを 512 個並べたトークン列のデータになります。
input x (2560)
│
│
▼
┌──────────────────────────┐
│ input_layernorm │
│ RMSNorm(2560) │
└───────────┬──────────────┘
│
▼
┌─────────────── Self Attention ────────────────┐
│ │
│ Q = q_proj(x) 2560 → 4096 │
│ K = k_proj(x) 2560 → 1024 │
│ V = v_proj(x) 2560 → 1024 │
│ │
│ Q = q_norm(Q) RMSNorm(128 per head) │
│ K = k_norm(K) RMSNorm(128 per head) │
│ │
│ Attention(Q,K,V) │
│ │
│ H = softmax(QKᵀ / √d) V │
│ │
│ attn_out = o_proj(H) 4096 → 2560 │
│ │
└───────────────────┬───────────────────────────┘
│
▼
Residual Add
x + attn_out
│
▼
┌──────────────────────────┐
│ post_attention_layernorm │
│ RMSNorm(2560) │
└───────────┬──────────────┘
│
▼
MLP (SwiGLU)
┌──────────────────────────┐
│ gate = gate_proj(x) │
│ 2560 → 9728 │
│ │
│ up = up_proj(x) │
│ 2560 → 9728 │
│ │
│ gate = SiLU(gate) │
│ │
│ hidden = gate * up │
│ │
│ mlp_out = down_proj │
│ 9728 → 2560 │
└───────────┬──────────────┘
│
▼
Residual Add
x2 + mlp_out
│
▼
output (2560)
Self Attention 層で 2560 次元のトークンベクトルに 2560 x 4096 行列を作用させ、4096 次元の Query heads を計算しています。各 head は 128 次元で、head の数は 4096 / 128 = 32 です。Key と Value の head の数は 1024 / 128 = 8 で、Query の head の数の 1/4 です。GQA (Grouped Query Attention) が使用されていて、4つの異なる Query head が、共通の Key head と Value head を使用しているため、このような数になっています。KV cache の使用量を減らす目的で導入された手法です。
Qwen3ForCausalLM の全体構成の図に戻ってパラメータの数を計算してみます。行列変換に使用する重みと RMSNorm の計算で参照される重みパラメータ (scale パラメータ) の数の和を取ると下記のように約 4B になります。
>>> Embedding = 151936 * 2560 >>> Qwen3Attention = 2560 * 4096 * 2 + 2560 * 1024 * 2 + 128 * 2 >>> Qwen3MLP = 9728 * 2560 * 3 >>> Qwen3DecoderLayer = Qwen3Attention + Qwen3MLP + 2560 * 2 >>> Embedding + Qwen3DecoderLayer * 36 + 2560 4022468096
4.3. Flux2KleinPipeline の _get_qwen3_prompt_embeds(…) の処理内容をデバッグ実行で確認
Flux2KleinPipeline でテキストトークンデータを生成する下記の _get_qwen3_prompt_embeds(…) の先頭にブレイクポイントをセットし、デバッグ実行でローカル変数の内容を確認しました。
@staticmethod
def _get_qwen3_prompt_embeds(
text_encoder: Qwen3ForCausalLM,
tokenizer: Qwen2TokenizerFast,
prompt: str | list[str],
dtype: torch.dtype | None = None,
device: torch.device | None = None,
max_sequence_length: int = 512,
hidden_states_layers: list[int] = (9, 18, 27),
):
...
return prompt_embeds
こちらのリンク先の Google Colab のコードセルを順に実行し、7つ目のコードセルを実行するとセットしたブレイクポイントにヒットします。
プロンプト文字列はそのまま処理されるのではなく、一度、下記のスクリプトでチャット形式のテキストに変換しています。Qwen は、チャット形式の入力を想定しているためです。
messages = [{"role": "user", "content": single_prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
_get_qwen3_prompt_embeds の上記のコードを処理しているときのローカル変数の内容を、デバッグ実行で確認したログは下記のようになりました。
ipdb> n
> /content/diffusers/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py(228)_get_qwen3_prompt_embeds()
227 for single_prompt in prompt:
--> 228 messages = [{"role": "user", "content": single_prompt}]
229 text = tokenizer.apply_chat_template(
...
ipdb> p single_prompt
'A cat holding a sign that says "Gifu AI Study Group"'
ipdb> n
> /content/diffusers/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py(229)_get_qwen3_prompt_embeds()
228 messages = [{"role": "user", "content": single_prompt}]
--> 229 text = tokenizer.apply_chat_template(
230 messages,
ipdb> p messages
[{'role': 'user', 'content': 'A cat holding a sign that says "Gifu AI Study Group"'}]
...
ipdb> n
> /content/diffusers/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py(229)_get_qwen3_prompt_embeds()
228 messages = [{"role": "user", "content": single_prompt}]
--> 229 text = tokenizer.apply_chat_template(
230 messages,
ipdb> n
> /content/diffusers/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py(235)_get_qwen3_prompt_embeds()
234 )
--> 235 inputs = tokenizer(
236 text,
ipdb> whatis text
<class 'str'>
ipdb> p text
'<|im_start|>user\nA cat holding a sign that says "Gifu AI Study Group"<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n'
single_prompt の内容は、コードセル側でプロンプトとしてセットした A cat holding a sign that says “Gifu AI Study Group” です。tokenizer.apply_chat_template(…) を呼んだ後は、text にセットされた下記のような文字列に変換されています。
<|im_start|>user A cat holding a sign that says "Gifu AI Study Group"<|im_end|> <|im_start|>assistant <think> </think>
tokenize=False としているため、戻り値の text はトークン ID 列ではなく文字列になっています。
add_generation_prompt=True としているため、Assistant の開始トークンを追加しています。「Assistant がこれから文字列を生成する」という状態でテキストエンコーダに文字列を入力し、テキストトークンデータを処理する準備をしています。
次に、下記のコードを実行し、上記の text にセットされた文字列をトークン ID 列に変換しています。
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
_get_qwen3_prompt_embeds の上記の tokenizer の __call__ メソッドの返り値 inputs を、デバッグ実行で確認しました。
下記のログのようにトークン ID の整数値が 512 個並んだ input_ids と、1 と 0 が 512 個並んだ attention_mask がセットされていました。トークン ID の整数値は、151643 以外の整数値が 26 個並んだ後、27 番目以降の整数は全て 151643 となっていました。attention_mask は 1 が 26 個並んだ後、27 番目以降には 0 が並びました。
ipdb> p inputs
{'input_ids': tensor([[151644, 872, 198, 32, 8251, 9963, 264, 1841, 429,
2727, 330, 38, 20850, 15235, 19173, 5737, 1, 151645,
198, 151644, 77091, 198, 151667, 271, 151668, 271, 151643,
151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
...
151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643]]),
'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
...
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}
ipdb> p inputs['input_ids'].shape
torch.Size([1, 512])
padding=”max_length” としているため、max_length=max_sequence_length で指定した長さまで padding されたトークン ID 列が出力されています。今回実行した条件では、max_sequence_length は 512 なため、トークン ID 列の長さは 512 となっています。
truncation=True としているため、max_sequence_length の 512 より長いトークン ID 列になる場合は、513 以降のトークン ID は切り捨てられます。
したがって、どのようなプロンプト文字列を入力してもトークン ID の長さは 512 として処理されます。
上記のログのトークン ID のうち、下記の左側に記載したトークン ID は、右側に記載した特殊トークンに対応しているようです。
151644 → <|im_start|> 151645 → <|im_end|> 151667 → <think> 151668 → </think> 151643 → padding
_get_qwen3_prompt_embeds の下記のコードをデバッグ実行し、トークン ID 列を Qwen3ForCausalLM のテキストエンコーダで処理した結果を確認しました。(トークン ID 列を生成した後の処理になります。)
# Forward pass through the model
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
下記のようなログが得られました。
ipdb> p len(output.hidden_states) 37 ipdb> p output.hidden_states[0].shape torch.Size([1, 512, 2560]) ipdb> p output.hidden_states[1].shape torch.Size([1, 512, 2560]) ipdb> p output.hidden_states[36].shape torch.Size([1, 512, 2560])
output_hidden_states=True としているため、中間層の出力も取得しています。
上記のログでは、output = text_encoder(...)の返り値として得られた output.hidden_states の数 len(output.hidden_states) は、37 になっています。
下記のように、最初にトークンIDをトークンベクトルに変換する層の出力があり、その後に 36 層の Tranformer ブロックの出力が続いています。そのため、output.hidden_states の数 len(output.hidden_states) は、37 になっています。
hidden_states = 0 embedding 1 layer0 2 layer1 ... 36 layer35
各層の出力である、多次元配列データのサイズは torch.Size([1, 512, 2560]) となっています。2560 次元のトークンベクトルが 512 並んだ多次元配列になります。
上記 4.1. で触れたように、これらの中間層の出力の中から、hidden_states_layers: list[int] = (9, 18, 27) で指定した 3つの層の出力を取得し、それらを連結した 7680 次元のトークンベクトル列を prompt_embeds として返します。下記のログはその処理を確認したログになります。
ipdb> n
> /content/diffusers/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py(261)_get_qwen3_prompt_embeds()
260
--> 261 batch_size, num_channels, seq_len, hidden_dim = out.shape
262 prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
ipdb> p out.shape
torch.Size([1, 3, 512, 2560])
ipdb> n
> /content/diffusers/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py(262)_get_qwen3_prompt_embeds()
261 batch_size, num_channels, seq_len, hidden_dim = out.shape
--> 262 prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
263
ipdb> n
> /content/diffusers/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py(264)_get_qwen3_prompt_embeds()
263
--> 264 return prompt_embeds
265
ipdb> p prompt_embeds.shape
torch.Size([1, 512, 7680])