Coverage Report

Created: 2026-04-12 06:40

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/ggml/src/ggml-backend-meta.cpp
Line
Count
Source
1
#include "ggml.h"
2
#include "ggml-impl.h"
3
#include "ggml-backend.h"
4
#include "ggml-backend-impl.h"
5
#include "ggml-alloc.h"
6
#include "ggml-cpp.h"
7
8
// TODO: tmp
9
#include "ggml-ext.h"
10
11
#include <algorithm>
12
#include <cassert>
13
#include <cmath>
14
#include <cstddef>
15
#include <cstdint>
16
#include <cstring>
17
#include <map>
18
#include <memory>
19
#include <string>
20
#include <tuple>
21
#include <utility>
22
#include <vector>
23
24
struct ggml_backend_meta_device;
25
struct ggml_backend_meta_buffer_type;
26
struct ggml_backend_meta_buffer;
27
struct ggml_backend_meta;
28
29
0
const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis) {
30
0
    switch (split_axis) {
31
0
        case GGML_BACKEND_SPLIT_AXIS_0:
32
0
            return "0";
33
0
        case GGML_BACKEND_SPLIT_AXIS_1:
34
0
            return "1";
35
0
        case GGML_BACKEND_SPLIT_AXIS_2:
36
0
            return "2";
37
0
        case GGML_BACKEND_SPLIT_AXIS_3:
38
0
            return "3";
39
0
        case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
40
0
            return "MIRRORED";
41
0
        case GGML_BACKEND_SPLIT_AXIS_PARTIAL:
42
0
            return "PARTIAL";
43
0
        case GGML_BACKEND_SPLIT_AXIS_NONE:
44
0
            return "NONE";
45
0
        case GGML_BACKEND_SPLIT_AXIS_UNKNOWN:
46
0
            return "UNKNOWN";
47
0
        default:
48
0
            GGML_ABORT("fatal error");
49
0
    }
50
0
}
51
52
//
53
// meta backend device
54
//
55
56
struct ggml_backend_meta_device_context {
57
    std::vector<ggml_backend_dev_t>     simple_devs;
58
    ggml_backend_meta_get_split_state_t get_split_state;
59
    void *                              get_split_state_ud;
60
61
    std::string name;
62
    std::string description;
63
64
    ggml_backend_meta_device_context(
65
            std::vector<ggml_backend_dev_t> simple_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) :
66
0
            simple_devs(std::move(simple_devs)), get_split_state(get_split_state), get_split_state_ud(get_split_state_ud) {
67
0
        name        = std::string("Meta(");
68
0
        description = std::string("Meta(");
69
0
        for (size_t i = 0; i < simple_devs.size(); i++) {
70
0
            if (i > 0) {
71
0
                name        += ",";
72
0
                description += ",";
73
0
            }
74
0
            name        += ggml_backend_dev_name       (simple_devs[i]);
75
0
            description += ggml_backend_dev_description(simple_devs[i]);
76
0
        }
77
0
        name        += ")";
78
0
        description += ")";
79
0
    }
80
81
0
    bool operator<(const ggml_backend_meta_device_context & other) const {
82
0
        return std::tie(simple_devs, get_split_state, get_split_state_ud)
83
0
            < std::tie(other.simple_devs, other.get_split_state, other.get_split_state_ud);
84
0
    }
85
};
86
87
static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev);
88
89
0
static const char * ggml_backend_meta_device_get_name(ggml_backend_dev_t dev) {
90
0
    GGML_ASSERT(ggml_backend_dev_is_meta(dev));
91
0
    const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
92
0
    return meta_dev_ctx->name.c_str();
93
0
}
94
95
0
static const char * ggml_backend_meta_device_get_description(ggml_backend_dev_t dev) {
96
0
    GGML_ASSERT(ggml_backend_dev_is_meta(dev));
97
0
    const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
98
0
    return meta_dev_ctx->description.c_str();
99
0
}
100
101
0
static void ggml_backend_meta_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
102
0
    GGML_ASSERT(ggml_backend_dev_is_meta(dev));
103
0
    const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
104
0
    *free  = 0;
105
0
    *total = 0;
106
0
    for (ggml_backend_dev_t dev : meta_dev_ctx->simple_devs) {
107
0
        size_t tmp_free, tmp_total;
108
0
        ggml_backend_dev_memory(dev, &tmp_free, &tmp_total);
109
0
        *free  += tmp_free;
110
0
        *total += tmp_total;
111
0
    }
112
0
}
113
114
0
static enum ggml_backend_dev_type ggml_backend_meta_device_get_type(ggml_backend_dev_t dev) {
115
0
    return GGML_BACKEND_DEVICE_TYPE_META;
116
117
0
    GGML_UNUSED(dev);
118
0
}
119
120
0
static void ggml_backend_meta_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
121
0
    GGML_ASSERT(ggml_backend_dev_is_meta(dev));
122
0
    const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
123
124
    // TODO replace placeholders
125
0
    props->name        = ggml_backend_meta_device_get_name(dev);
126
0
    props->description = ggml_backend_meta_device_get_description(dev);
127
0
    props->type        = ggml_backend_meta_device_get_type(dev);
128
0
    props->device_id   = 0;
129
130
0
    ggml_backend_meta_device_get_memory(dev, &props->memory_free, &props->memory_total);
131
132
0
    props->caps = {
133
0
        /* .async                 = */ true,
134
0
        /* .host_buffer           = */ false, // Not implemented.
135
0
        /* .buffer_from_host_ptr  = */ false, // Not implemented.
136
0
        /* .events                = */ false, // Not implemented.
137
0
    };
138
0
    for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) {
139
0
        ggml_backend_dev_props tmp_props;
140
0
        ggml_backend_dev_get_props(simple_dev, &tmp_props);
141
0
        props->caps.async                = props->caps.async                && tmp_props.caps.async;
142
0
        props->caps.host_buffer          = props->caps.host_buffer          && tmp_props.caps.host_buffer;
143
0
        props->caps.buffer_from_host_ptr = props->caps.buffer_from_host_ptr && tmp_props.caps.buffer_from_host_ptr;
144
0
        props->caps.events               = props->caps.events               && tmp_props.caps.events;
145
0
    }
146
0
}
147
148
static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params);
149
150
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev);
151
152
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev);
153
154
0
static bool ggml_backend_meta_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
155
0
    GGML_ASSERT(ggml_backend_dev_is_meta(dev));
156
0
    const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
157
0
    return std::all_of(meta_dev_ctx->simple_devs.begin(), meta_dev_ctx->simple_devs.end(),
158
0
        [op](ggml_backend_dev_t simple_dev) { return ggml_backend_dev_supports_op(simple_dev, op); });
159
0
}
160
161
0
static bool ggml_backend_meta_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
162
0
    GGML_ASSERT(ggml_backend_dev_is_meta(dev));
163
0
    ggml_backend_dev_t dev_buft = ggml_backend_buft_get_device(buft);
164
0
    if (!ggml_backend_dev_is_meta(dev_buft)) {
165
0
        return false;
166
0
    }
167
0
    const ggml_backend_meta_device_context * meta_dev_ctx      = (const ggml_backend_meta_device_context *) dev->context;
168
0
    const ggml_backend_meta_device_context * meta_buft_dev_ctx = (const ggml_backend_meta_device_context *) dev_buft->context;
169
0
    if (meta_dev_ctx->simple_devs.size() != meta_buft_dev_ctx->simple_devs.size()) {
170
0
        return false;
171
0
    }
172
0
    for (size_t i = 0; i < meta_dev_ctx->simple_devs.size(); i++) {
173
0
        if (meta_dev_ctx->simple_devs[i] != meta_buft_dev_ctx->simple_devs[i]) {
174
0
            return false;
175
0
        }
176
0
    }
177
0
    return true;
178
0
}
179
180
static const ggml_backend_device_i ggml_backend_meta_device_iface = {
181
    /* .get_name             = */ ggml_backend_meta_device_get_name,
182
    /* .get_description      = */ ggml_backend_meta_device_get_description,
183
    /* .get_memory           = */ ggml_backend_meta_device_get_memory,
184
    /* .get_type             = */ ggml_backend_meta_device_get_type,
185
    /* .get_props            = */ ggml_backend_meta_device_get_props,
186
    /* .init_backend         = */ ggml_backend_meta_device_init_backend,
187
    /* .get_buffer_type      = */ ggml_backend_meta_device_get_buffer_type,
188
    /* .get_host_buffer_type = */ ggml_backend_meta_device_get_host_buffer_type,
189
    /* .buffer_from_host_ptr = */ nullptr,
190
    /* .supports_op          = */ ggml_backend_meta_device_supports_op,
191
    /* .supports_buft        = */ ggml_backend_meta_device_supports_buft,
192
    /* .offload_op           = */ nullptr,
193
    /* .event_new            = */ nullptr,
194
    /* .event_free           = */ nullptr,
195
    /* .event_synchronize    = */ nullptr,
196
};
197
198
0
static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev) {
199
0
    return dev != nullptr && dev->iface.get_name == ggml_backend_meta_device_iface.get_name;
200
0
}
201
202
0
static size_t ggml_backend_meta_dev_n_devs(ggml_backend_dev_t meta_dev) {
203
0
    GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev));
204
0
    const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context;
205
0
    return meta_dev_ctx->simple_devs.size();
206
0
}
207
208
0
static ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev, size_t index) {
209
0
    GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev));
210
0
    const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context;
211
0
    GGML_ASSERT(index < meta_dev_ctx->simple_devs.size());
212
0
    return meta_dev_ctx->simple_devs[index];
213
0
}
214
215
ggml_backend_dev_t ggml_backend_meta_device(
216
0
        ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) {
217
0
    GGML_ASSERT(n_devs <= GGML_BACKEND_META_MAX_DEVICES);
218
    // TODO: this is not thread-safe - needs to be fixed
219
0
    static std::vector<std::unique_ptr<ggml_backend_meta_device_context>>         ctxs;
220
0
    static std::map<ggml_backend_meta_device_context, struct ggml_backend_device> meta_devs;
221
222
0
    std::vector<ggml_backend_dev_t> simple_devs;
223
0
    simple_devs.reserve(n_devs);
224
0
    for (size_t i = 0; i < n_devs; i++) {
225
0
        simple_devs.push_back(devs[i]);
226
0
    }
227
0
    ggml_backend_meta_device_context ctx(simple_devs, get_split_state, get_split_state_ud);
228
229
0
    {
230
0
        auto it = meta_devs.find(ctx);
231
0
        if (it != meta_devs.end()) {
232
0
            return &it->second;
233
0
        }
234
0
    }
235
0
    ctxs.push_back(std::make_unique<ggml_backend_meta_device_context>(ctx));
236
237
0
    struct ggml_backend_device meta_dev = {
238
0
        /*iface  =*/ ggml_backend_meta_device_iface,
239
0
        /*reg    =*/ nullptr,
240
0
        /*ctx    =*/ ctxs.back().get(),
241
0
    };
242
243
0
    auto result = meta_devs.emplace(*ctxs.back(), meta_dev);
244
0
    return &result.first->second;
245
0
}
246
247
//
248
// meta backend buffer type
249
//
250
251
struct ggml_backend_meta_buffer_type_context {
252
    std::vector<ggml_backend_buffer_type_t> simple_bufts;
253
254
    std::string name;
255
256
0
    ggml_backend_meta_buffer_type_context(std::vector<ggml_backend_buffer_type_t> simple_bufts) : simple_bufts(std::move(simple_bufts)) {
257
0
        name = "Meta(";
258
0
        for (size_t i = 0; i < simple_bufts.size(); i++) {
259
0
            if (i > 0) {
260
0
                name += ",";
261
0
            }
262
0
            name += ggml_backend_buft_name(simple_bufts[i]);
263
0
        }
264
0
        name += ")";
265
0
    }
266
267
0
    bool operator<(const ggml_backend_meta_buffer_type_context & other) const {
268
0
        return simple_bufts < other.simple_bufts;
269
0
    }
270
};
271
272
0
static size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft) {
273
0
    GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft));
274
0
    const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context;
275
0
    return meta_buft_ctx->simple_bufts.size();
276
0
}
277
278
0
static const char * ggml_backend_meta_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
279
0
    GGML_ASSERT(ggml_backend_buft_is_meta(buft));
280
0
    const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) buft->context;
281
0
    return meta_buft_ctx->name.c_str();
282
0
}
283
284
0
static ggml_backend_buffer_type_t ggml_backend_meta_buft_simple_buft(ggml_backend_buffer_type_t meta_buft, size_t index) {
285
0
    GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft));
286
0
    const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context;
287
0
    GGML_ASSERT(index < meta_buft_ctx->simple_bufts.size());
288
0
    return meta_buft_ctx->simple_bufts[index];
289
0
}
290
291
static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
292
293
0
static size_t ggml_backend_meta_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
294
0
    const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
295
0
    size_t max_alignment = 1;
296
0
    for (size_t i = 0; i < n_simple_bufts; i++) {
297
0
        const size_t alignment = ggml_backend_buft_get_alignment(ggml_backend_meta_buft_simple_buft(buft, i));
298
0
        max_alignment = std::max(max_alignment, alignment);
299
0
        GGML_ASSERT(max_alignment % alignment == 0);
300
0
    }
301
0
    return max_alignment;
302
0
}
303
304
0
static size_t ggml_backend_meta_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
305
0
    const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
306
0
    size_t max_size = SIZE_MAX;
307
0
    for (size_t i = 0; i < n_simple_bufts; i++) {
308
0
        max_size = std::min(max_size, ggml_backend_buft_get_max_size(ggml_backend_meta_buft_simple_buft(buft, i)));
309
0
    }
310
0
    return max_size;
311
0
}
312
313
0
static size_t ggml_backend_meta_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
314
0
    const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
315
0
    size_t max_alloc_size = 0;
316
0
    for (size_t i = 0; i < n_simple_bufts; i++) {
317
0
        const size_t alloc_size = ggml_backend_buft_get_alloc_size(ggml_backend_meta_buft_simple_buft(buft, i), tensor);
318
0
        max_alloc_size = std::max(max_alloc_size, alloc_size);
319
0
    }
320
0
    return max_alloc_size;
321
0
}
322
323
0
static bool ggml_backend_meta_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
324
0
    const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
325
0
    for (size_t i = 0; i < n_simple_bufts; i++) {
326
0
        if (!ggml_backend_buft_is_host(ggml_backend_meta_buft_simple_buft(buft, i))) {
327
0
            return false;
328
0
        }
329
0
    }
330
0
    return true;
331
0
}
332
333
static const struct ggml_backend_buffer_type_i ggml_backend_meta_buffer_type_iface = {
334
    /* .get_name         = */ ggml_backend_meta_buffer_type_get_name,
335
    /* .alloc_buffer     = */ ggml_backend_meta_buffer_type_alloc_buffer,
336
    /* .get_alignment    = */ ggml_backend_meta_buffer_type_get_alignment,
337
    /* .get_max_size     = */ ggml_backend_meta_buffer_type_get_max_size,
338
    /* .get_alloc_size   = */ ggml_backend_meta_buffer_type_get_alloc_size,
339
    /* .is_host          = */ ggml_backend_meta_buffer_type_is_host,
340
};
341
342
0
bool ggml_backend_buft_is_meta(ggml_backend_buffer_type_t buft) {
343
0
    return buft != nullptr && buft->iface.get_name == ggml_backend_meta_buffer_type_iface.get_name;
344
0
}
345
346
0
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev) {
347
0
    static std::map<ggml_backend_dev_t, struct ggml_backend_buffer_type> meta_bufts;
348
0
    GGML_ASSERT(ggml_backend_dev_is_meta(dev));
349
0
    {
350
0
        auto it = meta_bufts.find(dev);
351
0
        if (it != meta_bufts.end()) {
352
0
            return &it->second;
353
0
        }
354
0
    }
355
356
0
    const size_t n_devs = ggml_backend_meta_dev_n_devs(dev);
357
0
    std::vector<ggml_backend_buffer_type_t> simple_bufts;
358
0
    simple_bufts.reserve(n_devs);
359
0
    for (size_t i = 0; i < n_devs; i++) {
360
0
        simple_bufts.push_back(ggml_backend_dev_buffer_type(ggml_backend_meta_dev_simple_dev(dev, i)));
361
0
    }
362
0
    ggml_backend_meta_buffer_type_context * buft_ctx = new ggml_backend_meta_buffer_type_context(simple_bufts);
363
364
0
    struct ggml_backend_buffer_type meta_buft = {
365
0
        /*iface  =*/ ggml_backend_meta_buffer_type_iface,
366
0
        /*device =*/ dev,
367
0
        /*ctx    =*/ buft_ctx,
368
0
    };
369
0
    auto result = meta_bufts.emplace(dev, meta_buft);
370
0
    return &result.first->second;
371
0
}
372
373
0
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev) {
374
0
    GGML_ASSERT(ggml_backend_dev_is_meta(dev));
375
0
    const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
376
377
0
    ggml_backend_buffer_type_t host_buft = nullptr;
378
0
    for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) {
379
0
        ggml_backend_buffer_type_t simple_host_buft = ggml_backend_dev_host_buffer_type(simple_dev);
380
0
        if (simple_host_buft == nullptr) {
381
0
            return nullptr;
382
0
        }
383
0
        if (host_buft == nullptr) {
384
0
            host_buft = simple_host_buft;
385
0
        } else if (host_buft != simple_host_buft) {
386
            // if different simple devices have different host buffer types,
387
            // we cannot provide a single host buffer type for the meta device
388
0
            return nullptr;
389
0
        }
390
0
    }
391
0
    return host_buft;
392
0
}
393
394
//
395
// meta backend buffer
396
//
397
398
struct ggml_backend_meta_buffer_context {
399
    static constexpr size_t nbtc = GGML_TENSOR_SIZE - sizeof(ggml_tensor::padding);
400
401
    std::map<std::pair<const ggml_tensor *, bool>, std::pair<ggml_backend_meta_split_state, char[nbtc]>> split_state_cache;
402
    std::map<          const ggml_tensor *,        std::vector<ggml_tensor *>>                           simple_tensors;
403
404
    struct buffer_config {
405
        ggml_context          * ctx;
406
        ggml_backend_buffer_t   buf;
407
408
0
        buffer_config(ggml_context * ctx, ggml_backend_buffer_t buf) : ctx(ctx), buf(buf) {}
409
    };
410
    std::vector<buffer_config> buf_configs;
411
412
    int debug;
413
414
0
    ggml_backend_meta_buffer_context() {
415
0
        const char * GGML_META_DEBUG = getenv("GGML_META_DEBUG");
416
0
        debug = GGML_META_DEBUG ? atoi(GGML_META_DEBUG) : 0;
417
0
    }
418
};
419
420
0
static void ggml_backend_meta_buffer_free_buffer(ggml_backend_buffer_t buffer) {
421
0
    GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
422
0
    ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
423
0
    for (auto & [ctx, buf] : buf_ctx->buf_configs) {
424
0
        ggml_backend_buffer_free(buf);
425
0
        ggml_free(ctx);
426
0
    }
427
0
    delete buf_ctx;
428
0
}
429
430
0
static size_t ggml_backend_meta_buffer_n_bufs(ggml_backend_buffer_t meta_buf) {
431
0
    GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf));
432
0
    ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context;
433
0
    return buf_ctx->buf_configs.size();
434
0
}
435
436
0
static ggml_backend_buffer_t ggml_backend_meta_buffer_simple_buffer(ggml_backend_buffer_t meta_buf, size_t index) {
437
0
    GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf));
438
0
    ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context;
439
0
    GGML_ASSERT(index < buf_ctx->buf_configs.size());
440
0
    return buf_ctx->buf_configs[index].buf;
441
0
}
442
443
0
static struct ggml_tensor * ggml_backend_meta_buffer_simple_tensor(const struct ggml_tensor * tensor, size_t index) {
444
0
    GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
445
0
    ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
446
0
    GGML_ASSERT(index < buf_ctx->buf_configs.size());
447
448
0
    auto it = buf_ctx->simple_tensors.find(tensor);
449
0
    if (it == buf_ctx->simple_tensors.end()) {
450
0
        return nullptr;
451
0
    }
452
0
    return it->second[index];
453
0
}
454
455
0
static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) {
456
0
    const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
457
0
    ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
458
459
0
    auto split_states_equal = [&](const ggml_backend_meta_split_state & a, const ggml_backend_meta_split_state & b) -> bool {
460
0
        if (a.axis != b.axis) {
461
0
            return false;
462
0
        }
463
0
        for (size_t j = 0; j < n_bufs; j++) {
464
0
            int64_t sum_a = 0;
465
0
            for (size_t s = 0; s < a.n_segments; s++) {
466
0
                sum_a += a.ne[s*n_bufs + j];
467
0
            }
468
0
            int64_t sum_b = 0;
469
0
            for (size_t s = 0; s < b.n_segments; s++) {
470
0
                sum_b += b.ne[s*n_bufs + j];
471
0
            }
472
0
            if (sum_a != sum_b) {
473
0
                return false;
474
0
            }
475
0
        }
476
0
        return true;
477
0
    };
478
479
0
    auto handle_generic = [&](const std::vector<ggml_backend_meta_split_state> & src_ss, bool scalar_only) -> ggml_backend_meta_split_state {
480
0
        ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1};
481
0
        for (size_t i = 0; i < GGML_MAX_SRC; i++) {
482
0
            if (tensor->src[i] == nullptr || tensor->src[i] == tensor) {
483
0
                continue;
484
0
            }
485
0
            if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) {
486
0
                ret = src_ss[i];
487
0
            } else if (!split_states_equal(src_ss[i], ret)) {
488
0
                ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1};
489
0
                break;
490
0
            }
491
0
        }
492
0
        if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) {
493
0
            ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1};
494
0
        }
495
0
        if (scalar_only && ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) {
496
0
            ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1};
497
0
        }
498
0
        GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
499
0
        return ret;
500
0
    };
501
502
    // Some ops process data on a per-row bases:
503
0
    auto handle_per_row = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
504
0
        GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_0);
505
0
        return src_ss[0];
506
0
    };
507
508
    // Some ops broadcast the src1 data across src0:
509
0
    auto handle_bin_bcast = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
510
0
        if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS &&
511
0
                tensor->src[1]->ne[src_ss[0].axis] == 1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
512
0
            return src_ss[0];
513
0
        }
514
0
        if (src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[0].axis == src_ss[1].axis ||
515
0
           (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL)))) {
516
0
            return src_ss[0]; // GGML_OP_ADD_ID
517
0
        }
518
0
        GGML_ASSERT(tensor->src[2] == nullptr || src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
519
0
        return handle_generic(src_ss, /*scalar_only =*/ false);
520
0
    };
521
522
0
    auto handle_concat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
523
0
        const ggml_backend_meta_split_axis concat_axis = ggml_backend_meta_split_axis(ggml_get_op_params_i32(tensor, 0));
524
0
        if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis >= 0 && src_ss[1].axis < GGML_MAX_DIMS) {
525
0
            GGML_ASSERT(concat_axis != src_ss[1].axis);
526
0
            return src_ss[1];
527
0
        }
528
0
        if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) {
529
0
            GGML_ASSERT(concat_axis != src_ss[0].axis);
530
0
            return src_ss[0];
531
0
        }
532
0
        if (src_ss[0].axis == src_ss[1].axis && src_ss[0].axis != concat_axis) {
533
0
            return src_ss[0];
534
0
        }
535
0
        return handle_generic(src_ss, /*scalar_only =*/ true);
536
0
    };
537
538
0
    auto handle_mul_mat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
539
0
        if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
540
0
            return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1};
541
0
        }
542
0
        if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
543
0
            ggml_backend_meta_split_state ret = src_ss[0];
544
0
            ret.axis = GGML_BACKEND_SPLIT_AXIS_0;
545
0
            ret.n_segments = 1;
546
0
            return ret;
547
0
        }
548
0
        if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
549
0
            ggml_backend_meta_split_state ret = src_ss[1];
550
0
            ret.n_segments = 1;
551
0
            return ret;
552
0
        }
553
0
        if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_0) {
554
0
            GGML_ASSERT(split_states_equal(src_ss[0], src_ss[1]));
555
0
            return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, 1};
556
0
        }
557
0
        GGML_ABORT("fatal error");
558
        //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1};
559
0
    };
560
561
0
    auto handle_cpy = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
562
0
        if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) {
563
0
            int64_t ne_split_src = tensor->src[0]->ne[0];
564
0
            for (int dim = 1; dim <= src_ss[0].axis; dim++) {
565
0
                ne_split_src *= tensor->src[0]->ne[dim];
566
0
            }
567
0
            int64_t ne_split_dst = 1;
568
0
            for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
569
0
                ne_split_dst *= tensor->ne[dim];
570
0
                if (ne_split_dst == ne_split_src) {
571
0
                    return {ggml_backend_meta_split_axis(dim), {0}, 1};
572
0
                }
573
0
            }
574
0
        }
575
0
        return handle_generic(src_ss, /*scalar_only =*/ false);
576
0
    };
577
578
0
    auto handle_reshape = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
579
0
        switch (src_ss[0].axis) {
580
0
            case GGML_BACKEND_SPLIT_AXIS_0:
581
0
            case GGML_BACKEND_SPLIT_AXIS_1:
582
0
            case GGML_BACKEND_SPLIT_AXIS_2:
583
0
            case GGML_BACKEND_SPLIT_AXIS_3: {
584
0
                GGML_ASSERT(!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]));
585
0
                if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1) {
586
0
                    return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, 1};
587
0
                }
588
0
                std::vector<int64_t> base_ne_in;
589
0
                base_ne_in.reserve(GGML_MAX_DIMS - src_ss[0].axis);
590
0
                {
591
0
                    base_ne_in.push_back(1);
592
0
                    int dim = 0;
593
0
                    for (; dim <= src_ss[0].axis; dim++) {
594
0
                        base_ne_in[0] *= tensor->src[0]->ne[dim];
595
0
                    }
596
0
                    for (; dim <= GGML_MAX_DIMS; dim++) {
597
0
                        base_ne_in.push_back(base_ne_in.back() * tensor->src[0]->ne[dim]);
598
0
                    }
599
0
                }
600
0
                int64_t base_ne_out = 1;
601
0
                for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
602
0
                    const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim];
603
0
                    for (const int64_t & bni : base_ne_in) {
604
0
                        if (bni == base_ne_out_next) {
605
0
                            return {ggml_backend_meta_split_axis(dim), {0}, 1};
606
0
                        }
607
0
                    }
608
0
                    if (base_ne_out_next > base_ne_in[0]) {
609
0
                        GGML_ASSERT(dim + 1 < GGML_MAX_DIMS);
610
0
                        return {ggml_backend_meta_split_axis(dim + 1), {0}, 1};
611
0
                    }
612
0
                    base_ne_out = base_ne_out_next;
613
0
                }
614
0
                GGML_ABORT("shape mismatch for %s", ggml_op_name(tensor->op));
615
0
            }
616
0
            case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
617
0
            case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
618
0
                return src_ss[0];
619
0
            }
620
0
            default: {
621
0
                GGML_ABORT("fatal error");
622
                //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1};
623
0
            }
624
0
        }
625
0
    };
626
627
0
    auto handle_view = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
628
0
        if (ggml_is_contiguous(tensor) && ggml_is_contiguous(tensor->src[0])) {
629
0
            return handle_reshape(src_ss);
630
0
        }
631
0
        const int axis = src_ss[0].axis;
632
0
        {
633
0
            bool all_strides_the_same = true;
634
0
            for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
635
0
                if (tensor->ne[dim] == 1 && tensor->src[0]->ne[dim] == 1) {
636
0
                    continue;
637
0
                }
638
0
                if (tensor->nb[dim] != tensor->src[0]->nb[dim]) {
639
0
                    all_strides_the_same = false;
640
0
                    break;
641
0
                }
642
0
            }
643
0
            if (all_strides_the_same) {
644
0
                return src_ss[0];
645
0
            }
646
0
        }
647
0
        if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]) && axis >= 0 && axis < GGML_MAX_DIMS-1) {
648
0
            for (int dim = 0; dim < GGML_MAX_DIMS-1; dim++) {
649
0
                if (tensor->nb[dim+1] == tensor->src[0]->nb[axis+1]) {
650
0
                    return {ggml_backend_meta_split_axis(dim), {0}, 1};
651
0
                }
652
0
            }
653
0
            GGML_ABORT("fatal error");
654
0
        }
655
0
        if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED || src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) {
656
0
            return src_ss[0];
657
0
        }
658
0
        GGML_ABORT("view of permuted tensor not implemented");
659
        //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1};
660
0
    };
661
662
0
    auto handle_permute = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
663
0
        switch (src_ss[0].axis) {
664
0
            case GGML_BACKEND_SPLIT_AXIS_0:
665
0
            case GGML_BACKEND_SPLIT_AXIS_1:
666
0
            case GGML_BACKEND_SPLIT_AXIS_2:
667
0
            case GGML_BACKEND_SPLIT_AXIS_3: {
668
0
                return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, 1};
669
0
            }
670
0
            case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
671
0
            case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
672
0
                return src_ss[0];
673
0
            }
674
0
            default: {
675
0
                GGML_ABORT("fatal error");
676
                //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1};
677
0
            }
678
0
        }
679
0
    };
680
681
0
    auto handle_transpose = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
682
0
        switch (src_ss[0].axis) {
683
0
            case GGML_BACKEND_SPLIT_AXIS_0:
684
0
            case GGML_BACKEND_SPLIT_AXIS_1: {
685
0
                return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, 1};
686
0
            }
687
0
            case GGML_BACKEND_SPLIT_AXIS_2:
688
0
            case GGML_BACKEND_SPLIT_AXIS_3:
689
0
            case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
690
0
            case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
691
0
                return src_ss[0];
692
0
            }
693
0
            default: {
694
0
                GGML_ABORT("fatal error");
695
                //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1};
696
0
            }
697
0
        }
698
0
    };
699
700
0
    auto handle_get_rows = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
701
0
        if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
702
0
            return src_ss[0];
703
0
        }
704
0
        return handle_generic(src_ss, /*scalar_only =*/ true);
705
0
    };
706
707
0
    auto handle_set_rows = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
708
0
        GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_1);
709
0
        GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
710
0
        GGML_ASSERT(split_states_equal(src_ss[0], src_ss[2]));
711
0
        return src_ss[0];
712
0
    };
713
714
0
    auto handle_rope = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
715
0
        GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
716
0
        return src_ss[0];
717
0
    };
718
719
0
    auto handle_pad = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
720
0
        if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) {
721
0
            GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 0] == 0);
722
0
            GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 1] == 0);
723
0
        }
724
0
        return src_ss[0];
725
0
    };
726
727
0
    auto handle_flash_attn_ext = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
728
0
        GGML_ASSERT(                             src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_2);
729
0
        GGML_ASSERT(                             src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_2);
730
0
        GGML_ASSERT(                             src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_2);
731
0
        GGML_ASSERT(tensor->src[4] == nullptr || src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
732
0
        GGML_ASSERT(tensor->src[4] == nullptr || src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_0);
733
0
        return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1};
734
0
    };
735
736
0
    auto handle_ssm_conv = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
737
0
        if (src_ss[0].axis == src_ss[1].axis) {
738
0
            if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0) {
739
0
                return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1};
740
0
            }
741
0
            if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1) {
742
0
                return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1};
743
0
            }
744
0
        }
745
0
        return handle_generic(src_ss, /*scalar_only =*/ false);
746
0
    };
747
748
0
    auto handle_gated_delta_net = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
749
0
        if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED &&
750
0
            src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED &&
751
0
            src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
752
0
            return src_ss[0];
753
0
        }
754
0
        GGML_ASSERT(src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1);
755
0
        GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1);
756
0
        GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1);
757
0
        GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1);
758
0
        GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1);
759
0
        GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2);
760
0
        return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1};
761
0
    };
762
763
0
    auto calculate_split_state = [&]() -> ggml_backend_meta_split_state {
764
0
        if (ggml_nelements(tensor) == 0) {
765
0
            return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1};
766
0
        }
767
0
        if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) {
768
0
            ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer));
769
0
            const ggml_backend_meta_device_context * dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
770
0
            ggml_backend_meta_split_state ret = dev_ctx->get_split_state(tensor, dev_ctx->get_split_state_ud);
771
0
            if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) {
772
0
                const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1;
773
0
                int64_t ne_sum = 0;
774
0
                for (size_t sj = 0; sj < ret.n_segments*n_bufs; sj++) {
775
0
                    GGML_ASSERT(ret.ne[sj] % granularity == 0);
776
0
                    ne_sum += ret.ne[sj];
777
0
                }
778
0
                GGML_ASSERT(ne_sum == tensor->ne[ret.axis]);
779
0
            }
780
0
            return ret;
781
0
        }
782
783
0
        std::vector<ggml_backend_meta_split_state> src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1});
784
0
        for (size_t i = 0; i < GGML_MAX_SRC; i++) {
785
0
            if (tensor->src[i] == nullptr || tensor->src[i] == tensor) {
786
0
                src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1};
787
0
                continue;
788
0
            }
789
0
            src_ss[i] = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true);
790
0
            GGML_ASSERT(src_ss[i].axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
791
0
        }
792
793
0
        ggml_backend_meta_split_state split_state;
794
0
        switch (tensor->op) {
795
0
            case GGML_OP_NONE: {
796
0
                split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1};
797
0
            } break;
798
0
            case GGML_OP_DUP: {
799
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
800
0
            } break;
801
0
            case GGML_OP_ADD:
802
0
            case GGML_OP_ADD_ID: {
803
0
                split_state = handle_bin_bcast(src_ss);
804
0
            } break;
805
0
            case GGML_OP_ADD1:
806
0
            case GGML_OP_ACC: {
807
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
808
0
            } break;
809
0
            case GGML_OP_SUB:
810
0
            case GGML_OP_MUL:
811
0
            case GGML_OP_DIV: {
812
0
                split_state = handle_bin_bcast(src_ss);
813
0
            } break;
814
0
            case GGML_OP_SQR:
815
0
            case GGML_OP_SQRT:
816
0
            case GGML_OP_LOG:
817
0
            case GGML_OP_SIN:
818
0
            case GGML_OP_COS: {
819
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
820
0
            } break;
821
0
            case GGML_OP_SUM: {
822
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
823
0
            } break;
824
0
            case GGML_OP_SUM_ROWS:
825
0
            case GGML_OP_CUMSUM:
826
0
            case GGML_OP_MEAN:
827
0
            case GGML_OP_ARGMAX:
828
0
            case GGML_OP_COUNT_EQUAL: {
829
0
                split_state = handle_per_row(src_ss);
830
0
            } break;
831
0
            case GGML_OP_REPEAT:
832
0
            case GGML_OP_REPEAT_BACK: {
833
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
834
0
            } break;
835
0
            case GGML_OP_CONCAT: {
836
0
                split_state = handle_concat(src_ss);
837
0
            } break;
838
0
            case GGML_OP_SILU_BACK: {
839
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
840
0
            } break;
841
0
            case GGML_OP_NORM:
842
0
            case GGML_OP_RMS_NORM:
843
0
            case GGML_OP_RMS_NORM_BACK:
844
0
            case GGML_OP_GROUP_NORM:
845
0
            case GGML_OP_L2_NORM: {
846
0
                split_state = handle_per_row(src_ss);
847
0
            } break;
848
0
            case GGML_OP_MUL_MAT:
849
0
            case GGML_OP_MUL_MAT_ID: {
850
0
                split_state = handle_mul_mat(src_ss);
851
0
            } break;
852
0
            case GGML_OP_OUT_PROD: {
853
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
854
0
            } break;
855
0
            case GGML_OP_SCALE: {
856
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
857
0
            } break;
858
0
            case GGML_OP_SET: {
859
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
860
0
            } break;
861
0
            case GGML_OP_CPY: {
862
0
                split_state = handle_cpy(src_ss);
863
0
            } break;
864
0
            case GGML_OP_CONT:
865
0
            case GGML_OP_RESHAPE: {
866
0
                split_state = handle_reshape(src_ss);
867
0
            } break;
868
0
            case GGML_OP_VIEW: {
869
0
                split_state = handle_view(src_ss);
870
0
            } break;
871
0
            case GGML_OP_PERMUTE: {
872
0
                split_state = handle_permute(src_ss);
873
0
            } break;
874
0
            case GGML_OP_TRANSPOSE: {
875
0
                split_state = handle_transpose(src_ss);
876
0
            } break;
877
0
            case GGML_OP_GET_ROWS: {
878
0
                split_state = handle_get_rows(src_ss);
879
0
            } break;
880
0
            case GGML_OP_GET_ROWS_BACK: {
881
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
882
0
            } break;
883
0
            case GGML_OP_SET_ROWS: {
884
0
                split_state = handle_set_rows(src_ss);
885
0
            } break;
886
0
            case GGML_OP_DIAG:
887
0
            case GGML_OP_DIAG_MASK_INF:
888
0
            case GGML_OP_DIAG_MASK_ZERO: {
889
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
890
0
            } break;
891
0
            case GGML_OP_SOFT_MAX:
892
0
            case GGML_OP_SOFT_MAX_BACK: {
893
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
894
0
            } break;
895
0
            case GGML_OP_ROPE: {
896
0
                split_state = handle_rope(src_ss);
897
0
            } break;
898
0
            case GGML_OP_ROPE_BACK: {
899
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
900
0
            } break;
901
0
            case GGML_OP_CLAMP: {
902
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
903
0
            } break;
904
0
            case GGML_OP_CONV_TRANSPOSE_1D:
905
0
            case GGML_OP_IM2COL:
906
0
            case GGML_OP_IM2COL_BACK:
907
0
            case GGML_OP_IM2COL_3D:
908
0
            case GGML_OP_CONV_2D:
909
0
            case GGML_OP_CONV_3D:
910
0
            case GGML_OP_CONV_2D_DW:
911
0
            case GGML_OP_CONV_TRANSPOSE_2D:
912
0
            case GGML_OP_POOL_1D:
913
0
            case GGML_OP_POOL_2D:
914
0
            case GGML_OP_POOL_2D_BACK:
915
0
            case GGML_OP_UPSCALE: {
916
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
917
0
            } break;
918
0
            case GGML_OP_PAD: {
919
0
                split_state = handle_pad(src_ss);
920
0
            } break;
921
0
            case GGML_OP_PAD_REFLECT_1D:
922
0
            case GGML_OP_ROLL:
923
0
            case GGML_OP_ARANGE:
924
0
            case GGML_OP_TIMESTEP_EMBEDDING: {
925
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
926
0
            } break;
927
0
            case GGML_OP_ARGSORT:
928
0
            case GGML_OP_TOP_K: {
929
0
                split_state = handle_per_row(src_ss);
930
0
            } break;
931
0
            case GGML_OP_LEAKY_RELU: {
932
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
933
0
            } break;
934
0
            case GGML_OP_TRI: {
935
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
936
0
            } break;
937
0
            case GGML_OP_FILL: {
938
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
939
0
            } break;
940
0
            case GGML_OP_FLASH_ATTN_EXT: {
941
0
                split_state = handle_flash_attn_ext(src_ss);
942
0
            } break;
943
0
            case GGML_OP_FLASH_ATTN_BACK: {
944
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
945
0
            } break;
946
0
            case GGML_OP_SSM_CONV: {
947
0
                split_state = handle_ssm_conv(src_ss);
948
0
            } break;
949
0
            case GGML_OP_SSM_SCAN:
950
0
            case GGML_OP_WIN_PART:
951
0
            case GGML_OP_WIN_UNPART:
952
0
            case GGML_OP_GET_REL_POS:
953
0
            case GGML_OP_ADD_REL_POS:
954
0
            case GGML_OP_RWKV_WKV6:
955
0
            case GGML_OP_GATED_LINEAR_ATTN:
956
0
            case GGML_OP_RWKV_WKV7:
957
0
            case GGML_OP_SOLVE_TRI: {
958
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
959
0
            } break;
960
0
            case GGML_OP_GATED_DELTA_NET: {
961
0
                split_state = handle_gated_delta_net(src_ss);
962
0
            } break;
963
0
            case GGML_OP_UNARY: {
964
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
965
0
            } break;
966
0
            case GGML_OP_MAP_CUSTOM1:
967
0
            case GGML_OP_MAP_CUSTOM2:
968
0
            case GGML_OP_MAP_CUSTOM3:
969
0
            case GGML_OP_CUSTOM: {
970
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
971
0
            } break;
972
0
            case GGML_OP_CROSS_ENTROPY_LOSS:
973
0
            case GGML_OP_CROSS_ENTROPY_LOSS_BACK: {
974
0
                split_state = handle_per_row(src_ss);
975
0
            } break;
976
0
            case GGML_OP_OPT_STEP_ADAMW:
977
0
            case GGML_OP_OPT_STEP_SGD:
978
0
            case GGML_OP_GLU: {
979
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
980
0
            } break;
981
0
            default: {
982
0
                GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op));
983
0
                split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1};
984
0
            } break;
985
0
        }
986
0
        if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
987
0
            bool first_src_split_by_axis = true;
988
0
            const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
989
990
0
            for (size_t i = 0; i < GGML_MAX_SRC; i++) {
991
0
                if (tensor->src[i] == nullptr || src_ss[i].axis < 0 || src_ss[i].axis >= GGML_MAX_DIMS) {
992
0
                    continue;
993
0
                }
994
0
                if (first_src_split_by_axis) {
995
0
                    for (size_t j = 0; j < n_bufs; j++) {
996
                        // Take over ratio from src:
997
0
                        for (size_t s = 0; s < src_ss[i].n_segments; s++) {
998
0
                            split_state.ne[s*n_bufs + j] = 0;
999
0
                        }
1000
0
                        for (size_t s = 0; s < src_ss[i].n_segments; s++) {
1001
0
                            split_state.ne[j] += src_ss[i].ne[s*n_bufs + j];
1002
0
                        }
1003
0
                        split_state.ne[j] *= tensor->ne[split_state.axis];
1004
0
                        if (split_state.ne[j] != 0 || tensor->src[i]->ne[src_ss[i].axis] != 0) {
1005
0
                            GGML_ASSERT(split_state.ne[j] % tensor->src[i]->ne[src_ss[i].axis] == 0);
1006
0
                            split_state.ne[j] /= tensor->src[i]->ne[src_ss[i].axis];
1007
0
                        }
1008
0
                    }
1009
0
                } else {
1010
0
                    for (size_t j = 0; j < n_bufs; j++) {
1011
0
                        int64_t sum = 0;
1012
0
                        for (size_t s = 0; s < src_ss[i].n_segments; s++) {
1013
0
                            sum += src_ss[i].ne[s*n_bufs + j];
1014
0
                        }
1015
                        // Assert that ratio is consistent:
1016
0
                        GGML_ASSERT(split_state.ne[j] * tensor->src[i]->ne[src_ss[i].axis]
1017
0
                                               == sum * tensor->ne[split_state.axis]);
1018
0
                    }
1019
0
                }
1020
0
                first_src_split_by_axis = false;
1021
0
            }
1022
0
            GGML_ASSERT(!first_src_split_by_axis);
1023
0
        }
1024
0
        return split_state;
1025
0
    };
1026
1027
0
    const std::pair key = std::make_pair(tensor, assume_sync);
1028
0
    auto it = buf_ctx->split_state_cache.find(key);
1029
0
    if (it != buf_ctx->split_state_cache.end() && memcmp(it->second.second, (const char *) tensor, sizeof(it->second.second)) != 0) {
1030
0
        buf_ctx->split_state_cache.clear();
1031
0
        it = buf_ctx->split_state_cache.end();
1032
0
    }
1033
1034
0
    if (it == buf_ctx->split_state_cache.end()) {
1035
0
        buf_ctx->split_state_cache[key].first = calculate_split_state();
1036
0
        memcpy(buf_ctx->split_state_cache[key].second, tensor, sizeof(buf_ctx->split_state_cache[key].second));
1037
0
        if (buf_ctx->debug > 0) {
1038
0
            std::string srcs_info;
1039
0
            for (size_t i = 0; i < GGML_MAX_SRC; i++) {
1040
0
                if (tensor->src[i] == nullptr) {
1041
0
                    continue;
1042
0
                }
1043
0
                if (!srcs_info.empty()) {
1044
0
                    srcs_info += ", ";
1045
0
                }
1046
0
                const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true);
1047
0
                const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis);
1048
0
                std::string ne_info;
1049
0
                for (size_t j = 0; j < n_bufs; j++) {
1050
0
                    if (!ne_info.empty()) {
1051
0
                        ne_info += ", ";
1052
0
                    }
1053
0
                    ne_info += std::to_string(split_state.ne[j]);
1054
0
                }
1055
0
                srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]";
1056
0
            }
1057
0
            std::string ne_info;
1058
0
            for (size_t j = 0; j < n_bufs; j++) {
1059
0
                if (!ne_info.empty()) {
1060
0
                    ne_info += ", ";
1061
0
                }
1062
0
                ne_info += std::to_string(buf_ctx->split_state_cache[key].first.ne[j]);
1063
0
            }
1064
0
            GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op),
1065
0
                ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].first.axis), ne_info.c_str());
1066
0
        }
1067
0
    }
1068
1069
0
    ggml_backend_meta_split_state ret = buf_ctx->split_state_cache[key].first;
1070
0
    GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_NONE);
1071
#ifndef NDEBUG
1072
    if (ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) {
1073
        int64_t ne_ret = 0;
1074
        for (size_t sj = 0; sj < ret.n_segments*n_bufs; sj++) {
1075
            ne_ret += ret.ne[sj];
1076
        }
1077
        assert(ne_ret == tensor->ne[int(ret.axis)]);
1078
    }
1079
#endif // NDEBUG
1080
0
    return ret;
1081
0
}
1082
1083
0
static void * ggml_backend_meta_buffer_get_base(ggml_backend_buffer_t buffer) {
1084
0
    GGML_UNUSED(buffer);
1085
0
    return (void *) 0x1000000000000000; // FIXME
1086
0
}
1087
1088
0
static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
1089
0
    GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
1090
0
    ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
1091
0
    const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(buffer);
1092
1093
0
    const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ true);
1094
0
    GGML_ASSERT(ggml_nelements(tensor) == 0 || split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
1095
0
    GGML_ASSERT(split_state.n_segments <= 16);
1096
1097
0
    int split_dim = split_state.axis;
1098
0
    int64_t ne[GGML_MAX_DIMS];
1099
0
    size_t  nb[GGML_MAX_DIMS];
1100
0
    for (size_t k = 0; k < GGML_MAX_DIMS; k++) {
1101
0
        ne[k] = tensor->ne[k];
1102
0
        nb[k] = tensor->nb[k];
1103
0
    }
1104
1105
0
    std::vector<ggml_tensor *> simple_tensors;
1106
0
    simple_tensors.reserve(n_simple_bufs);
1107
0
    for (size_t j = 0; j < n_simple_bufs; j++) {
1108
0
        ggml_context          * simple_ctx = buf_ctx->buf_configs[j].ctx;
1109
0
        ggml_backend_buffer_t   simple_buf = buf_ctx->buf_configs[j].buf;
1110
1111
0
        if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
1112
            // TODO: the following assert fails for llama-parallel even though the results are correct:
1113
            // GGML_ASSERT(ggml_is_contiguously_allocated(tensor));
1114
0
            ne[split_dim] = 0;
1115
0
            for (size_t s = 0; s < split_state.n_segments; s++) {
1116
0
                ne[split_dim] += split_state.ne[s*n_simple_bufs + j];
1117
0
            }
1118
0
            for (int i = 0; i < GGML_MAX_DIMS; i++) {
1119
0
                if (tensor->nb[i] > tensor->nb[split_dim]) {
1120
0
                    nb[i] = tensor->nb[i] * ne[split_dim]/tensor->ne[split_dim];
1121
0
                }
1122
0
            }
1123
0
        }
1124
1125
0
        ggml_tensor * t_ij = ggml_new_tensor(simple_ctx, tensor->type, GGML_MAX_DIMS, ne);
1126
0
        t_ij->op = tensor->op;
1127
0
        for (int i = 0; i < GGML_MAX_DIMS; i++) {
1128
0
            t_ij->nb[i] = nb[i];
1129
0
        }
1130
0
        t_ij->flags = tensor->flags;
1131
0
        memcpy(t_ij->op_params, tensor->op_params, sizeof(tensor->op_params));
1132
0
        ggml_set_name(t_ij, tensor->name);
1133
0
        t_ij->buffer = simple_buf;
1134
0
        t_ij->view_src = tensor->view_src;
1135
0
        t_ij->view_offs = tensor->view_offs;
1136
0
        if (t_ij->view_src != nullptr && ggml_backend_buffer_is_meta(t_ij->view_src->buffer)) {
1137
0
            t_ij->view_src = ggml_backend_meta_buffer_simple_tensor(tensor->view_src, j);
1138
0
            if (t_ij->view_offs > 0 && split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
1139
0
                GGML_ASSERT(ne[split_dim] != 0 && tensor->ne[split_dim] != 0);
1140
0
                const int split_dim_view_src = ggml_backend_meta_get_split_state(tensor->view_src, /*assume_sync =*/ true).axis;
1141
0
                GGML_ASSERT(split_dim_view_src >= 0 && split_dim_view_src < GGML_MAX_DIMS);
1142
1143
                // The offset can be internal to the data split, in those cases the view offset should not be scaled.
1144
                // If however, the offset is larger than the data split then it needs to be scaled proportionally.
1145
0
                bool split_internal_offset = t_ij->view_offs <= tensor->view_src->nb[split_dim_view_src];
1146
0
                for (int i = 0; i < GGML_MAX_DIMS; i++) {
1147
0
                    const size_t dim_size = tensor->ne[i] * tensor->nb[i];
1148
0
                    if (tensor->view_offs <= dim_size && dim_size < tensor->nb[split_dim]) {
1149
0
                        split_internal_offset = true;
1150
0
                        break;
1151
0
                    }
1152
0
                }
1153
0
                if (!split_internal_offset) {
1154
0
                    t_ij->view_offs = t_ij->view_offs * ne[split_dim]/tensor->ne[split_dim];
1155
0
                }
1156
0
            }
1157
0
        }
1158
0
        if (t_ij->view_src != nullptr) {
1159
0
            t_ij->data = (char *) t_ij->view_src->data + t_ij->view_offs;
1160
0
        } else if (simple_buf != nullptr) {
1161
0
            t_ij->data = (char *) ggml_backend_buffer_get_base(simple_buf)
1162
0
                + size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(buffer));
1163
0
        }
1164
0
        t_ij->extra = tensor->extra;
1165
0
        for (int i = 0; i < GGML_MAX_SRC; i++) {
1166
0
            t_ij->src[i] = tensor->src[i];
1167
0
            if (tensor->src[i] == tensor) {
1168
0
                t_ij->src[i] = t_ij;
1169
0
            } else if (t_ij->src[i] != nullptr && ggml_backend_buffer_is_meta(t_ij->src[i]->buffer)) {
1170
0
                t_ij->src[i] = ggml_backend_meta_buffer_simple_tensor(tensor->src[i], j);
1171
0
            }
1172
0
        }
1173
1174
0
        simple_tensors.push_back(t_ij);
1175
0
    }
1176
0
    buf_ctx->simple_tensors[tensor] = simple_tensors;
1177
1178
0
    return GGML_STATUS_SUCCESS;
1179
0
}
1180
1181
0
static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1182
0
    const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer);
1183
0
    GGML_ASSERT(ggml_is_contiguous(tensor));
1184
1185
0
    const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1186
1187
0
    if (split_state.n_segments != 1) {
1188
0
        GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
1189
0
        GGML_ASSERT(offset == 0);
1190
0
        GGML_ASSERT(size == ggml_nbytes(tensor));
1191
0
        GGML_ASSERT(tensor->ne[3] == 1);
1192
0
        size_t offset_data = 0;
1193
0
        std::vector<size_t> simple_offsets(n_bufs, 0);
1194
0
        if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) {
1195
0
            GGML_ASSERT(tensor->ne[2] == 1);
1196
0
            const int64_t blck_size = ggml_blck_size(tensor->type);
1197
0
            for (size_t s = 0; s < split_state.n_segments; s++) {
1198
0
                for (size_t j = 0; j < n_bufs; j++) {
1199
0
                    ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1200
0
                    GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0);
1201
0
                    const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0];
1202
0
                    ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, simple_offsets[j], nbytes,
1203
0
                        tensor->ne[1], simple_tensor->nb[1], tensor->nb[1]);
1204
0
                    offset_data       += nbytes;
1205
0
                    simple_offsets[j] += nbytes;
1206
0
                }
1207
0
            }
1208
0
            GGML_ASSERT(offset_data*tensor->ne[1] == size);
1209
0
            return;
1210
0
        }
1211
0
        GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1);
1212
0
        for (size_t s = 0; s < split_state.n_segments; s++) {
1213
0
            for (size_t j = 0; j < n_bufs; j++) {
1214
0
                ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1215
0
                const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1];
1216
0
                ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, simple_offsets[j], nbytes,
1217
0
                    tensor->ne[2], simple_tensor->nb[2], tensor->nb[2]);
1218
0
                offset_data       += nbytes;
1219
0
                simple_offsets[j] += nbytes;
1220
0
            }
1221
0
        }
1222
0
        GGML_ASSERT(offset_data*tensor->ne[2] == size);
1223
0
        return;
1224
0
    }
1225
1226
0
    switch (split_state.axis) {
1227
0
        case GGML_BACKEND_SPLIT_AXIS_0:
1228
0
        case GGML_BACKEND_SPLIT_AXIS_1:
1229
0
        case GGML_BACKEND_SPLIT_AXIS_2: {
1230
            // Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
1231
0
            const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
1232
0
            GGML_ASSERT(offset % chunk_size_full == 0);
1233
0
            GGML_ASSERT(size   % chunk_size_full == 0);
1234
0
            const int64_t i_start =  offset        /chunk_size_full;
1235
0
            const int64_t i_stop  = (offset + size)/chunk_size_full;
1236
0
            size_t offset_j = 0;
1237
0
            for (size_t j = 0; j < n_bufs; j++) {
1238
0
                ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1239
0
                const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
1240
0
                const size_t simple_offset = i_start * chunk_size_j;
1241
0
                ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full);
1242
0
                offset_j += chunk_size_j;
1243
0
            }
1244
0
            GGML_ASSERT(offset_j == chunk_size_full);
1245
0
        } break;
1246
0
        case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
1247
0
            for (size_t j = 0; j < n_bufs; j++) {
1248
0
                ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1249
0
                ggml_backend_tensor_set(simple_tensor, data, offset, size);
1250
0
            }
1251
0
        } break;
1252
0
        case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
1253
0
            GGML_ASSERT(tensor->type == GGML_TYPE_F32);
1254
0
            const int64_t ne = ggml_nelements(tensor);
1255
0
            std::vector<float> tmp;
1256
0
            tmp.reserve(ne);
1257
0
            for (int64_t i = 0; i < ne; i++) {
1258
0
                tmp.push_back(((const float *) data)[i] / n_bufs);
1259
0
            }
1260
0
            for (size_t j = 0; j < n_bufs; j++) {
1261
0
                ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1262
0
                ggml_backend_tensor_set(simple_tensor, tmp.data(), offset, size);
1263
0
            }
1264
0
        } break;
1265
0
        default: {
1266
0
            GGML_ABORT("fatal error");
1267
0
        }
1268
0
    }
1269
0
}
1270
1271
0
static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1272
0
    const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer);
1273
0
    GGML_ASSERT(ggml_is_contiguous(tensor));
1274
1275
0
    const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1276
0
    GGML_ASSERT(split_state.n_segments == 1);
1277
1278
0
    switch (split_state.axis) {
1279
0
        case GGML_BACKEND_SPLIT_AXIS_0:
1280
0
        case GGML_BACKEND_SPLIT_AXIS_1:
1281
0
        case GGML_BACKEND_SPLIT_AXIS_2: {
1282
            // Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
1283
0
            const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
1284
0
            GGML_ASSERT(offset % chunk_size_full == 0);
1285
0
            GGML_ASSERT(size   % chunk_size_full == 0);
1286
0
            const int64_t i_start =  offset        /chunk_size_full;
1287
0
            const int64_t i_stop  = (offset + size)/chunk_size_full;
1288
0
            size_t offset_j = 0;
1289
0
            for (size_t j = 0; j < n_bufs; j++){
1290
0
                const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1291
0
                const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
1292
0
                const size_t simple_offset = i_start * chunk_size_j;
1293
0
                ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full);
1294
0
                offset_j += chunk_size_j;
1295
0
            }
1296
0
            GGML_ASSERT(offset_j == chunk_size_full);
1297
0
        } break;
1298
0
        case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
1299
            // TODO other simple backend may be better
1300
0
            const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0);
1301
0
            ggml_backend_tensor_get(simple_tensor, data, offset, size);
1302
0
        } break;
1303
0
        default: {
1304
0
            GGML_ABORT("fatal error");
1305
0
        }
1306
0
    }
1307
0
}
1308
1309
0
static void ggml_backend_meta_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1310
0
    const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer);
1311
0
    for (size_t i = 0; i < n_buffers; i++) {
1312
0
        ggml_backend_buffer_clear(ggml_backend_meta_buffer_simple_buffer(buffer, i), value);
1313
0
    }
1314
0
}
1315
1316
0
static void ggml_backend_meta_buffer_reset(ggml_backend_buffer_t buffer) {
1317
0
    const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer);
1318
0
    for (size_t i = 0; i < n_buffers; i++) {
1319
0
        ggml_backend_buffer_reset(ggml_backend_meta_buffer_simple_buffer(buffer, i));
1320
0
    }
1321
0
}
1322
1323
static const ggml_backend_buffer_i ggml_backend_meta_buffer_iface = {
1324
    /* .free_buffer     = */ ggml_backend_meta_buffer_free_buffer,
1325
    /* .get_base        = */ ggml_backend_meta_buffer_get_base,
1326
    /* .init_tensor     = */ ggml_backend_meta_buffer_init_tensor,
1327
    /* .memset_tensor   = */ nullptr, // TODO implement
1328
    /* .set_tensor      = */ ggml_backend_meta_buffer_set_tensor,
1329
    /* .get_tensor      = */ ggml_backend_meta_buffer_get_tensor,
1330
    /* .set_tensor_2d   = */ nullptr,
1331
    /* .get_tensor_2d   = */ nullptr,
1332
    /* .cpy_tensor      = */ nullptr,
1333
    /* .clear           = */ ggml_backend_meta_buffer_clear,
1334
    /* .reset           = */ ggml_backend_meta_buffer_reset,
1335
};
1336
1337
0
bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf) {
1338
0
    return buf != nullptr && buf->iface.free_buffer == ggml_backend_meta_buffer_iface.free_buffer;
1339
0
}
1340
1341
0
static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1342
0
    const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
1343
1344
0
    ggml_init_params params = {
1345
0
        /*.mem_size   =*/ 1024*1024*1024, // FIXME
1346
0
        /*.mem_buffer =*/ nullptr,
1347
0
        /*.no_alloc   =*/ true,
1348
0
    };
1349
1350
0
    ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context();
1351
0
    size_t max_size = 0;
1352
0
    buf_ctx->buf_configs.reserve(n_simple_bufts);
1353
0
    for (size_t i = 0; i < n_simple_bufts; i++) {
1354
0
        ggml_backend_buffer_t simple_buf = ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size);
1355
0
        max_size = std::max(max_size, ggml_backend_buffer_get_size(simple_buf));
1356
0
        buf_ctx->buf_configs.emplace_back(ggml_init(params), simple_buf);
1357
0
    }
1358
1359
0
    return ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, buf_ctx, max_size);
1360
0
}
1361
1362
0
struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
1363
0
    const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
1364
1365
0
    ggml_init_params params = {
1366
0
        /*.mem_size   =*/ 1024*1024*1024, // FIXME
1367
0
        /*.mem_buffer =*/ nullptr,
1368
0
        /*.no_alloc   =*/ true,
1369
0
    };
1370
1371
0
    ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context();
1372
0
    meta_buf_ctx->buf_configs.reserve(n_simple_bufts);
1373
0
    for (size_t i = 0; i < n_simple_bufts; i++) {
1374
0
        meta_buf_ctx->buf_configs.emplace_back(ggml_init(params), nullptr);
1375
0
    }
1376
1377
0
    ggml_backend_buffer_t meta_buf = ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, meta_buf_ctx, 0);
1378
0
    for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
1379
0
        t->buffer = meta_buf;
1380
0
        ggml_backend_meta_buffer_init_tensor(meta_buf, t);
1381
0
        t->data = (void *) 0x2000000000000000; // FIXME
1382
0
    }
1383
0
    for (size_t i = 0; i < n_simple_bufts; i++) {
1384
0
        meta_buf_ctx->buf_configs[i].buf = ggml_backend_alloc_ctx_tensors_from_buft(
1385
0
            meta_buf_ctx->buf_configs[i].ctx, ggml_backend_meta_buft_simple_buft(buft, i));
1386
0
        meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->buf_configs[i].buf));
1387
0
    }
1388
0
    return meta_buf;
1389
0
}
1390
1391
//
1392
// meta backend
1393
//
1394
1395
0
static ggml_guid_t ggml_backend_meta_guid() {
1396
0
    static ggml_guid guid = {0xf1, 0x0e, 0x34, 0xcf, 0x9c, 0x6f, 0x43, 0xcb, 0x96, 0x92, 0xbe, 0x8e, 0xbb, 0x71, 0x3f, 0xda};
1397
0
    return &guid;
1398
0
}
1399
1400
struct ggml_backend_meta_context {
1401
    struct cgraph_config {
1402
        ggml_cgraph * cgraph_main = nullptr;
1403
        int           offset      = 0; // Node offset vs. original graph
1404
1405
        std::vector<ggml_cgraph *> cgraphs_aux;
1406
    };
1407
    struct backend_config {
1408
        ggml_backend_t backend;
1409
1410
        std::vector<cgraph_config> cgraphs;
1411
        std::vector<ggml_tensor *> nodes;
1412
        ggml_backend_buffer_ptr    buf;
1413
1414
0
        backend_config(ggml_backend_t backend) : backend(backend) {}
1415
    };
1416
    std::string                 name;
1417
    std::vector<backend_config> backend_configs;
1418
    ggml_context_ptr            ctx;
1419
    std::vector<ggml_cgraph *>  cgraphs_aux;
1420
    std::vector<ggml_tensor *>  nodes_aux;
1421
    int                         max_nnodes    = 0;
1422
    size_t                      max_tmp_size  = 0;
1423
    size_t                      max_subgraphs = 0;
1424
1425
0
    ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) {
1426
0
        const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev);
1427
0
        name = "Meta(";
1428
0
        backend_configs.reserve(n_devs);
1429
0
        for (size_t i = 0; i < n_devs; i++) {
1430
0
            ggml_backend_dev_t simple_dev = ggml_backend_meta_dev_simple_dev(meta_dev, i);
1431
0
            if (i > 0) {
1432
0
                name += ",";
1433
0
            }
1434
0
            name += ggml_backend_dev_name(simple_dev);
1435
0
            backend_configs.emplace_back(ggml_backend_dev_init(simple_dev, params));
1436
0
        }
1437
0
        name += ")";
1438
0
    }
1439
1440
0
    ~ggml_backend_meta_context() {
1441
0
        for (auto & bc : backend_configs) {
1442
0
            ggml_backend_free(bc.backend);
1443
0
        }
1444
0
    }
1445
1446
0
    size_t n_reduce_steps() const {
1447
0
        return std::ceil(std::log2(backend_configs.size()));
1448
0
    }
1449
};
1450
1451
0
static const char * ggml_backend_meta_get_name(ggml_backend_t backend) {
1452
0
    GGML_ASSERT(ggml_backend_is_meta(backend));
1453
0
    const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) backend->context;
1454
0
    return backend_ctx->name.c_str();
1455
0
}
1456
1457
0
static void ggml_backend_meta_free(ggml_backend_t backend) {
1458
0
    GGML_ASSERT(ggml_backend_is_meta(backend));
1459
0
    ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context;
1460
0
    delete backend_ctx;
1461
0
    delete backend;
1462
0
}
1463
1464
0
static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1465
0
    const size_t n_backends = ggml_backend_meta_n_backends(backend);
1466
0
    GGML_ASSERT(offset == 0);
1467
0
    GGML_ASSERT(ggml_is_contiguous(tensor));
1468
1469
0
    const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1470
0
    GGML_ASSERT(split_state.n_segments == 1);
1471
1472
0
    switch (split_state.axis) {
1473
0
        case GGML_BACKEND_SPLIT_AXIS_0:
1474
0
        case GGML_BACKEND_SPLIT_AXIS_1:
1475
0
        case GGML_BACKEND_SPLIT_AXIS_2: {
1476
            // Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
1477
0
            const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
1478
0
            GGML_ASSERT(offset % chunk_size_full == 0);
1479
0
            GGML_ASSERT(size   % chunk_size_full == 0);
1480
0
            const int64_t i_start =  offset        /chunk_size_full;
1481
0
            const int64_t i_stop  = (offset + size)/chunk_size_full;
1482
0
            size_t offset_j = 0;
1483
0
            for (size_t j = 0; j < n_backends; j++){
1484
0
                ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j);
1485
0
                ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1486
0
                const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
1487
0
                ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j,
1488
0
                    i_stop - i_start, chunk_size_j, chunk_size_full);
1489
0
                offset_j += chunk_size_j;
1490
0
            }
1491
0
            GGML_ASSERT(offset_j == chunk_size_full);
1492
0
        } break;
1493
0
        case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
1494
0
            for (size_t j = 0; j < n_backends; j++) {
1495
0
                ggml_backend_tensor_set_async(
1496
0
                    ggml_backend_meta_simple_backend(backend, j), ggml_backend_meta_buffer_simple_tensor(tensor, j), data, offset, size);
1497
0
            }
1498
0
        } break;
1499
0
        default: {
1500
0
            GGML_ABORT("fatal error");
1501
0
        }
1502
0
    }
1503
0
}
1504
1505
0
static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1506
0
    const size_t n_backends = ggml_backend_meta_n_backends(backend);
1507
0
    GGML_ASSERT(offset == 0);
1508
0
    GGML_ASSERT(ggml_is_contiguous(tensor));
1509
1510
0
    const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1511
0
    GGML_ASSERT(split_state.n_segments == 1);
1512
1513
0
    switch (split_state.axis) {
1514
0
        case GGML_BACKEND_SPLIT_AXIS_0:
1515
0
        case GGML_BACKEND_SPLIT_AXIS_1:
1516
0
        case GGML_BACKEND_SPLIT_AXIS_2: {
1517
            // Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
1518
0
            const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
1519
0
            GGML_ASSERT(offset % chunk_size_full == 0);
1520
0
            GGML_ASSERT(size   % chunk_size_full == 0);
1521
0
            const int64_t i_start =  offset        /chunk_size_full;
1522
0
            const int64_t i_stop  = (offset + size)/chunk_size_full;
1523
0
            size_t offset_j = 0;
1524
0
            for (size_t j = 0; j < n_backends; j++){
1525
0
                ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j);
1526
0
                const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1527
0
                const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
1528
0
                ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j,
1529
0
                    i_stop - i_start, chunk_size_j, chunk_size_full);
1530
0
                offset_j += chunk_size_j;
1531
0
            }
1532
0
            GGML_ASSERT(offset_j == chunk_size_full);
1533
0
        } break;
1534
0
        case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
1535
            // TODO other simple backend may be better
1536
0
            ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, 0);
1537
0
            const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0);
1538
0
            ggml_backend_tensor_get_async(simple_backend, simple_tensor, data, offset, size);
1539
0
        } break;
1540
0
        default: {
1541
0
            GGML_ABORT("fatal error");
1542
0
        }
1543
0
    }
1544
0
}
1545
1546
0
static void ggml_backend_meta_synchronize(ggml_backend_t backend) {
1547
0
    const size_t n_backends = ggml_backend_meta_n_backends(backend);
1548
0
    for (size_t i = 0; i < n_backends; i++) {
1549
0
        ggml_backend_synchronize(ggml_backend_meta_simple_backend(backend, i));
1550
0
    }
1551
0
}
1552
1553
0
static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1554
0
    GGML_ASSERT(cgraph->grads == nullptr);
1555
0
    const size_t n_backends = ggml_backend_meta_n_backends(backend);
1556
0
    ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context;
1557
1558
0
    bool max_nnodes_raised = false;
1559
0
    if (cgraph->n_nodes > backend_ctx->max_nnodes) {
1560
0
        for (size_t j = 0; j < n_backends; j++) {
1561
0
            auto & bcj = backend_ctx->backend_configs[j];
1562
0
            bcj.nodes.resize(cgraph->n_nodes);
1563
0
            bcj.cgraphs.resize(cgraph->n_nodes);
1564
0
        }
1565
0
        backend_ctx->max_nnodes = cgraph->n_nodes;
1566
0
        max_nnodes_raised = true;
1567
0
    }
1568
0
    for (size_t j = 0; j < n_backends; j++) {
1569
0
        auto & bcj = backend_ctx->backend_configs[j];
1570
1571
0
        for (int i = 0; i < cgraph->n_nodes; i++) {
1572
0
            ggml_tensor * node = cgraph->nodes[i];
1573
0
            if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) {
1574
                // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes.
1575
                // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash.
1576
0
                bcj.nodes[i] = node;
1577
0
                continue;
1578
0
            }
1579
0
            bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j);
1580
0
            GGML_ASSERT(bcj.nodes[i]);
1581
0
        }
1582
0
    }
1583
1584
0
    size_t n_subgraphs  = 0;
1585
0
    size_t max_tmp_size = 0;
1586
0
    {
1587
        // For MoE models it may make sense to delay the AllReduce in order to reduce I/O:
1588
0
        auto get_i_delayed = [&](const int i) -> int {
1589
0
            int id = i; // i_delayed
1590
0
            int idr = i; // i_delayed return, last safe return value
1591
1592
0
            ggml_tensor * node = cgraph->nodes[id];
1593
0
            int32_t n_used = ggml_node_get_use_count(cgraph, id);
1594
0
            if (id + 1 >= cgraph->n_nodes) {
1595
0
                return idr;
1596
0
            }
1597
0
            {
1598
0
                ggml_tensor * next = cgraph->nodes[id+1];
1599
0
                if (next->op == GGML_OP_ADD_ID && next->src[0] == node &&
1600
0
                        ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL &&
1601
0
                        ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
1602
0
                    node = next;
1603
0
                    id++;
1604
0
                    idr = id;
1605
0
                    n_used = ggml_node_get_use_count(cgraph, id);
1606
0
                }
1607
0
            }
1608
0
            if (id + 1 >= cgraph->n_nodes) {
1609
0
                return idr;
1610
0
            }
1611
0
            {
1612
0
                ggml_tensor * next = cgraph->nodes[id+1];
1613
0
                if (next->op == GGML_OP_MUL && next->src[0] == node &&
1614
0
                        ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
1615
0
                    node = next;
1616
0
                    id++;
1617
0
                    idr = id;
1618
0
                    n_used = ggml_node_get_use_count(cgraph, id);
1619
0
                }
1620
0
            }
1621
1622
0
            if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) {
1623
0
                return idr;
1624
0
            }
1625
0
            for (int32_t k = 0; k < n_used; k++) {
1626
0
                ggml_tensor * next = cgraph->nodes[id+1];
1627
0
                if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] ||
1628
0
                        next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] ||
1629
0
                        ggml_node_get_use_count(cgraph, id+1) != 1) {
1630
0
                    return idr;
1631
0
                }
1632
0
                id++;
1633
0
            }
1634
0
            {
1635
0
                ggml_tensor * next = cgraph->nodes[id+1];
1636
0
                if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] ||
1637
0
                        next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) {
1638
0
                    return idr;
1639
0
                }
1640
0
                id++;
1641
0
            }
1642
0
            for (int32_t k = 0; k < n_used - 2; k++) {
1643
0
                ggml_tensor * next = cgraph->nodes[id+1];
1644
0
                if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] ||
1645
0
                        next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) {
1646
0
                    return idr;
1647
0
                }
1648
0
                id++;
1649
0
            }
1650
0
            idr = id;
1651
0
            return idr;
1652
0
        };
1653
1654
0
        int i_start = 0;
1655
0
        for (int i = 0; i < cgraph->n_nodes; i++) {
1656
0
            ggml_tensor * node = cgraph->nodes[i];
1657
0
            if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) {
1658
0
                continue;
1659
0
            }
1660
0
            const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false);
1661
0
            if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) {
1662
0
                max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node));
1663
0
            }
1664
0
            const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL;
1665
0
            if (!new_subgraph) {
1666
0
                continue;
1667
0
            }
1668
1669
0
            i = get_i_delayed(i);
1670
1671
0
            for (size_t j = 0; j < n_backends; j++) {
1672
0
                auto & bcj = backend_ctx->backend_configs[j];
1673
0
                bcj.cgraphs[n_subgraphs].offset = i_start;
1674
0
            }
1675
0
            n_subgraphs++;
1676
0
            i_start = i + 1;
1677
0
        }
1678
0
        GGML_ASSERT(i_start == cgraph->n_nodes);
1679
0
    }
1680
1681
0
    if (max_tmp_size > backend_ctx->max_tmp_size) {
1682
0
        for (size_t j = 0; j < n_backends; j++) {
1683
0
            auto & bcj = backend_ctx->backend_configs[j];
1684
0
            bcj.buf.reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size));
1685
0
        }
1686
0
        backend_ctx->max_tmp_size = max_tmp_size;
1687
0
    }
1688
1689
1690
0
    if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) {
1691
0
        backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs);
1692
0
        const size_t n_reduce_steps = backend_ctx->n_reduce_steps();
1693
0
        const size_t n_nodes_per_device = 2 * n_reduce_steps; // tmp + ADD per step
1694
0
        const size_t n_cgraphs_per_device = n_reduce_steps;    // 1 ADD graph per step
1695
0
        const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads);
1696
0
        const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads);
1697
0
        const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead();
1698
0
        ggml_init_params params = {
1699
0
            /*.mem_size   =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux),
1700
0
            /*.mem_buffer =*/ nullptr,
1701
0
            /*.no_alloc   =*/ true,
1702
0
        };
1703
0
        backend_ctx->ctx.reset(ggml_init(params));
1704
0
        for (size_t j = 0; j < n_backends; j++) {
1705
0
            auto & bcj = backend_ctx->backend_configs[j];
1706
0
            for (size_t i = 0; i < n_subgraphs; i++) {
1707
0
                bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false);
1708
0
            }
1709
0
        }
1710
0
        backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs);
1711
0
        for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) {
1712
0
            backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads);
1713
0
        }
1714
0
        backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs);
1715
0
        for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) {
1716
0
            backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1);
1717
0
        }
1718
0
    }
1719
1720
0
    for (size_t j = 0; j < n_backends; j++) {
1721
0
        auto & bcj = backend_ctx->backend_configs[j];
1722
0
        for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) {
1723
0
            ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main;
1724
0
            const size_t i_node_start = bcj.cgraphs[i_graph].offset;
1725
0
            const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes;
1726
0
            cgraph_ij->n_nodes = i_node_stop - i_node_start;
1727
0
            ggml_hash_set_reset(&cgraph_ij->visited_hash_set);
1728
0
            for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) {
1729
0
                ggml_tensor * node_ij = bcj.nodes[i_node];
1730
0
                cgraph_ij->nodes[i_node - i_node_start] = node_ij;
1731
0
                const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]);
1732
0
                const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij);
1733
0
                cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig];
1734
0
            }
1735
0
        }
1736
0
    }
1737
1738
0
    size_t iga = 0; // i graph aux
1739
0
    size_t ina = 0; // i node aux
1740
1741
    // FIXME usage_counts
1742
0
    auto get_cgraph_aux = [&]() -> ggml_cgraph * {
1743
0
        ggml_cgraph * ret = backend_ctx->cgraphs_aux[iga++];
1744
0
        return ret;
1745
0
    };
1746
0
    auto get_node_aux = [&](ggml_tensor * t) -> ggml_tensor * {
1747
0
        ggml_tensor * ret = backend_ctx->nodes_aux[ina++];
1748
0
        memset(ret, 0, sizeof(ggml_tensor));
1749
0
        ret->op   = GGML_OP_NONE;
1750
0
        ret->type = t->type;
1751
0
        for (size_t k = 0; k < GGML_MAX_DIMS; k++) {
1752
0
            ret->ne[k] = t->ne[k];
1753
0
            ret->nb[k] = t->nb[k];
1754
0
        }
1755
0
        return ret;
1756
0
    };
1757
1758
    // Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable:
1759
0
    auto allreduce_fallback = [&](size_t i) -> ggml_status {
1760
0
        std::vector<ggml_cgraph *> step_cgraphs(n_backends, nullptr);
1761
1762
0
        for (size_t offset_j = 1; offset_j < n_backends; offset_j *= 2) {
1763
0
            std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr);
1764
1765
0
            for (size_t j = 0; j < n_backends; j++) {
1766
0
                const size_t j_other = j ^ offset_j;
1767
0
                if (j_other > j) {
1768
0
                    continue;
1769
0
                }
1770
1771
0
                auto & bcj1 = backend_ctx->backend_configs[j];
1772
0
                auto & bcj2 = backend_ctx->backend_configs[j_other];
1773
1774
0
                ggml_tensor * node1 = bcj1.cgraphs[i].cgraph_main->nodes[bcj1.cgraphs[i].cgraph_main->n_nodes - 1];
1775
0
                ggml_tensor * node2 = bcj2.cgraphs[i].cgraph_main->nodes[bcj2.cgraphs[i].cgraph_main->n_nodes - 1];
1776
0
                GGML_ASSERT(ggml_is_contiguous(node1));
1777
0
                GGML_ASSERT(ggml_is_contiguous(node2));
1778
1779
                // Tmp tensors to receive P2P copies
1780
0
                ggml_tensor * node_tmp_1 = get_node_aux(node1);
1781
0
                node_tmp_1->buffer = bcj1.buf.get();
1782
0
                node_tmp_1->data = ggml_backend_buffer_get_base(bcj1.buf.get());
1783
1784
0
                ggml_tensor * node_tmp_2 = get_node_aux(node2);
1785
0
                node_tmp_2->buffer = bcj2.buf.get();
1786
0
                node_tmp_2->data = ggml_backend_buffer_get_base(bcj2.buf.get());
1787
1788
                // 2 P2P copies: exchange full buffers
1789
0
                ggml_backend_tensor_copy_async(bcj1.backend, bcj2.backend, node1, node_tmp_2);
1790
0
                ggml_backend_tensor_copy_async(bcj2.backend, bcj1.backend, node2, node_tmp_1);
1791
1792
                // Local ADD: node1 += tmp1 (in-place via view)
1793
0
                ggml_tensor * node_red_1 = get_node_aux(node1);
1794
0
                node_red_1->view_src = node1->view_src == nullptr ? node1 : node1->view_src;
1795
0
                node_red_1->view_offs = node1->view_offs;
1796
0
                node_red_1->op = GGML_OP_ADD;
1797
0
                node_red_1->src[0] = node1;
1798
0
                node_red_1->src[1] = node_tmp_1;
1799
0
                node_red_1->flags |= GGML_TENSOR_FLAG_COMPUTE;
1800
0
                ggml_backend_view_init(node_red_1);
1801
1802
                // Local ADD: node2 += tmp2 (in-place via view)
1803
0
                ggml_tensor * node_red_2 = get_node_aux(node2);
1804
0
                node_red_2->view_src = node2->view_src == nullptr ? node2 : node2->view_src;
1805
0
                node_red_2->view_offs = node2->view_offs;
1806
0
                node_red_2->op = GGML_OP_ADD;
1807
0
                node_red_2->src[0] = node2;
1808
0
                node_red_2->src[1] = node_tmp_2;
1809
0
                node_red_2->flags |= GGML_TENSOR_FLAG_COMPUTE;
1810
0
                ggml_backend_view_init(node_red_2);
1811
1812
                // Build 1-node cgraphs for the ADD ops
1813
0
                ggml_cgraph * cgraph_aux_1 = get_cgraph_aux();
1814
0
                cgraph_aux_1->nodes[0] = node_red_1;
1815
0
                cgraph_aux_1->n_nodes = 1;
1816
0
                step_cgraphs[j] = cgraph_aux_1;
1817
1818
0
                ggml_cgraph * cgraph_aux_2 = get_cgraph_aux();
1819
0
                cgraph_aux_2->nodes[0] = node_red_2;
1820
0
                cgraph_aux_2->n_nodes = 1;
1821
0
                step_cgraphs[j_other] = cgraph_aux_2;
1822
0
            }
1823
1824
            // Execute local ADDs for this step
1825
0
            for (size_t j = 0; j < n_backends; j++) {
1826
0
                if (step_cgraphs[j] == nullptr) {
1827
0
                    continue;
1828
0
                }
1829
0
                auto & bcj = backend_ctx->backend_configs[j];
1830
0
                const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]);
1831
0
                if (status != GGML_STATUS_SUCCESS) {
1832
0
                    return status;
1833
0
                }
1834
0
            }
1835
0
        }
1836
0
        return GGML_STATUS_SUCCESS;
1837
0
    };
1838
1839
1840
0
    for (size_t i = 0; i < n_subgraphs; i++) {
1841
0
        for (size_t j = 0; j < n_backends; j++) {
1842
0
            auto & bcj = backend_ctx->backend_configs[j];
1843
0
            const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, bcj.cgraphs[i].cgraph_main);
1844
0
            if (status != GGML_STATUS_SUCCESS) {
1845
0
                return status;
1846
0
            }
1847
0
        }
1848
1849
0
        if (n_backends > 1 && i < n_subgraphs - 1) {
1850
0
            bool backend_allreduce_success = false;
1851
0
            ggml_backend_allreduce_tensor_t allreduce_tensor = (ggml_backend_allreduce_tensor_t) ggml_backend_reg_get_proc_address(
1852
0
                ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_ctx->backend_configs[0].backend)), "ggml_backend_allreduce_tensor");
1853
0
            if (allreduce_tensor) {
1854
0
                std::vector<ggml_backend_t> backends;
1855
0
                backends.reserve(n_backends);
1856
0
                std::vector<ggml_tensor *> nodes;
1857
0
                nodes.reserve(n_backends);
1858
0
                for (size_t j = 0; j < n_backends; j++) {
1859
0
                    auto & bcj = backend_ctx->backend_configs[j];
1860
0
                    backends.push_back(bcj.backend);
1861
0
                    ggml_cgraph * cgraph_ij = bcj.cgraphs[i].cgraph_main;
1862
0
                    nodes.push_back(cgraph_ij->nodes[cgraph_ij->n_nodes-1]);
1863
0
                }
1864
0
                backend_allreduce_success = allreduce_tensor(backends.data(), nodes.data(), n_backends);
1865
0
            }
1866
1867
0
            if (!backend_allreduce_success) {
1868
0
                const ggml_status status = allreduce_fallback(i);
1869
0
                if (status != GGML_STATUS_SUCCESS) {
1870
0
                    return status;
1871
0
                }
1872
0
            }
1873
0
        }
1874
0
    }
1875
0
    return GGML_STATUS_SUCCESS;
1876
0
}
1877
1878
static const ggml_backend_i ggml_backend_meta_i = {
1879
    /* .get_name                = */ ggml_backend_meta_get_name,
1880
    /* .free                    = */ ggml_backend_meta_free,
1881
    /* .set_tensor_async        = */ ggml_backend_meta_set_tensor_async,
1882
    /* .get_tensor_async        = */ ggml_backend_meta_get_tensor_async,
1883
    /* .get_tensor_2d_async     = */ nullptr,
1884
    /* .set_tensor_2d_async     = */ nullptr,
1885
    /* .cpy_tensor_async        = */ nullptr,
1886
    /* .synchronize             = */ ggml_backend_meta_synchronize,
1887
    /* .graph_plan_create       = */ nullptr,
1888
    /* .graph_plan_free         = */ nullptr,
1889
    /* .graph_plan_update       = */ nullptr,
1890
    /* .graph_plan_compute      = */ nullptr,
1891
    /* .graph_compute           = */ ggml_backend_meta_graph_compute,
1892
    /* .event_record            = */ nullptr,
1893
    /* .event_wait              = */ nullptr,
1894
    /* .graph_optimize          = */ nullptr,
1895
};
1896
1897
0
bool ggml_backend_is_meta(ggml_backend_t backend) {
1898
0
    return backend != nullptr && backend->iface.get_name == ggml_backend_meta_i.get_name;
1899
0
}
1900
1901
0
static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params) {
1902
0
    ggml_backend_meta_context * backend_ctx = new ggml_backend_meta_context(dev, params);
1903
1904
0
    ggml_backend_t backend = new struct ggml_backend;
1905
0
    backend->guid    = ggml_backend_meta_guid();
1906
0
    backend->iface   = ggml_backend_meta_i;
1907
0
    backend->device  = dev;
1908
0
    backend->context = backend_ctx;
1909
0
    return backend;
1910
0
}
1911
1912
0
size_t ggml_backend_meta_n_backends(ggml_backend_t meta_backend) {
1913
0
    GGML_ASSERT(ggml_backend_is_meta(meta_backend));
1914
0
    const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context;
1915
0
    return backend_ctx->backend_configs.size();
1916
0
}
1917
1918
0
ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index) {
1919
0
    GGML_ASSERT(ggml_backend_is_meta(meta_backend));
1920
0
    const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context;
1921
0
    return backend_ctx->backend_configs[index].backend;
1922
0
}
1923