Created
June 13, 2024 05:41
-
-
Save laksjdjf/742fa0a17415f809bfce737351667102 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
================================================================================================================================================================ | |
Layer (type (var_name)) Input Shape Output Shape Param # Kernel Shape | |
================================================================================================================================================================ | |
SD3Transformer2DModel (SD3Transformer2DModel) -- [1, 16, 128, 128] -- -- | |
├─PatchEmbed (pos_embed) [1, 16, 128, 128] [1, 4096, 1536] -- -- | |
│ └─Conv2d (proj) [1, 16, 128, 128] [1, 1536, 64, 64] 99,840 [2, 2] | |
├─CombinedTimestepTextProjEmbeddings (time_text_embed) [1] [1, 1536] -- -- | |
│ └─Timesteps (time_proj) [1] [1, 256] -- -- | |
│ └─TimestepEmbedding (timestep_embedder) [1, 256] [1, 1536] -- -- | |
│ │ └─Linear (linear_1) [1, 256] [1, 1536] 394,752 -- | |
│ │ └─SiLU (act) [1, 1536] [1, 1536] -- -- | |
│ │ └─Linear (linear_2) [1, 1536] [1, 1536] 2,360,832 -- | |
│ └─PixArtAlphaTextProjection (text_embedder) [1, 2048] [1, 1536] -- -- | |
│ │ └─Linear (linear_1) [1, 2048] [1, 1536] 3,147,264 -- | |
│ │ └─SiLU (act_1) [1, 1536] [1, 1536] -- -- | |
│ │ └─Linear (linear_2) [1, 1536] [1, 1536] 2,360,832 -- | |
├─Linear (context_embedder) [1, 154, 4096] [1, 154, 1536] 6,292,992 -- | |
├─ModuleList (transformer_blocks) -- -- -- -- | |
│ └─JointTransformerBlock (0) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (1) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (2) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (3) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (4) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (5) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (6) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (7) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (8) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (9) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (10) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (11) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (12) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (13) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (14) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (15) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (16) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (17) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (18) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (19) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (20) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (21) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (22) -- [1, 154, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
│ └─JointTransformerBlock (23) -- -- -- -- | |
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─AdaLayerNormContinuous (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ │ │ └─Linear (linear) [1, 1536] [1, 3072] 4,721,664 -- | |
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
│ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
│ │ │ └─ModuleList (to_out) -- -- -- -- | |
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ │ │ └─ModuleList (net) -- -- -- -- | |
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
├─AdaLayerNormContinuous (norm_out) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
│ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
│ └─Linear (linear) [1, 1536] [1, 3072] 4,721,664 -- | |
│ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
├─Linear (proj_out) [1, 4096, 1536] [1, 4096, 64] 98,368 -- | |
================================================================================================================================================================ | |
Total params: 2,028,328,000 | |
Trainable params: 2,028,328,000 | |
Non-trainable params: 0 | |
Total mult-adds (G): 2.44 | |
================================================================================================================================================================ | |
Input size (MB): 1.79 | |
Forward/backward pass size (MB): 5663.46 | |
Params size (MB): 4056.66 | |
Estimated Total Size (MB): 9721.90 | |
================================================================================================================================================================ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment