Skip to content

Commit b714767

Browse files
Aclyggerganov
authored andcommitted
Add ggml_roll (ggml/1274)
* ggml : add ggml_roll * use set/get_op_params & std::min
1 parent d860dd9 commit b714767

File tree

5 files changed

+117
-2
lines changed

5 files changed

+117
-2
lines changed

ggml/include/ggml.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ extern "C" {
489489
GGML_OP_UPSCALE, // nearest interpolate
490490
GGML_OP_PAD,
491491
GGML_OP_PAD_REFLECT_1D,
492+
GGML_OP_ROLL,
492493
GGML_OP_ARANGE,
493494
GGML_OP_TIMESTEP_EMBEDDING,
494495
GGML_OP_ARGSORT,
@@ -1801,6 +1802,17 @@ extern "C" {
18011802
int p0,
18021803
int p1);
18031804

1805+
// Move tensor elements by an offset given for each dimension. Elements that
1806+
// are shifted beyond the last position are wrapped around to the beginning.
1807+
GGML_API struct ggml_tensor * ggml_roll(
1808+
struct ggml_context * ctx,
1809+
struct ggml_tensor * a,
1810+
int shift0,
1811+
int shift1,
1812+
int shift2,
1813+
int shift3);
1814+
1815+
18041816
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
18051817
// timesteps: [N,]
18061818
// return: [N, dim]

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,6 +1890,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
18901890
{
18911891
ggml_compute_forward_pad_reflect_1d(params, tensor);
18921892
} break;
1893+
case GGML_OP_ROLL:
1894+
{
1895+
ggml_compute_forward_roll(params, tensor);
1896+
} break;
18931897
case GGML_OP_ARANGE:
18941898
{
18951899
ggml_compute_forward_arange(params, tensor);
@@ -2214,6 +2218,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22142218
case GGML_OP_UPSCALE:
22152219
case GGML_OP_PAD:
22162220
case GGML_OP_PAD_REFLECT_1D:
2221+
case GGML_OP_ROLL:
22172222
case GGML_OP_ARANGE:
22182223
case GGML_OP_TIMESTEP_EMBEDDING:
22192224
case GGML_OP_ARGSORT:

ggml/src/ggml-cpu/ops.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6793,6 +6793,73 @@ void ggml_compute_forward_pad_reflect_1d(
67936793
}
67946794
}
67956795

6796+
// ggml_compute_forward_roll
6797+
6798+
static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
6799+
if (i < 0) {
6800+
return i + ne;
6801+
} else if (i >= ne) {
6802+
return i - ne;
6803+
}
6804+
return i;
6805+
}
6806+
6807+
static void ggml_compute_forward_roll_f32(
6808+
const ggml_compute_params * params,
6809+
ggml_tensor * dst) {
6810+
6811+
const ggml_tensor * src0 = dst->src[0];
6812+
const float * src_data = (const float *) src0->data;
6813+
float * dst_data = (float *) dst->data;
6814+
6815+
GGML_TENSOR_UNARY_OP_LOCALS
6816+
6817+
const int s0 = ggml_get_op_params_i32(dst, 0);
6818+
const int s1 = ggml_get_op_params_i32(dst, 1);
6819+
const int s2 = ggml_get_op_params_i32(dst, 2);
6820+
const int s3 = ggml_get_op_params_i32(dst, 3);
6821+
6822+
const int64_t total = ne1 * ne2 * ne3;
6823+
const int64_t per_thread = (total + params->nth) / params->nth;
6824+
const int64_t start = params->ith * per_thread;
6825+
const int64_t end = std::min(start + per_thread, total);
6826+
6827+
for (int64_t i = start; i < end; ++i) {
6828+
const int64_t i1 = i % ne1;
6829+
const int64_t i2 = (i / ne1) % ne2;
6830+
const int64_t i3 = i / (ne2 * ne1);
6831+
float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
6832+
6833+
const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);
6834+
const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);
6835+
const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);
6836+
const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
6837+
6838+
const int64_t s = ggml_wrap_index(-s0, ne00);
6839+
const int64_t n = ne00 - s;
6840+
ggml_vec_cpy_f32(n, dst_row, src_row + s);
6841+
ggml_vec_cpy_f32(s, dst_row + n, src_row);
6842+
}
6843+
}
6844+
6845+
void ggml_compute_forward_roll(
6846+
const ggml_compute_params * params,
6847+
ggml_tensor * dst) {
6848+
6849+
const ggml_tensor * src0 = dst->src[0];
6850+
6851+
switch (src0->type) {
6852+
case GGML_TYPE_F32:
6853+
{
6854+
ggml_compute_forward_roll_f32(params, dst);
6855+
} break;
6856+
default:
6857+
{
6858+
GGML_ABORT("fatal error");
6859+
}
6860+
}
6861+
}
6862+
67966863
// ggml_compute_forward_arange
67976864

67986865
static void ggml_compute_forward_arange_f32(

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ void ggml_compute_forward_pool_2d_back(const struct ggml_compute_params * params
7272
void ggml_compute_forward_upscale(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7373
void ggml_compute_forward_pad(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7474
void ggml_compute_forward_pad_reflect_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
75+
void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7576
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7677
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7778
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml.c

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
955955
"UPSCALE",
956956
"PAD",
957957
"PAD_REFLECT_1D",
958+
"ROLL",
958959
"ARANGE",
959960
"TIMESTEP_EMBEDDING",
960961
"ARGSORT",
@@ -985,7 +986,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
985986
"OPT_STEP_ADAMW",
986987
};
987988

988-
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
989+
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
989990

990991
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
991992
"none",
@@ -1050,6 +1051,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10501051
"upscale(x)",
10511052
"pad(x)",
10521053
"pad_reflect_1d(x)",
1054+
"roll(x)",
10531055
"arange(start, stop, step)",
10541056
"timestep_embedding(timesteps, dim, max_period)",
10551057
"argsort(x)",
@@ -1080,7 +1082,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10801082
"adamw(x)",
10811083
};
10821084

1083-
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
1085+
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
10841086

10851087
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
10861088

@@ -4341,6 +4343,34 @@ struct ggml_tensor * ggml_pad_reflect_1d(
43414343
return result;
43424344
}
43434345

4346+
// ggml_roll
4347+
4348+
struct ggml_tensor * ggml_roll(
4349+
struct ggml_context * ctx,
4350+
struct ggml_tensor * a,
4351+
int shift0,
4352+
int shift1,
4353+
int shift2,
4354+
int shift3) {
4355+
GGML_ASSERT(a->nb[0] == ggml_type_size(a->type));
4356+
GGML_ASSERT(abs(shift0) < a->ne[0]);
4357+
GGML_ASSERT(abs(shift1) < a->ne[1]);
4358+
GGML_ASSERT(abs(shift2) < a->ne[2]);
4359+
GGML_ASSERT(abs(shift3) < a->ne[3]);
4360+
4361+
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
4362+
4363+
ggml_set_op_params_i32(result, 0, shift0);
4364+
ggml_set_op_params_i32(result, 1, shift1);
4365+
ggml_set_op_params_i32(result, 2, shift2);
4366+
ggml_set_op_params_i32(result, 3, shift3);
4367+
4368+
result->op = GGML_OP_ROLL;
4369+
result->src[0] = a;
4370+
4371+
return result;
4372+
}
4373+
43444374
// ggml_arange
43454375

43464376
struct ggml_tensor * ggml_arange(

0 commit comments

Comments
 (0)