Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 127 additions & 30 deletions src/model/vae/wan_vae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,26 @@ namespace WAN {
}
};


class Conv2dBut3d : public Conv2d {
public:
using Conv2d::Conv2d;

ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
ggml_tensor* x_swapped = ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2);
x_swapped = ggml_cont(ctx->ggml_ctx, x_swapped);

ggml_tensor* out = Conv2d::forward(ctx, x_swapped);

ggml_tensor* out_swapped = ggml_permute(ctx->ggml_ctx, out, 0, 1, 3, 2);

out_swapped = ggml_cont(ctx->ggml_ctx, out_swapped);

return out_swapped;
}
};


class Resample : public GGMLBlock {
protected:
int64_t dim;
Expand Down Expand Up @@ -338,19 +358,32 @@ namespace WAN {
protected:
int64_t in_dim;
int64_t out_dim;
bool is_2D;

public:
ResidualBlock(int64_t in_dim, int64_t out_dim)
: in_dim(in_dim), out_dim(out_dim) {
ResidualBlock(int64_t in_dim, int64_t out_dim, bool is_2D = false)
: in_dim(in_dim), out_dim(out_dim), is_2D(is_2D) {
blocks["residual.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(in_dim));
// residual.1 is nn.SiLU()
blocks["residual.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
if (is_2D) {
blocks["residual.2"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(in_dim, out_dim, {3, 3}, {1, 1}, {1, 1}));
} else {
blocks["residual.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
}
blocks["residual.3"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
// residual.4 is nn.SiLU()
// residual.5 is nn.Dropout()
blocks["residual.6"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
if (is_2D) {
blocks["residual.6"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(out_dim, out_dim, {3, 3}, {1, 1}, {1, 1}));
} else {
blocks["residual.6"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
}
if (in_dim != out_dim) {
blocks["shortcut"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_dim, out_dim, {1, 1, 1}));
if (is_2D) {
blocks["shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(in_dim, out_dim, {1, 1}));
} else {
blocks["shortcut"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_dim, out_dim, {1, 1, 1}));
}
}
}

Expand All @@ -363,9 +396,15 @@ namespace WAN {
GGML_ASSERT(b == 1);
ggml_tensor* h = x;
if (in_dim != out_dim) {
auto shortcut = std::dynamic_pointer_cast<CausalConv3d>(blocks["shortcut"]);
if (is_2D) {
auto shortcut = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["shortcut"]);

h = shortcut->forward(ctx, x);
} else {
auto shortcut = std::dynamic_pointer_cast<CausalConv3d>(blocks["shortcut"]);

h = shortcut->forward(ctx, x);
h = shortcut->forward(ctx, x);
}
}

for (int i = 0; i < 7; i++) {
Expand All @@ -385,8 +424,13 @@ namespace WAN {
cache_x,
2);
}
if (is_2D) {
auto layer = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["residual." + std::to_string(i)]);

x = layer->forward(ctx, x, feat_cache[idx]);
x = layer->forward(ctx, x);
} else {
x = layer->forward(ctx, x, feat_cache[idx]);
}
feat_cache[idx] = cache_x;
feat_idx += 1;
}
Expand Down Expand Up @@ -466,21 +510,23 @@ namespace WAN {
protected:
int mult;
bool up_flag;
bool is_2D = false;

public:
Up_ResidualBlock(int64_t in_dim,
int64_t out_dim,
int mult,
bool temperal_upsample = false,
bool up_flag = false)
: mult(mult), up_flag(up_flag) {
bool up_flag = false,
bool is_2D = false)
: mult(mult), up_flag(up_flag), is_2D(is_2D) {
if (up_flag) {
blocks["avg_shortcut"] = std::shared_ptr<GGMLBlock>(new DupUp3D(in_dim, out_dim, temperal_upsample ? 2 : 1, up_flag ? 2 : 1));
}

int i = 0;
for (; i < mult; i++) {
blocks["upsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
blocks["upsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim, is_2D));
in_dim = out_dim;
}
if (up_flag) {
Expand Down Expand Up @@ -758,33 +804,40 @@ namespace WAN {
std::vector<int> dim_mult;
int num_res_blocks;
std::vector<bool> temperal_upsample;
bool is_2D = false;

public:
Decoder3d(int64_t dim = 128,
int64_t z_dim = 4,
std::vector<int> dim_mult = {1, 2, 4, 4},
int num_res_blocks = 2,
std::vector<bool> temperal_upsample = {true, true, false},
bool wan2_2 = false)
bool wan2_2 = false,
bool is_2D = false)
: dim(dim),
z_dim(z_dim),
dim_mult(dim_mult),
num_res_blocks(num_res_blocks),
temperal_upsample(temperal_upsample),
wan2_2(wan2_2) {
wan2_2(wan2_2),
is_2D(is_2D) {
// attn_scales is always []
std::vector<int64_t> dims = {dim_mult[dim_mult.size() - 1] * dim};
for (int i = static_cast<int>(dim_mult.size()) - 1; i >= 0; i--) {
dims.push_back(dim * dim_mult[i]);
}

// init block
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
if(is_2D){
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(z_dim, dims[0], {3, 3}, {1, 1}, {1, 1}));
}else{
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
}

// middle blocks
blocks["middle.0"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0]));
blocks["middle.0"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0], is_2D));
blocks["middle.1"] = std::shared_ptr<GGMLBlock>(new AttentionBlock(dims[0]));
blocks["middle.2"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0]));
blocks["middle.2"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0], is_2D));

// upsample blocks
int index = 0;
Expand All @@ -799,15 +852,16 @@ namespace WAN {
out_dim,
num_res_blocks + 1,
t_up_flag,
i != dim_mult.size() - 1));
i != dim_mult.size() - 1,
is_2D));

blocks["upsamples." + std::to_string(index++)] = block;
} else {
if (i == 1 || i == 2 || i == 3) {
in_dim = in_dim / 2;
}
for (int j = 0; j < num_res_blocks + 1; j++) {
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim, is_2D));
blocks["upsamples." + std::to_string(index++)] = block;
in_dim = out_dim;
}
Expand All @@ -822,12 +876,13 @@ namespace WAN {

// output blocks
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
int64_t final_dim = wan2_2 ? 12 : 3;
// head.1 is nn.SiLU()
if (wan2_2) {
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 12, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
if (is_2D) {
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(out_dim, final_dim, {3, 3}, {1, 1}, {1, 1}));

} else {
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 3, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, final_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
}
}

Expand All @@ -847,7 +902,10 @@ namespace WAN {
auto head_2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["head.2"]);

// conv1
if (feat_cache.size() > 0) {
if (is_2D) {
auto conv1 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv1"]);
x = conv1->forward(ctx, x);
} else if (feat_cache.size() > 0) {
int idx = feat_idx;
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
Expand Down Expand Up @@ -902,7 +960,10 @@ namespace WAN {
// head
x = head_0->forward(ctx, x);
x = ggml_silu(ctx->ggml_ctx, x);
if (feat_cache.size() > 0) {
if (is_2D) {
auto head_2 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["head.2"]);
x = head_2->forward(ctx, x);
} else if (feat_cache.size() > 0) {
int idx = feat_idx;
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
Expand Down Expand Up @@ -935,6 +996,7 @@ namespace WAN {
int num_res_blocks = 2;
std::vector<bool> temperal_upsample = {true, true, false};
std::vector<bool> temperal_downsample = {false, true, true};
bool is_2D = false;

int _conv_num = 33;
int _conv_idx = 0;
Expand All @@ -951,8 +1013,8 @@ namespace WAN {
}

public:
WanVAE(bool decode_only = true, bool wan2_2 = false)
: decode_only(decode_only), wan2_2(wan2_2) {
WanVAE(bool decode_only = true, bool wan2_2 = false, bool is_2D = false)
: decode_only(decode_only), wan2_2(wan2_2), is_2D(is_2D) {
// attn_scales is always []
if (wan2_2) {
dim = 160;
Expand All @@ -962,12 +1024,28 @@ namespace WAN {
_conv_num = 34;
_enc_conv_num = 26;
}

if(is_2D){
temperal_upsample = {false, false, false};
temperal_downsample = {false, false, false};
// TODO : encode 2D
decode_only = true;
}

if (!decode_only) {
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample, wan2_2));
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim * 2, z_dim * 2, {1, 1, 1}));
if (is_2D) {
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(z_dim * 2, z_dim * 2, {1, 1}));
} else {
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim * 2, z_dim * 2, {1, 1, 1}));
}
}
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, temperal_upsample, wan2_2, is_2D));
if (is_2D) {
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(z_dim, z_dim, {1, 1}));
} else {
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, z_dim, {1, 1, 1}));
}
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, temperal_upsample, wan2_2));
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, z_dim, {1, 1, 1}));
}

static ggml_tensor* patchify(ggml_context* ctx,
Expand Down Expand Up @@ -1073,7 +1151,13 @@ namespace WAN {
auto conv2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv2"]);

int64_t iter_ = z->ne[2];
auto x = conv2->forward(ctx, z);
auto x = z;
if(is_2D){
auto conv2 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv2"]);
x = conv2->forward(ctx, z);
} else {
x = conv2->forward(ctx, z);
}
// sd::ggml_graph_cut::mark_graph_cut(x, "wan_vae.decode.prelude", "x");
ggml_tensor* out;
for (int i = 0; i < iter_; i++) {
Expand Down Expand Up @@ -1129,7 +1213,20 @@ namespace WAN {
bool decode_only = false,
SDVersion version = VERSION_WAN2,
std::shared_ptr<RunnerWeightManager> weight_manager = nullptr)
: VAE(version, backend, prefix, weight_manager), decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V) {
: VAE(version, backend, prefix, weight_manager), decode_only(decode_only) {
bool is_2D = false;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (ends_with(name, "decoder.conv1.weight")) {
if (tensor_storage.ne[2] > 3) {
is_2D = true;
}
break;
}
}
if (is_2D) {
LOG_DEBUG("USING 2D VAE");
}
ae = WanVAE(decode_only, version == VERSION_WAN2_2_TI2V, is_2D);
ae.init(params_ctx, tensor_storage_map, prefix);
}

Expand Down
Loading