From e72f41bf0ab2433f0f4a7c3fdf189ba744578a87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 27 Jun 2026 00:47:09 +0200 Subject: [PATCH 1/2] feat: support Qwen2D VAE --- src/model/vae/wan_vae.hpp | 157 ++++++++++++++++++++++++++++++-------- 1 file changed, 127 insertions(+), 30 deletions(-) diff --git a/src/model/vae/wan_vae.hpp b/src/model/vae/wan_vae.hpp index 8a845c7ca..8ecb98c31 100644 --- a/src/model/vae/wan_vae.hpp +++ b/src/model/vae/wan_vae.hpp @@ -113,6 +113,26 @@ namespace WAN { } }; + + class Conv2dBut3d : public Conv2d { + public: + using Conv2d::Conv2d; + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + ggml_tensor* x_swapped = ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2); + x_swapped = ggml_cont(ctx->ggml_ctx, x_swapped); + + ggml_tensor* out = Conv2d::forward(ctx, x_swapped); + + ggml_tensor* out_swapped = ggml_permute(ctx->ggml_ctx, out, 0, 1, 3, 2); + + out_swapped = ggml_cont(ctx->ggml_ctx, out_swapped); + + return out_swapped; + } + }; + + class Resample : public GGMLBlock { protected: int64_t dim; @@ -338,19 +358,32 @@ namespace WAN { protected: int64_t in_dim; int64_t out_dim; + bool is_2D; public: - ResidualBlock(int64_t in_dim, int64_t out_dim) - : in_dim(in_dim), out_dim(out_dim) { + ResidualBlock(int64_t in_dim, int64_t out_dim, bool is_2D = false) + : in_dim(in_dim), out_dim(out_dim), is_2D(is_2D) { blocks["residual.0"] = std::shared_ptr(new RMS_norm(in_dim)); // residual.1 is nn.SiLU() - blocks["residual.2"] = std::shared_ptr(new CausalConv3d(in_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + if (is_2D) { + blocks["residual.2"] = std::shared_ptr(new Conv2dBut3d(in_dim, out_dim, {3, 3}, {1, 1}, {1, 1})); + } else { + blocks["residual.2"] = std::shared_ptr(new CausalConv3d(in_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + } blocks["residual.3"] = std::shared_ptr(new RMS_norm(out_dim)); // residual.4 is nn.SiLU() // residual.5 is nn.Dropout() - blocks["residual.6"] = std::shared_ptr(new CausalConv3d(out_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + if (is_2D) { + blocks["residual.6"] = std::shared_ptr(new Conv2dBut3d(out_dim, out_dim, {3, 3}, {1, 1}, {1, 1})); + } else { + blocks["residual.6"] = std::shared_ptr(new CausalConv3d(out_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + } if (in_dim != out_dim) { - blocks["shortcut"] = std::shared_ptr(new CausalConv3d(in_dim, out_dim, {1, 1, 1})); + if (is_2D) { + blocks["shortcut"] = std::shared_ptr(new Conv2dBut3d(in_dim, out_dim, {1, 1})); + } else { + blocks["shortcut"] = std::shared_ptr(new CausalConv3d(in_dim, out_dim, {1, 1, 1})); + } } } @@ -363,9 +396,15 @@ namespace WAN { GGML_ASSERT(b == 1); ggml_tensor* h = x; if (in_dim != out_dim) { - auto shortcut = std::dynamic_pointer_cast(blocks["shortcut"]); + if (is_2D) { + auto shortcut = std::dynamic_pointer_cast(blocks["shortcut"]); + + h = shortcut->forward(ctx, x); + } else { + auto shortcut = std::dynamic_pointer_cast(blocks["shortcut"]); - h = shortcut->forward(ctx, x); + h = shortcut->forward(ctx, x); + } } for (int i = 0; i < 7; i++) { @@ -385,8 +424,13 @@ namespace WAN { cache_x, 2); } + if (is_2D) { + auto layer = std::dynamic_pointer_cast(blocks["residual." + std::to_string(i)]); - x = layer->forward(ctx, x, feat_cache[idx]); + x = layer->forward(ctx, x); + } else { + x = layer->forward(ctx, x, feat_cache[idx]); + } feat_cache[idx] = cache_x; feat_idx += 1; } @@ -466,21 +510,23 @@ namespace WAN { protected: int mult; bool up_flag; + bool is_2D = false; public: Up_ResidualBlock(int64_t in_dim, int64_t out_dim, int mult, bool temperal_upsample = false, - bool up_flag = false) - : mult(mult), up_flag(up_flag) { + bool up_flag = false, + bool is_2D = false) + : mult(mult), up_flag(up_flag), is_2D(is_2D) { if (up_flag) { blocks["avg_shortcut"] = std::shared_ptr(new DupUp3D(in_dim, out_dim, temperal_upsample ? 2 : 1, up_flag ? 2 : 1)); } int i = 0; for (; i < mult; i++) { - blocks["upsamples." + std::to_string(i)] = std::shared_ptr(new ResidualBlock(in_dim, out_dim)); + blocks["upsamples." + std::to_string(i)] = std::shared_ptr(new ResidualBlock(in_dim, out_dim, is_2D)); in_dim = out_dim; } if (up_flag) { @@ -758,6 +804,7 @@ namespace WAN { std::vector dim_mult; int num_res_blocks; std::vector temperal_upsample; + bool is_2D = false; public: Decoder3d(int64_t dim = 128, @@ -765,13 +812,15 @@ namespace WAN { std::vector dim_mult = {1, 2, 4, 4}, int num_res_blocks = 2, std::vector temperal_upsample = {true, true, false}, - bool wan2_2 = false) + bool wan2_2 = false, + bool is_2D = false) : dim(dim), z_dim(z_dim), dim_mult(dim_mult), num_res_blocks(num_res_blocks), temperal_upsample(temperal_upsample), - wan2_2(wan2_2) { + wan2_2(wan2_2), + is_2D(is_2D) { // attn_scales is always [] std::vector dims = {dim_mult[dim_mult.size() - 1] * dim}; for (int i = static_cast(dim_mult.size()) - 1; i >= 0; i--) { @@ -779,12 +828,16 @@ namespace WAN { } // init block - blocks["conv1"] = std::shared_ptr(new CausalConv3d(z_dim, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + if(is_2D){ + blocks["conv1"] = std::shared_ptr(new Conv2dBut3d(z_dim, dims[0], {3, 3}, {1, 1}, {1, 1})); + }else{ + blocks["conv1"] = std::shared_ptr(new CausalConv3d(z_dim, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + } // middle blocks - blocks["middle.0"] = std::shared_ptr(new ResidualBlock(dims[0], dims[0])); + blocks["middle.0"] = std::shared_ptr(new ResidualBlock(dims[0], dims[0], is_2D)); blocks["middle.1"] = std::shared_ptr(new AttentionBlock(dims[0])); - blocks["middle.2"] = std::shared_ptr(new ResidualBlock(dims[0], dims[0])); + blocks["middle.2"] = std::shared_ptr(new ResidualBlock(dims[0], dims[0], is_2D)); // upsample blocks int index = 0; @@ -799,7 +852,8 @@ namespace WAN { out_dim, num_res_blocks + 1, t_up_flag, - i != dim_mult.size() - 1)); + i != dim_mult.size() - 1, + is_2D)); blocks["upsamples." + std::to_string(index++)] = block; } else { @@ -807,7 +861,7 @@ namespace WAN { in_dim = in_dim / 2; } for (int j = 0; j < num_res_blocks + 1; j++) { - auto block = std::shared_ptr(new ResidualBlock(in_dim, out_dim)); + auto block = std::shared_ptr(new ResidualBlock(in_dim, out_dim, is_2D)); blocks["upsamples." + std::to_string(index++)] = block; in_dim = out_dim; } @@ -822,12 +876,13 @@ namespace WAN { // output blocks blocks["head.0"] = std::shared_ptr(new RMS_norm(out_dim)); + int64_t final_dim = wan2_2 ? 12 : 3; // head.1 is nn.SiLU() - if (wan2_2) { - blocks["head.2"] = std::shared_ptr(new CausalConv3d(out_dim, 12, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + if (is_2D) { + blocks["head.2"] = std::shared_ptr(new Conv2dBut3d(out_dim, final_dim, {3, 3}, {1, 1}, {1, 1})); } else { - blocks["head.2"] = std::shared_ptr(new CausalConv3d(out_dim, 3, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + blocks["head.2"] = std::shared_ptr(new CausalConv3d(out_dim, final_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); } } @@ -847,7 +902,10 @@ namespace WAN { auto head_2 = std::dynamic_pointer_cast(blocks["head.2"]); // conv1 - if (feat_cache.size() > 0) { + if (is_2D) { + auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); + x = conv1->forward(ctx, x); + } else if (feat_cache.size() > 0) { int idx = feat_idx; auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]); if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { @@ -902,7 +960,10 @@ namespace WAN { // head x = head_0->forward(ctx, x); x = ggml_silu(ctx->ggml_ctx, x); - if (feat_cache.size() > 0) { + if (is_2D) { + auto head_2 = std::dynamic_pointer_cast(blocks["head.2"]); + x = head_2->forward(ctx, x); + } else if (feat_cache.size() > 0) { int idx = feat_idx; auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]); if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { @@ -935,6 +996,7 @@ namespace WAN { int num_res_blocks = 2; std::vector temperal_upsample = {true, true, false}; std::vector temperal_downsample = {false, true, true}; + bool is_2D = false; int _conv_num = 33; int _conv_idx = 0; @@ -951,8 +1013,8 @@ namespace WAN { } public: - WanVAE(bool decode_only = true, bool wan2_2 = false) - : decode_only(decode_only), wan2_2(wan2_2) { + WanVAE(bool decode_only = true, bool wan2_2 = false, bool is_2D = false) + : decode_only(decode_only), wan2_2(wan2_2), is_2D(is_2D) { // attn_scales is always [] if (wan2_2) { dim = 160; @@ -962,12 +1024,28 @@ namespace WAN { _conv_num = 34; _enc_conv_num = 26; } + + if(is_2D){ + temperal_upsample = {false, false, false}; + temperal_downsample = {false, false, false}; + // TODO : encode 2D + decode_only = true; + } + if (!decode_only) { blocks["encoder"] = std::shared_ptr(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample, wan2_2)); - blocks["conv1"] = std::shared_ptr(new CausalConv3d(z_dim * 2, z_dim * 2, {1, 1, 1})); + if (is_2D) { + blocks["conv1"] = std::shared_ptr(new Conv2dBut3d(z_dim * 2, z_dim * 2, {1, 1})); + } else { + blocks["conv1"] = std::shared_ptr(new CausalConv3d(z_dim * 2, z_dim * 2, {1, 1, 1})); + } + } + blocks["decoder"] = std::shared_ptr(new Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, temperal_upsample, wan2_2, is_2D)); + if (is_2D) { + blocks["conv2"] = std::shared_ptr(new Conv2dBut3d(z_dim, z_dim, {1, 1})); + } else { + blocks["conv2"] = std::shared_ptr(new CausalConv3d(z_dim, z_dim, {1, 1, 1})); } - blocks["decoder"] = std::shared_ptr(new Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, temperal_upsample, wan2_2)); - blocks["conv2"] = std::shared_ptr(new CausalConv3d(z_dim, z_dim, {1, 1, 1})); } static ggml_tensor* patchify(ggml_context* ctx, @@ -1073,7 +1151,13 @@ namespace WAN { auto conv2 = std::dynamic_pointer_cast(blocks["conv2"]); int64_t iter_ = z->ne[2]; - auto x = conv2->forward(ctx, z); + auto x = z; + if(is_2D){ + auto conv2 = std::dynamic_pointer_cast(blocks["conv2"]); + x = conv2->forward(ctx, z); + } else { + x = conv2->forward(ctx, z); + } // sd::ggml_graph_cut::mark_graph_cut(x, "wan_vae.decode.prelude", "x"); ggml_tensor* out; for (int i = 0; i < iter_; i++) { @@ -1129,7 +1213,20 @@ namespace WAN { bool decode_only = false, SDVersion version = VERSION_WAN2, std::shared_ptr weight_manager = nullptr) - : VAE(version, backend, prefix, weight_manager), decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V) { + : VAE(version, backend, prefix, weight_manager), decode_only(decode_only) { + bool is_2D = false; + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (ends_with(name, "decoder.conv1.weight")) { + if (tensor_storage.ne[2] > 3) { + is_2D = true; + } + break; + } + } + if (is_2D) { + LOG_DEBUG("USING 2D VAE"); + } + ae = WanVAE(decode_only, version == VERSION_WAN2_2_TI2V, is_2D); ae.init(params_ctx, tensor_storage_map, prefix); } From ed2dfbee0e9aecca1222ba7591332668915a7835 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 27 Jun 2026 11:25:52 +0200 Subject: [PATCH 2/2] encode support --- src/model/vae/wan_vae.hpp | 58 +++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/src/model/vae/wan_vae.hpp b/src/model/vae/wan_vae.hpp index 8ecb98c31..1d52fff74 100644 --- a/src/model/vae/wan_vae.hpp +++ b/src/model/vae/wan_vae.hpp @@ -456,13 +456,14 @@ namespace WAN { int64_t out_dim, int mult, bool temperal_downsample = false, - bool down_flag = false) + bool down_flag = false, + bool is_2D = false) : mult(mult), down_flag(down_flag) { blocks["avg_shortcut"] = std::shared_ptr(new AvgDown3D(in_dim, out_dim, temperal_downsample ? 2 : 1, down_flag ? 2 : 1)); int i = 0; for (; i < mult; i++) { - blocks["downsamples." + std::to_string(i)] = std::shared_ptr(new ResidualBlock(in_dim, out_dim)); + blocks["downsamples." + std::to_string(i)] = std::shared_ptr(new ResidualBlock(in_dim, out_dim, is_2D)); in_dim = out_dim; } if (down_flag) { @@ -510,7 +511,6 @@ namespace WAN { protected: int mult; bool up_flag; - bool is_2D = false; public: Up_ResidualBlock(int64_t in_dim, @@ -519,7 +519,7 @@ namespace WAN { bool temperal_upsample = false, bool up_flag = false, bool is_2D = false) - : mult(mult), up_flag(up_flag), is_2D(is_2D) { + : mult(mult), up_flag(up_flag) { if (up_flag) { blocks["avg_shortcut"] = std::shared_ptr(new DupUp3D(in_dim, out_dim, temperal_upsample ? 2 : 1, up_flag ? 2 : 1)); } @@ -638,6 +638,7 @@ namespace WAN { std::vector dim_mult; int num_res_blocks; std::vector temperal_downsample; + bool is_2D = false; public: Encoder3d(int64_t dim = 128, @@ -645,23 +646,26 @@ namespace WAN { std::vector dim_mult = {1, 2, 4, 4}, int num_res_blocks = 2, std::vector temperal_downsample = {false, true, true}, - bool wan2_2 = false) + bool wan2_2 = false, + bool is_2D = false) : dim(dim), z_dim(z_dim), dim_mult(dim_mult), num_res_blocks(num_res_blocks), temperal_downsample(temperal_downsample), - wan2_2(wan2_2) { + wan2_2(wan2_2), + is_2D(is_2D) { // attn_scales is always [] std::vector dims = {dim}; for (int u : dim_mult) { dims.push_back(dim * u); } - if (wan2_2) { - blocks["conv1"] = std::shared_ptr(new CausalConv3d(12, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + int64_t input_dim = wan2_2 ? 12 : 3; + if (is_2D) { + blocks["conv1"] = std::shared_ptr(new Conv2dBut3d(input_dim, dims[0], {3, 3}, {1, 1}, {1, 1})); } else { - blocks["conv1"] = std::shared_ptr(new CausalConv3d(3, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + blocks["conv1"] = std::shared_ptr(new CausalConv3d(input_dim, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); } int index = 0; @@ -676,12 +680,13 @@ namespace WAN { out_dim, num_res_blocks, t_down_flag, - i != dim_mult.size() - 1)); + i != dim_mult.size() - 1, + is_2D)); blocks["downsamples." + std::to_string(index++)] = block; } else { for (int j = 0; j < num_res_blocks; j++) { - auto block = std::shared_ptr(new ResidualBlock(in_dim, out_dim)); + auto block = std::shared_ptr(new ResidualBlock(in_dim, out_dim, is_2D)); blocks["downsamples." + std::to_string(index++)] = block; in_dim = out_dim; } @@ -694,13 +699,17 @@ namespace WAN { } } - blocks["middle.0"] = std::shared_ptr(new ResidualBlock(out_dim, out_dim)); + blocks["middle.0"] = std::shared_ptr(new ResidualBlock(out_dim, out_dim, is_2D)); blocks["middle.1"] = std::shared_ptr(new AttentionBlock(out_dim)); - blocks["middle.2"] = std::shared_ptr(new ResidualBlock(out_dim, out_dim)); + blocks["middle.2"] = std::shared_ptr(new ResidualBlock(out_dim, out_dim, is_2D)); blocks["head.0"] = std::shared_ptr(new RMS_norm(out_dim)); // head.1 is nn.SiLU() - blocks["head.2"] = std::shared_ptr(new CausalConv3d(out_dim, z_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + if (is_2D) { + blocks["head.2"] = std::shared_ptr(new Conv2dBut3d(out_dim, z_dim, {3, 3}, {1, 1}, {1, 1})); + } else { + blocks["head.2"] = std::shared_ptr(new CausalConv3d(out_dim, z_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + } } ggml_tensor* forward(GGMLRunnerContext* ctx, @@ -719,7 +728,10 @@ namespace WAN { auto head_2 = std::dynamic_pointer_cast(blocks["head.2"]); // conv1 - if (feat_cache.size() > 0) { + if (is_2D) { + auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); + x = conv1->forward(ctx, x); + } else if (feat_cache.size() > 0) { int idx = feat_idx; auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]); if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { @@ -774,7 +786,10 @@ namespace WAN { // head x = head_0->forward(ctx, x); x = ggml_silu(ctx->ggml_ctx, x); - if (feat_cache.size() > 0) { + if (is_2D) { + auto head_2 = std::dynamic_pointer_cast(blocks["head.2"]); + x = head_2->forward(ctx, x); + } else if (feat_cache.size() > 0) { int idx = feat_idx; auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]); if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { @@ -1028,12 +1043,10 @@ namespace WAN { if(is_2D){ temperal_upsample = {false, false, false}; temperal_downsample = {false, false, false}; - // TODO : encode 2D - decode_only = true; } if (!decode_only) { - blocks["encoder"] = std::shared_ptr(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample, wan2_2)); + blocks["encoder"] = std::shared_ptr(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample, wan2_2, is_2D)); if (is_2D) { blocks["conv1"] = std::shared_ptr(new Conv2dBut3d(z_dim * 2, z_dim * 2, {1, 1})); } else { @@ -1132,7 +1145,12 @@ namespace WAN { out = ggml_concat(ctx->ggml_ctx, out, out_, 2); } } - out = conv1->forward(ctx, out); + if (is_2D) { + auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); + out = conv1->forward(ctx, out); + } else { + out = conv1->forward(ctx, out); + } auto mu = ggml_ext_chunk(ctx->ggml_ctx, out, 2, 3)[0]; // sd::ggml_graph_cut::mark_graph_cut(mu, "wan_vae.encode.final", "mu"); clear_cache();