1. 概要
こちらのリンク先の Google Colab のコードセルを実行し、FLUX.2 [klein] 4B の画像生成 Transformer の構成を調べました。
2. diffusers の FLUX.2 Klein 4B の Transformer の担う処理
Flux2Transformer2DModel が FLUX.2 Klein 4B の画像生成 Transformer の処理全体を担う class になります。この class は、生成画像に対応する潜在空間の多次元配列データを、乱数で生成した初期データから順に更新していく速度ベクトルを計算します。
こちらのリンク先の下記のフローの条件付き予測を担当しています。指定された回数だけ、多次元配列データを更新するための速度ベクトルを繰り返し計算します。
Prompt ↓ Qwen2Tokenizer ↓ Qwen3ForCausalLM(Text Encoder) ↓ Text Embeddings(固定・全ステップ共有) ↓ Latent(初期ノイズ) ↓ [反復] Flux2Transformer2DModel(条件付き予測) + FlowMatchEulerDiscreteScheduler(更新) ↓ Latent(収束) ↓ AutoencoderKLFlux2.decode ↓ Image
こちらのリンク先で調べた入力を受け取り、多次元配列データを更新するための速度ベクトルを繰り返し計算します。
3. FLUX.2 [klein] 4B の画像生成 Transformer の構成
こちらのリンク先の Google Colab のコードセルを 4. まで順に実行し、Flux2KleinPipeline のインスタンスを pipe という変数名で生成しました。
その後、下記のコードセルを実行すると、
print(pipe.transformer)
下記の出力が得られました。この後、構成要素を一つずつ見ていきます。
Flux2Transformer2DModel(
(pos_embed): Flux2PosEmbed()
(time_guidance_embed): Flux2TimestepGuidanceEmbeddings(
(time_proj): Timesteps()
(timestep_embedder): TimestepEmbedding(
(linear_1): Linear(in_features=256, out_features=3072, bias=False)
(act): SiLU()
(linear_2): Linear(in_features=3072, out_features=3072, bias=False)
)
)
(double_stream_modulation_img): Flux2Modulation(
(linear): Linear(in_features=3072, out_features=18432, bias=False)
(act_fn): SiLU()
)
(double_stream_modulation_txt): Flux2Modulation(
(linear): Linear(in_features=3072, out_features=18432, bias=False)
(act_fn): SiLU()
)
(single_stream_modulation): Flux2Modulation(
(linear): Linear(in_features=3072, out_features=9216, bias=False)
(act_fn): SiLU()
)
(x_embedder): Linear(in_features=128, out_features=3072, bias=False)
(context_embedder): Linear(in_features=7680, out_features=3072, bias=False)
(transformer_blocks): ModuleList(
(0-4): 5 x Flux2TransformerBlock(
(norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(norm1_context): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(attn): Flux2Attention(
(to_q): Linear(in_features=3072, out_features=3072, bias=False)
(to_k): Linear(in_features=3072, out_features=3072, bias=False)
(to_v): Linear(in_features=3072, out_features=3072, bias=False)
(norm_q): RMSNorm((128,), eps=1e-06, elementwise_affine=True)
(norm_k): RMSNorm((128,), eps=1e-06, elementwise_affine=True)
(to_out): ModuleList(
(0): Linear(in_features=3072, out_features=3072, bias=False)
(1): Dropout(p=0.0, inplace=False)
)
(norm_added_q): RMSNorm((128,), eps=1e-06, elementwise_affine=True)
(norm_added_k): RMSNorm((128,), eps=1e-06, elementwise_affine=True)
(add_q_proj): Linear(in_features=3072, out_features=3072, bias=False)
(add_k_proj): Linear(in_features=3072, out_features=3072, bias=False)
(add_v_proj): Linear(in_features=3072, out_features=3072, bias=False)
(to_add_out): Linear(in_features=3072, out_features=3072, bias=False)
)
(norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(ff): Flux2FeedForward(
(linear_in): Linear(in_features=3072, out_features=18432, bias=False)
(act_fn): Flux2SwiGLU(
(gate_fn): SiLU()
)
(linear_out): Linear(in_features=9216, out_features=3072, bias=False)
)
(norm2_context): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(ff_context): Flux2FeedForward(
(linear_in): Linear(in_features=3072, out_features=18432, bias=False)
(act_fn): Flux2SwiGLU(
(gate_fn): SiLU()
)
(linear_out): Linear(in_features=9216, out_features=3072, bias=False)
)
)
)
(single_transformer_blocks): ModuleList(
(0-19): 20 x Flux2SingleTransformerBlock(
(norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(attn): Flux2ParallelSelfAttention(
(to_qkv_mlp_proj): Linear(in_features=3072, out_features=27648, bias=False)
(mlp_act_fn): Flux2SwiGLU(
(gate_fn): SiLU()
)
(norm_q): RMSNorm((128,), eps=1e-06, elementwise_affine=True)
(norm_k): RMSNorm((128,), eps=1e-06, elementwise_affine=True)
(to_out): Linear(in_features=12288, out_features=3072, bias=False)
)
)
)
(norm_out): AdaLayerNormContinuous(
(silu): SiLU()
(linear): Linear(in_features=3072, out_features=6144, bias=False)
(norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
)
(proj_out): Linear(in_features=3072, out_features=128, bias=False)
)
3.1. (pos_embed): Flux2PosEmbed()
下記の構成要素は RoPE を適用するための sin, cos 列の値を計算するクラスのインスタンスです。ニューラルネットワークの学習で得られる重みパラメータは持ちません。
(pos_embed): Flux2PosEmbed()
self.pos_embed(…) は、Flux2Transformer2DModel の forward メソッド内の下記のコードで実行され、forward メソッドの引数 img_ids と txt_ids を参照し、Attention 層で Query と Key を回転させる際に参照される sin, cos 列を計算します。
image_rotary_emb = self.pos_embed(img_ids)
text_rotary_emb = self.pos_embed(txt_ids)
concat_rotary_emb = (
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
)
3.2. (time_guidance_embed): Flux2TimestepGuidanceEmbeddings(…)
下記の構成要素は画像生成 Transformer のモジュレーション処理で参照される temb (timestep embedding) を計算する time_guidance_embed を表しています。
(time_guidance_embed): Flux2TimestepGuidanceEmbeddings(
(time_proj): Timesteps()
(timestep_embedder): TimestepEmbedding(
(linear_1): Linear(in_features=256, out_features=3072, bias=False)
(act): SiLU()
(linear_2): Linear(in_features=3072, out_features=3072, bias=False)
)
)
Flux2Transformer2DModel の forward メソッド内では下記のコードで self.time_guidance_embed(timestep, guidance)が呼ばれています。その計算結果 temb は、Double Stream Transformer Block 内のモジュレーション、Single Stream Transformer Block 内のモジュレーション、出力層の Adaptive LayerNorm で参照されます。
# 1. Calculate timestep embedding and modulation parameters
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
temb = self.time_guidance_embed(timestep, guidance)
double_stream_mod_img = self.double_stream_modulation_img(temb)
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
single_stream_mod = self.single_stream_modulation(temb)
...
# 6. Output layers
hidden_states = self.norm_out(hidden_states, temb)
こちらのリンク先で調べたように Flux2KleinPipeline を使用する場合は、guidance は None です。そのため、timestep のみを参照して temb が計算されます。
timestep から temb を生成する処理ですが、まず、下記のコードで diffusers の Timesteps クラスのインスタンスを用意します。下記の in_channels は 256 です。
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
次に、下記のコードでその forward メソッドを呼び、0 以上 1000 以下の値の timestep から 256 次元のベクトルを生成します。
timesteps_proj = self.time_proj(timestep)
上記のコードで生成される 256 次元のベクトルは、下記の式で表されます。下記の式の $t$ は timestep の値です。ベクトル要素 $\cos(\omega_i \cdot t)$, $\sin(\omega_i \cdot t)$ の個数の合計が 256 となるようにしています。
$$
\begin{aligned}
\mathrm{emb}(t) &=
\left[ \cos(\omega_0 \cdot t), \ldots, \cos(\omega_{127} \cdot t), \sin(\omega_0 \cdot t), \ldots, \sin(\omega_{127} \cdot t) \right] \\
\omega_i &= 10000^{-i/128}
\end{aligned}
$$
上記の式のベクトル $\mathrm{emb}(t)$ は、例えば $t$ = 100 と $t$ = 101 のときのベクトルが、$t$ = 100 と $t$ = 500 のように離れた $t$ の値のベクトルよりも近くなる性質を持っています。
w = $\omega_i$ として、i と w の関係を表にすると下記のようになります。i の値が 0 から 127 まで増加するにつれて角周波数 w の値は指数関数的に小さくなっていきます。幅広い範囲の角周波数をベクトル成分に持たせることで、$t$ が近いときは i が小さな高周波の要素によって違いが生じ、$t$ が 0 と 1000 のように大きく離れているときは i が大きな低周波の成分に少しずつ変化していく違いを持たせることができます。
| i | w | | --- | ------- | | 0 | 1 | | 32 | 1/10 | | 64 | 1/100 | | 96 | 1/1000 | | 128 | 1/10000 |
こちらのリンク先の Attention Is All You Need の2017年の Transformer の論文の 3.5 Positional Encoding の式の $pos$ を $t$、$d_{\mathrm{model}}$ を 256 に変え、ベクトル要素の並び順を、cos と sin を交互に並べるのではなく、$\cos(\omega_0 \cdot t)$, $\ldots$, $\cos(\omega_{127} \cdot t)$ の後ろに $\sin(\omega_0 \cdot t)$, $\ldots$, $\sin(\omega_{127} \cdot t)$ を並べるようにした式が、先ほどの $\mathrm{emb}(t)$ の式になります。
$$
\begin{aligned}
\mathrm{PE}(pos, 2i) &= \sin\left(\frac{pos}{10000^{2i/d_{\mathrm{model}}}}\right) \\
\mathrm{PE}(pos, 2i+1) &= \cos\left(\frac{pos}{10000^{2i/d_{\mathrm{model}}}}\right)
\end{aligned}
$$
Flux2TimestepGuidanceEmbeddings の後段の下記のハイライトした TimestepEmbedding は、先ほどの 256 次元のベクトル $\mathrm{emb}(t)$ に行列を作用させて 3072 次元のベクトルに変換します。その後、非線形変換 SiLU() を適用し、最後に 3072 x 3072 の行列を作用させて 3072 次元のベクトルを出力します。
(time_guidance_embed): Flux2TimestepGuidanceEmbeddings(
(time_proj): Timesteps()
(timestep_embedder): TimestepEmbedding(
(linear_1): Linear(in_features=256, out_features=3072, bias=False)
(act): SiLU()
(linear_2): Linear(in_features=3072, out_features=3072, bias=False)
)
)
3.3. Flux2Modulation によるモジュレーション用の shift, scale, gate ベクトルの計算
下記の構成要素は、上記 3.2. の構成要素の出力である timestep を元に生成された 3072 次元の temb (timestep embedding) から、Double Stream Transformer Block 内のモジュレーション、Single Stream Transformer Block 内のモジュレーションで参照される shift, scale, gate ベクトルを計算します。
画像生成 Transformer の処理に timestep 情報を埋め込む目的で使用されます。
(double_stream_modulation_img): Flux2Modulation(
(linear): Linear(in_features=3072, out_features=18432, bias=False)
(act_fn): SiLU()
)
(double_stream_modulation_txt): Flux2Modulation(
(linear): Linear(in_features=3072, out_features=18432, bias=False)
(act_fn): SiLU()
)
(single_stream_modulation): Flux2Modulation(
(linear): Linear(in_features=3072, out_features=9216, bias=False)
(act_fn): SiLU()
)
Flux2Modulation は下記のスクリプトのように 3072 次元の temb (timestep embedding) ベクトルの要素に SiLU() を適用した後、学習済みの重みを持つ行列を作用させています。
class Flux2Modulation(nn.Module):
def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
super().__init__()
self.mod_param_sets = mod_param_sets
self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
self.act_fn = nn.SiLU()
def forward(self, temb: torch.Tensor) -> torch.Tensor:
mod = self.act_fn(temb)
mod = self.linear(mod)
return mod
3.3.1. Double Stream Transformer Block のモジュレーション処理
double_stream_modulation_img, double_stream_modulation_txt の行列変換は、3072 次元の temb ベクトルを入力として、その 6 倍の 18432 次元のベクトルを出力しています。
これは、Double Stream Transformer Block の構成要素である Flux2TransformerBlock の forward 処理で下記のように参照されています。
# Modulation parameters shape: [1, 1, self.dim]
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = Flux2Modulation.split(temb_mod_img, 2)
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = Flux2Modulation.split(
temb_mod_txt, 2
)
# Img stream
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa
# Conditioning txt stream
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa
# Attention on concatenated img + txt stream
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
attn_output, context_attn_output = attention_outputs
# Process attention outputs for the image stream (`hidden_states`).
attn_output = gate_msa * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_mlp * ff_output
# Process attention outputs for the text stream (`encoder_hidden_states`).
context_attn_output = c_gate_msa * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
6 倍の 18432 次元のベクトルは、画像用の shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp の 6 つの 3072 次元のベクトルに分割されます。テキスト用のモジュレーションも同様で、c_shift_msa, c_scale_msa, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp の 6 つの 3072 次元のベクトルに分割されます。
shift_msa, scale_msa は Attention 処理の前の Layer Normalization 後に、下記のコードで画像データのベクトルを (1 + scale_msa) 倍し、それに shift_mas を加える形で変換しています。この計算は 3072 次元のベクトル要素ごとに適用されます。テキストデータについても同様です。
norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa
Attention 処理の後、Attention 層の出力を下記のスクリプトで gate_msa 倍しています。この計算もベクトル要素ごとに適用しています。
attn_output = gate_msa * attn_output
下記のスクリプトのように、self.ff(...)で適用される MLP 層 (SwiGLUが使用されている) の処理の前後でも同様の計算をしています。下記のスクリプトは画像データに MLP 層の処理を適用するコードですが、テキストデータについても同様です。
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_mlp * ff_output
上記のように、Double Stream Transformer Block では、Attention 層の前後と MLP 層の前後で参照される shift, scale, gate ベクトルがあり、double_stream_modulation_img, double_stream_modulation_txt の行列変換は、3072 次元の temb ベクトルを入力として、その 6 倍の 18432 次元のベクトルを出力しています。
3.3.2. Single Stream Transformer Block のモジュレーション処理
single_stream_modulation の行列変換は、3072 次元の temb ベクトルを入力として、その 3 倍の 9216 次元のベクトルを出力しています。
これは、Single Stream Transformer Block の構成要素である Flux2SingleTransformerBlock の forward 処理で下記のように参照されています。
mod_shift, mod_scale, mod_gate = Flux2Modulation.split(temb_mod, 1)[0]
norm_hidden_states = self.norm(hidden_states)
norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
hidden_states = hidden_states + mod_gate * attn_output
Double Stream Transformer Block のモジュレーション処理と異なり、self.attn(...)の前後でのみ mod_shift, mod_scale, mod_gate が参照されるため、6 つではなく 3 つの 3072 次元のベクトルを用意しています。Single Stream Transformer Block の上記のスクリプトの self.attn(...)では、Attention 層と MLP 層の計算が並列に処理されます。
また、Single Stream Transformer Block では各画像小領域 (高さと幅が 1/8 の潜在空間の隣接する 2 x 2 領域をまとめた領域) に対応する画像データの 3072 次元のベクトルとテキストデータの 3072 次元のトークンベクトルをまとめて共通の重み行列で処理しています。そのため、画像とテキストのために異なる shift, scale, gate を計算することもしていません。
3.4. 画像データとテキストデータを 3072 次元のベクトルに変換する行列
下記の構成要素は、Flux2Transformer2DModel の forward(…) メソッドが受け取った画像小領域とテキストトークンのベクトルを 3072 次元のベクトルに変換する行列です。
(x_embedder): Linear(in_features=128, out_features=3072, bias=False) (context_embedder): Linear(in_features=7680, out_features=3072, bias=False)
Flux2Transformer2DModel の forward(…) メソッドの下記のコードで、入力として受け取った二つの多次元配列 hidden_states と encoder_hidden_states の最後の軸の次元を 3072 次元に変換します。
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
hidden_states のサイズは torch.Size([1, 4096, 128]) のようになっていて、4096 は生成画像のサイズや入力画像のサイズに応じて変化します。画像小領域の数に対応しています。128 は、潜在空間において各画像小領域を表すデータの次元です。self.x_embedder(hidden_states)は、これを 128 次元から 3072 次元に変換します。
encoder_hidden_states のサイズは torch.Size([1, 512, 7680]) のようになっていて、512 はプロンプト文字列から生成されたテキストトークン列の長さです。長さが 512 に満たないときはパディングし、512 を超えるときは切り捨てているため、この長さ 512 は固定です。7680 は、Qwen3 4B の中間層を含む 37 層 (index は 0 – 36) の出力のうち、index が 9, 18, 27 の中間層の 2560 次元のトークンベクトルの出力を連結したベクトルの次元です。self.context_embedder(encoder_hidden_states)は、これを 7680 次元から 3072 次元に変換します。
3.5. Double Stream Transformer Blocks
Double Stream Transformer Blocks は、下記の構成要素に対応しています。Single Stream Transformer Blocks での処理の前に実行されます。FLUX.2 Klein 4B では、下記のように 5 Block 適用されます。
(transformer_blocks): ModuleList(
(0-4): 5 x Flux2TransformerBlock(
...
)
)
下の図は、Double Stream Transformer Block の処理フローです。context_states は、Flux2TransformerBlock の Python スクリプト内では encoder_hidden_states と記載されています。
画像トークン列に対応する hidden_states とテキストトークン列に対応する context_states とで、別々の重み行列を参照して Query, Key, Value の head が計算されます。head による Attention の計算の結果をまとめた出力を得るときにも、画像とテキストとで別々の重み行列を使用します。Query head と Key head に適用される RMSNorm で参照される scale ベクトルの値にも別々の重みを使用しています。
MLP 層の SwiGLU の計算の重みも画像とテキストで別々の重みを使用しています。
ただし、Attention の計算で Query head, Key head, Value head を参照する際には、画像とテキストの head を区別なく扱って計算しています。画像とテキストのデータが、互いの計算結果に影響を与えることになります。
┌───────────────────────────────────────────────┐
│ Flux2TransformerBlock │
└───────────────────────────────────────────────┘
hidden_states (image tokens) context_states (text tokens)
│ │
│ │
▼ ▼
┌────────────────┐ ┌────────────────────┐
│ LayerNorm │ │ LayerNorm │
│ (norm1) │ │ (norm1_context) │
└────────────────┘ └────────────────────┘
│ │
└──────────────┬───────────────────────┘
│
▼
┌─────────────────────┐
│ Flux2Attention │
│ (Joint Attention) │
│ │
│ hidden <-> context │
└─────────────────────┘
│
┌─────────────┴─────────────┐
│ │
▼ ▼
hidden attention out context attention out
│ │
▼ ▼
hidden residual add context residual add
│ │
▼ ▼
┌────────────────┐ ┌────────────────────┐
│ LayerNorm │ │ LayerNorm │
│ (norm2) │ │ (norm2_context) │
└────────────────┘ └────────────────────┘
│ │
▼ ▼
┌────────────────┐ ┌────────────────────┐
│ MLP(FeedFoward)│ │ MLP(FeedForward) │
│ (ff) │ │ (ff_context) │
│ │ │ │
│ 3072 → 18432 │ │ 3072 → 18432 │
│ SwiGLU │ │ SwiGLU │
│ 9216 → 3072 │ │ 9216 → 3072 │
└────────────────┘ └────────────────────┘
│ │
▼ ▼
hidden residual add context residual add
│ │
▼ ▼
hidden_states_out context_states_out
上の図の Flux2Attention 内部の処理フローは下記のようになります。
Flux2Attention
hidden_states context_states
│ │
│ │
▼ ▼
┌───────────────┐ ┌───────────────┐
│ to_q │ │ add_q_proj │
│ to_k │ │ add_k_proj │
│ to_v │ │ add_v_proj │
└───────────────┘ └───────────────┘
│ │
▼ ▼
RMSNorm(q,k) RMSNorm(q,k)
│ │
└──────────────┬─────────────────┘
│
▼
Attention
(scaled dot-product)
│
▼
┌────────────────────────┐
│ hidden attention out │
│ context attention out │
└────────────────────────┘
│
┌─────────────┴─────────────┐
▼ ▼
to_out Linear to_add_out Linear
3.6. Single Stream Transformer Blocks
Single Stream Transformer Blocks は、下記の構成要素に対応しています。Double Stream Transformer Blocks での処理の後に実行されます。FLUX.2 Klein 4B では、下記のように 20 Block 適用されます。
(single_transformer_blocks): ModuleList(
(0-19): 20 x Flux2SingleTransformerBlock(
(norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(attn): Flux2ParallelSelfAttention(
(to_qkv_mlp_proj): Linear(in_features=3072, out_features=27648, bias=False)
(mlp_act_fn): Flux2SwiGLU(
(gate_fn): SiLU()
)
(norm_q): RMSNorm((128,), eps=1e-06, elementwise_affine=True)
(norm_k): RMSNorm((128,), eps=1e-06, elementwise_affine=True)
(to_out): Linear(in_features=12288, out_features=3072, bias=False)
)
)
)
下の図は、Single Stream Transformer Block の処理フローです。Single Stream Transformer Block では各画像小領域に対応する 3072 次元の画像トークンベクトルとテキストデータの 3072 次元のテキストトークンベクトルを共通の重みで同じように処理しています。そのため、一つのトークンベクトル列 hidden_states のみを入力とする処理フローになっています。
また、Attention 層に MLP 層が続く通常の Transformer ブロックとは異なり、Self Attention 層と MLP 層の処理を並列実行しています。
┌──────────────────────────────────────────────┐
│ Flux2SingleTransformerBlock │
└──────────────────────────────────────────────┘
hidden_states (3072)
│
▼
┌────────────────────┐
│ LayerNorm │
│ norm │
└────────────────────┘
│
▼
┌──────────────────────────────┐
│ Linear : to_qkv_mlp_proj │
│ 3072 → 27648 │
└──────────────────────────────┘
│
▼
┌──────────────────┐
│ split tensor │
│ (Q,K,V | MLP) │
└─────────┬────────┘
│
┌────────────────┴────────────────┐
│ │
▼ ▼
┌─────────────────────┐ ┌─────────────────────┐
│ Self Attention │ │ MLP │
│ │ │ (SwiGLU) │
│ Q,K,V │ │ │
│ │ │ │ gate: SiLU │
│ ├─ norm_q (RMSNorm) │ │ │
│ └─ norm_k (RMSNorm) │ │ │
│ │ │ │
│ attention output │ │ MLP output │
└──────────┬──────────┘ └──────────┬──────────┘
│ │
└──────────────┬──────────────────┘
▼
┌──────────────────────┐
│ concat(attn, mlp) │
│ 12288 │
└──────────┬───────────┘
▼
┌──────────────────────┐
│ Linear : to_out │
│ 12288 → 3072 │
└──────────┬───────────┘
▼
residual add
│
▼
hidden_states_out
Single Stream Transformer Block の処理フローでは、下記の行列変換により、3072 次元のトークンベクトル列を 9倍の 27648 次元のトークンベクトル列に変換します。
(to_qkv_mlp_proj): Linear(in_features=3072, out_features=27648, bias=False)
27648 次元のベクトルのうち、Query heads, Key heads, Value heads のそれぞれに 3072 次元が割り当てられます。残り 18432 次元は SwiGLU の計算に割り当てられ、下記の図のように二分割されて処理されます。
MLP branch (SwiGLU)
input (MLP part from Linear)
dim = 18432
│
▼
┌─────────────┐
│ split │
│ 9216 | 9216 │
└──────┬──────┘
│
┌─────────┴─────────┐
│ │
▼ ▼
tensor A tensor B
dim = 9216 dim = 9216
│ │
│ │
│ ┌──────────┐
│ │ SiLU │
│ │ gate_fn │
│ └────┬─────┘
│ │
│ SiLU(B)
│ │
└─────────┬─────────┘
│
▼
elementwise multiplication
A ⊙ SiLU(B)
│
▼
SwiGLU output
dim = 9216
上記の 9216 次元の出力ベクトル列を Self Attention 層の 3072 次元の出力ベクトル列と結合し、12288 次元のベクトル列にした後、下記の行列を作用させて 3072 次元のベクトル列に戻します。
(to_out): Linear(in_features=12288, out_features=3072, bias=False)
3.7. Single Stream Transformer Blocks の後の処理
Single Stream Transformer Blocks の 20 ブロックの処理の後、Flux2Transformer2DModel の forward メソッドの下記のコードが実行され、潜在空間における画像データを更新する速度ベクトル列 (多次元配列) が出力されます。
# Remove text tokens from concatenated stream
hidden_states = hidden_states[:, num_txt_tokens:, ...]
# 6. Output layers
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
hidden_states = hidden_states[:, num_txt_tokens:, ...]は、2軸目の画像トークン列の前に連結されたテキストトークン列を取り除き、画像トークン列のみにしています。
その後、下記の層の処理を実行し、128 次元の画像更新用の速度ベクトル列 (多次元配列) output を計算します。
(norm_out): AdaLayerNormContinuous(
(silu): SiLU()
(linear): Linear(in_features=3072, out_features=6144, bias=False)
(norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
)
(proj_out): Linear(in_features=3072, out_features=128, bias=False)
上記の AdaLayerNormContinuous 層とその後の 3072 次元から 128 次元への行列変換の処理フローを表したのが下の図になります。output として得られるのは画像小領域を 128 次元のベクトルで表したベクトル列を更新する速度ベクトル列になります。128 次元の速度ベクトル列は画像小領域の数だけ存在します。
入力画像がある場合は、output にはそれに対応するデータも存在しますが、それらは呼び出し元の Flux2KleinPipeline の __call__ メソッドで破棄されます。
hidden_states
(3072)
│
│
▼
┌─────────────────────────┐
│ AdaLayerNormContinuous │
└─────────────────────────┘
1. modulation parameter の生成
────────────────────────────────
timestep embedding
(3072)
│
▼
SiLU
│
▼
Linear (3072 → 6144)
│
▼
┌───────────────┐
│ split │
│ │
▼ ▼
scale (3072) shift (3072)
2. LayerNorm
────────────────────────────────
hidden_states (3072)
│
▼
LayerNorm
(elementwise_affine = False)
│
▼
normalized_states
3. modulation
────────────────────────────────
normalized_states
│
▼
normalized_states * (1 + scale) + shift
│
▼
modulated_states (3072)
4. 出力 projection
────────────────────────────────
modulated_states
│
▼
Linear (3072 → 128)
│
▼
output
(dim = 128)