Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 7 additions & 95 deletions src/anima.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,81 +6,13 @@
#include <utility>
#include <vector>

#include "common.hpp"
#include "common_block.hpp"
#include "flux.hpp"
#include "ggml_extend.hpp"
#include "rope.hpp"

namespace Anima {
constexpr int ANIMA_GRAPH_SIZE = 65536;

__STATIC_INLINE__ struct ggml_tensor* patchify_2d(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t patch_size) {
// x: [W*r, H*q, T, C]
// return: [W, H, T, C*q*r]
if (patch_size == 1) {
return x;
}
GGML_ASSERT(x->ne[2] == 1);

int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t T = x->ne[2];
int64_t C = x->ne[3];
int64_t p = patch_size;
int64_t h = H / p;
int64_t w = W / p;

GGML_ASSERT(T == 1);
GGML_ASSERT(h * p == H && w * p == W);

// Reuse Flux patchify layout on a [W, H, C, N] view.
x = ggml_reshape_4d(ctx, x, W, H, C, T); // [W, H, C, N]

// Flux patchify: [N, C, H, W] -> [N, h*w, C*p*p]
x = ggml_reshape_4d(ctx, x, p, w, p, h * C * T); // [p, w, p, h*C*N]
x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [p, p, w, h*C*N]
x = ggml_reshape_4d(ctx, x, p * p, w * h, C, T); // [p*p, h*w, C, N]
x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [p*p, C, h*w, N]
x = ggml_reshape_3d(ctx, x, p * p * C, w * h, T); // [C*p*p, h*w, N]

// Return [w, h, T, C*p*p]
x = ggml_reshape_4d(ctx, x, p * p * C, w, h, T); // [C*p*p, w, h, N]
x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [w, h, N, C*p*p]
return x;
}

__STATIC_INLINE__ struct ggml_tensor* unpatchify_2d(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t patch_size) {
// x: [W, H, T, C*q*r]
// return: [W*r, H*q, T, C]
if (patch_size == 1) {
return x;
}
GGML_ASSERT(x->ne[2] == 1);

int64_t w = x->ne[0];
int64_t h = x->ne[1];
int64_t T = x->ne[2];
int64_t p = patch_size;
int64_t nm = p * p;
int64_t Cp = x->ne[3];
int64_t C = Cp / nm;
int64_t W = w * p;
int64_t H = h * p;

GGML_ASSERT(T == 1);
GGML_ASSERT(C * nm == Cp);

// [w, h, 1, C*p*p] -> [W, H, 1, C]
x = ggml_reshape_4d(ctx, x, w, h * C, p, p); // [w, h*C, p2, p1]
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 3, 1)); // [p2, w, p1, h*C]
x = ggml_reshape_4d(ctx, x, W, H, T, C); // [W, H, 1, C]
return x;
}

__STATIC_INLINE__ struct ggml_tensor* apply_gate(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* gate) {
Expand Down Expand Up @@ -491,7 +423,7 @@ namespace Anima {
int64_t text_embed_dim = 1024;
int64_t num_heads = 16;
int64_t head_dim = 128;
int64_t patch_size = 2;
int patch_size = 2;
int64_t num_layers = 28;
std::vector<int> axes_dim = {44, 42, 42};
int theta = 10000;
Expand Down Expand Up @@ -533,24 +465,10 @@ namespace Anima {
int64_t W = x->ne[0];
int64_t H = x->ne[1];

x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]); // [N*C, T, H, W] style

int64_t pad_h = (patch_size - H % patch_size) % patch_size;
int64_t pad_w = (patch_size - W % patch_size) % patch_size;
if (pad_h > 0 || pad_w > 0) {
x = ggml_ext_pad(ctx->ggml_ctx, x, static_cast<int>(pad_w), static_cast<int>(pad_h), 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
}

auto padding_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[0], x->ne[1], x->ne[2], 1);
x = ggml_concat(ctx->ggml_ctx, x, padding_mask, 3); // concat mask channel

x = patchify_2d(ctx->ggml_ctx, x, patch_size); // [C*4, T, H/2, W/2]
auto padding_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[0], x->ne[1], 1, x->ne[3]);
x = ggml_concat(ctx->ggml_ctx, x, padding_mask, 2); // [N, C + 1, H, W]

int64_t w_len = x->ne[0];
int64_t h_len = x->ne[1];
int64_t t_len = x->ne[2];
x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3], 1);
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, n_token, C]
x = DiT::pad_and_patchify(ctx, x, patch_size, patch_size); // [N, h*w, (C+1)*ph*pw]

x = x_embedder->forward(ctx, x);

Expand Down Expand Up @@ -586,15 +504,9 @@ namespace Anima {
x = block->forward(ctx, x, encoder_hidden_states, embedded_timestep, temb, image_pe);
}

x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, n_token, C*4]

x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [n_token, C*4, N]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w_len, h_len, t_len, x->ne[1]); // [C*4, T, H/2, W/2]
x = unpatchify_2d(ctx->ggml_ctx, x, patch_size); // [C, T, H, W]
x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, h*w, ph*pw*C]

x = ggml_ext_slice(ctx->ggml_ctx, x, 1, 0, H); // [C, T, H, W + pad]
x = ggml_ext_slice(ctx->ggml_ctx, x, 0, 0, W); // [C, T, H, W]
x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], x->ne[3], x->ne[2]); // [N, C, H, W]
x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, patch_size, patch_size, false); // [N, C, H, W]

return x;
}
Expand Down
6 changes: 3 additions & 3 deletions src/common.hpp → src/common_block.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef __COMMON_HPP__
#define __COMMON_HPP__
#ifndef __COMMON_BLOCK_HPP__
#define __COMMON_BLOCK_HPP__

#include "ggml_extend.hpp"

Expand Down Expand Up @@ -590,4 +590,4 @@ class VideoResBlock : public ResBlock {
}
};

#endif // __COMMON_HPP__
#endif // __COMMON_BLOCK_HPP__
108 changes: 108 additions & 0 deletions src/common_dit.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#ifndef __COMMON_DIT_HPP__
#define __COMMON_DIT_HPP__

#include "ggml_extend.hpp"

namespace DiT {
ggml_tensor* patchify(ggml_context* ctx,
ggml_tensor* x,
int pw,
int ph,
bool patch_last = true) {
// x: [N, C, H, W]
// return: [N, h*w, C*ph*pw] if patch_last else [N, h*w, ph*pw*C]
int64_t N = x->ne[3];
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
int64_t h = H / ph;
int64_t w = W / pw;

GGML_ASSERT(h * ph == H && w * pw == W);

x = ggml_reshape_4d(ctx, x, pw, w, ph, h * C * N); // [N*C*h, ph, w, pw]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, ph, pw]
x = ggml_reshape_4d(ctx, x, pw * ph, w * h, C, N); // [N, C, h*w, ph*pw]
if (patch_last) {
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, ph*pw]
x = ggml_reshape_3d(ctx, x, pw * ph * C, w * h, N); // [N, h*w, C*ph*pw]
} else {
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, h*w, C, ph*pw]
x = ggml_reshape_3d(ctx, x, C * pw * ph, w * h, N); // [N, h*w, ph*pw*C]
}
return x;
}

ggml_tensor* unpatchify(ggml_context* ctx,
ggml_tensor* x,
int64_t h,
int64_t w,
int ph,
int pw,
bool patch_last = true) {
// x: [N, h*w, C*ph*pw] if patch_last else [N, h*w, ph*pw*C]
// return: [N, C, H, W]
int64_t N = x->ne[2];
int64_t C = x->ne[0] / ph / pw;
int64_t H = h * ph;
int64_t W = w * pw;

GGML_ASSERT(C * ph * pw == x->ne[0]);

if (patch_last) {
x = ggml_reshape_4d(ctx, x, pw * ph, C, w * h, N); // [N, h*w, C, ph*pw]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, ph*pw]
} else {
x = ggml_reshape_4d(ctx, x, C, pw * ph, w * h, N); // [N, h*w, ph*pw, C]
x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, C, h*w, ph*pw]
}

x = ggml_reshape_4d(ctx, x, pw, ph, w, h * C * N); // [N*C*h, w, ph, pw]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, ph, w, pw]
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*ph, w*pw]

return x;
}

ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
ggml_tensor* x,
int ph,
int pw) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];

int pad_h = (ph - H % ph) % ph;
int pad_w = (pw - W % pw) % pw;
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x;
}

ggml_tensor* pad_and_patchify(GGMLRunnerContext* ctx,
ggml_tensor* x,
int ph,
int pw,
bool patch_last = true) {
x = pad_to_patch_size(ctx, x, ph, pw);
x = patchify(ctx->ggml_ctx, x, ph, pw, patch_last);
return x;
}

ggml_tensor* unpatchify_and_crop(ggml_context* ctx,
ggml_tensor* x,
int64_t H,
int64_t W,
int ph,
int pw,
bool patch_last = true) {
int pad_h = (ph - H % ph) % ph;
int pad_w = (pw - W % pw) % pw;
int64_t h = ((H + pad_h) / ph);
int64_t w = ((W + pad_w) / pw);
x = unpatchify(ctx, x, h, w, ph, pw, patch_last); // [N, C, H + pad_h, W + pad_w]
x = ggml_ext_slice(ctx, x, 1, 0, H); // [N, C, H, W + pad_w]
x = ggml_ext_slice(ctx, x, 0, 0, W); // [N, C, H, W]
return x;
}
} // namespace DiT

#endif // __COMMON_DIT_HPP__
3 changes: 1 addition & 2 deletions src/control.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#ifndef __CONTROL_HPP__
#define __CONTROL_HPP__

#include "common.hpp"
#include "ggml_extend.hpp"
#include "common_block.hpp"
#include "model.h"

#define CONTROL_NET_GRAPH_SIZE 1536
Expand Down
Loading
Loading