Coverage Report

Created: 2026-06-22 06:47

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/ggml/src/ggml-backend-meta.cpp
Line
Count
Source
1
#include "ggml.h"
2
#include "ggml-impl.h"
3
#include "ggml-backend.h"
4
#include "ggml-backend-impl.h"
5
#include "ggml-alloc.h"
6
#include "ggml-cpp.h"
7
8
#include <algorithm>
9
#include <cassert>
10
#include <cmath>
11
#include <cstddef>
12
#include <cstdint>
13
#include <cstring>
14
#include <map>
15
#include <memory>
16
#include <set>
17
#include <string>
18
#include <tuple>
19
#include <utility>
20
#include <vector>
21
22
struct ggml_backend_meta_device;
23
struct ggml_backend_meta_buffer_type;
24
struct ggml_backend_meta_buffer;
25
struct ggml_backend_meta;
26
27
0
const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis) {
28
0
    switch (split_axis) {
29
0
        case GGML_BACKEND_SPLIT_AXIS_0:
30
0
            return "0";
31
0
        case GGML_BACKEND_SPLIT_AXIS_1:
32
0
            return "1";
33
0
        case GGML_BACKEND_SPLIT_AXIS_2:
34
0
            return "2";
35
0
        case GGML_BACKEND_SPLIT_AXIS_3:
36
0
            return "3";
37
0
        case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
38
0
            return "MIRRORED";
39
0
        case GGML_BACKEND_SPLIT_AXIS_PARTIAL:
40
0
            return "PARTIAL";
41
0
        case GGML_BACKEND_SPLIT_AXIS_NONE:
42
0
            return "NONE";
43
0
        case GGML_BACKEND_SPLIT_AXIS_UNKNOWN:
44
0
            return "UNKNOWN";
45
0
        default:
46
0
            GGML_ABORT("fatal error");
47
0
    }
48
0
}
49
50
//
51
// meta backend device
52
//
53
54
struct ggml_backend_meta_device_context {
55
    std::vector<ggml_backend_dev_t>     simple_devs;
56
    ggml_backend_meta_get_split_state_t get_split_state;
57
    void *                              get_split_state_ud;
58
59
    std::string name;
60
    std::string description;
61
62
    ggml_backend_meta_device_context(
63
            std::vector<ggml_backend_dev_t> simple_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) :
64
0
            simple_devs(std::move(simple_devs)), get_split_state(get_split_state), get_split_state_ud(get_split_state_ud) {
65
0
        name        = std::string("Meta(");
66
0
        description = std::string("Meta(");
67
0
        for (size_t i = 0; i < simple_devs.size(); i++) {
68
0
            if (i > 0) {
69
0
                name        += ",";
70
0
                description += ",";
71
0
            }
72
0
            name        += ggml_backend_dev_name       (simple_devs[i]);
73
0
            description += ggml_backend_dev_description(simple_devs[i]);
74
0
        }
75
0
        name        += ")";
76
0
        description += ")";
77
0
    }
78
79
0
    bool operator<(const ggml_backend_meta_device_context & other) const {
80
0
        return std::tie(simple_devs, get_split_state, get_split_state_ud)
81
0
            < std::tie(other.simple_devs, other.get_split_state, other.get_split_state_ud);
82
0
    }
83
};
84
85
static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev);
86
87
0
static const char * ggml_backend_meta_device_get_name(ggml_backend_dev_t dev) {
88
0
    GGML_ASSERT(ggml_backend_dev_is_meta(dev));
89
0
    const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
90
0
    return meta_dev_ctx->name.c_str();
91
0
}
92
93
0
static const char * ggml_backend_meta_device_get_description(ggml_backend_dev_t dev) {
94
0
    GGML_ASSERT(ggml_backend_dev_is_meta(dev));
95
0
    const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
96
0
    return meta_dev_ctx->description.c_str();
97
0
}
98
99
0
static void ggml_backend_meta_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
100
0
    GGML_ASSERT(ggml_backend_dev_is_meta(dev));
101
0
    const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
102
0
    *free  = 0;
103
0
    *total = 0;
104
0
    for (ggml_backend_dev_t dev : meta_dev_ctx->simple_devs) {
105
0
        size_t tmp_free, tmp_total;
106
0
        ggml_backend_dev_memory(dev, &tmp_free, &tmp_total);
107
0
        *free  += tmp_free;
108
0
        *total += tmp_total;
109
0
    }
110
0
}
111
112
0
static enum ggml_backend_dev_type ggml_backend_meta_device_get_type(ggml_backend_dev_t dev) {
113
0
    return GGML_BACKEND_DEVICE_TYPE_META;
114
115
0
    GGML_UNUSED(dev);
116
0
}
117
118
0
static void ggml_backend_meta_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
119
0
    GGML_ASSERT(ggml_backend_dev_is_meta(dev));
120
0
    const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
121
122
    // TODO replace placeholders
123
0
    props->name        = ggml_backend_meta_device_get_name(dev);
124
0
    props->description = ggml_backend_meta_device_get_description(dev);
125
0
    props->type        = ggml_backend_meta_device_get_type(dev);
126
0
    props->device_id   = 0;
127
128
0
    ggml_backend_meta_device_get_memory(dev, &props->memory_free, &props->memory_total);
129
130
0
    props->caps = {
131
0
        /* .async                 = */ true,
132
0
        /* .host_buffer           = */ false, // Not implemented.
133
0
        /* .buffer_from_host_ptr  = */ false, // Not implemented.
134
0
        /* .events                = */ false, // Not implemented.
135
0
    };
136
0
    for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) {
137
0
        ggml_backend_dev_props tmp_props;
138
0
        ggml_backend_dev_get_props(simple_dev, &tmp_props);
139
0
        props->caps.async                = props->caps.async                && tmp_props.caps.async;
140
0
        props->caps.host_buffer          = props->caps.host_buffer          && tmp_props.caps.host_buffer;
141
0
        props->caps.buffer_from_host_ptr = props->caps.buffer_from_host_ptr && tmp_props.caps.buffer_from_host_ptr;
142
0
        props->caps.events               = props->caps.events               && tmp_props.caps.events;
143
0
    }
144
0
}
145
146
static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params);
147
148
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev);
149
150
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev);
151
152
0
static bool ggml_backend_meta_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
153
0
    GGML_ASSERT(ggml_backend_dev_is_meta(dev));
154
0
    const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
155
0
    return std::all_of(meta_dev_ctx->simple_devs.begin(), meta_dev_ctx->simple_devs.end(),
156
0
        [op](ggml_backend_dev_t simple_dev) { return ggml_backend_dev_supports_op(simple_dev, op); });
157
0
}
158
159
0
static bool ggml_backend_meta_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
160
0
    GGML_ASSERT(ggml_backend_dev_is_meta(dev));
161
0
    ggml_backend_dev_t dev_buft = ggml_backend_buft_get_device(buft);
162
0
    if (!ggml_backend_dev_is_meta(dev_buft)) {
163
0
        return false;
164
0
    }
165
0
    const ggml_backend_meta_device_context * meta_dev_ctx      = (const ggml_backend_meta_device_context *) dev->context;
166
0
    const ggml_backend_meta_device_context * meta_buft_dev_ctx = (const ggml_backend_meta_device_context *) dev_buft->context;
167
0
    if (meta_dev_ctx->simple_devs.size() != meta_buft_dev_ctx->simple_devs.size()) {
168
0
        return false;
169
0
    }
170
0
    for (size_t i = 0; i < meta_dev_ctx->simple_devs.size(); i++) {
171
0
        if (meta_dev_ctx->simple_devs[i] != meta_buft_dev_ctx->simple_devs[i]) {
172
0
            return false;
173
0
        }
174
0
    }
175
0
    return true;
176
0
}
177
178
static const ggml_backend_device_i ggml_backend_meta_device_iface = {
179
    /* .get_name             = */ ggml_backend_meta_device_get_name,
180
    /* .get_description      = */ ggml_backend_meta_device_get_description,
181
    /* .get_memory           = */ ggml_backend_meta_device_get_memory,
182
    /* .get_type             = */ ggml_backend_meta_device_get_type,
183
    /* .get_props            = */ ggml_backend_meta_device_get_props,
184
    /* .init_backend         = */ ggml_backend_meta_device_init_backend,
185
    /* .get_buffer_type      = */ ggml_backend_meta_device_get_buffer_type,
186
    /* .get_host_buffer_type = */ ggml_backend_meta_device_get_host_buffer_type,
187
    /* .buffer_from_host_ptr = */ nullptr,
188
    /* .supports_op          = */ ggml_backend_meta_device_supports_op,
189
    /* .supports_buft        = */ ggml_backend_meta_device_supports_buft,
190
    /* .offload_op           = */ nullptr,
191
    /* .event_new            = */ nullptr,
192
    /* .event_free           = */ nullptr,
193
    /* .event_synchronize    = */ nullptr,
194
};
195
196
0
static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev) {
197
0
    return dev != nullptr && dev->iface.get_name == ggml_backend_meta_device_iface.get_name;
198
0
}
199
200
0
static size_t ggml_backend_meta_dev_n_devs(ggml_backend_dev_t meta_dev) {
201
0
    GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev));
202
0
    const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context;
203
0
    return meta_dev_ctx->simple_devs.size();
204
0
}
205
206
0
static ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev, size_t index) {
207
0
    GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev));
208
0
    const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context;
209
0
    GGML_ASSERT(index < meta_dev_ctx->simple_devs.size());
210
0
    return meta_dev_ctx->simple_devs[index];
211
0
}
212
213
ggml_backend_dev_t ggml_backend_meta_device(
214
0
        ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) {
215
0
    GGML_ASSERT(n_devs <= GGML_BACKEND_META_MAX_DEVICES);
216
    // TODO: this is not thread-safe - needs to be fixed
217
0
    static std::vector<std::unique_ptr<ggml_backend_meta_device_context>>         ctxs;
218
0
    static std::map<ggml_backend_meta_device_context, struct ggml_backend_device> meta_devs;
219
220
0
    std::vector<ggml_backend_dev_t> simple_devs;
221
0
    simple_devs.reserve(n_devs);
222
0
    for (size_t i = 0; i < n_devs; i++) {
223
0
        simple_devs.push_back(devs[i]);
224
0
    }
225
0
    ggml_backend_meta_device_context ctx(simple_devs, get_split_state, get_split_state_ud);
226
227
0
    {
228
0
        auto it = meta_devs.find(ctx);
229
0
        if (it != meta_devs.end()) {
230
0
            return &it->second;
231
0
        }
232
0
    }
233
0
    ctxs.push_back(std::make_unique<ggml_backend_meta_device_context>(ctx));
234
235
0
    struct ggml_backend_device meta_dev = {
236
0
        /*iface  =*/ ggml_backend_meta_device_iface,
237
0
        /*reg    =*/ nullptr,
238
0
        /*ctx    =*/ ctxs.back().get(),
239
0
    };
240
241
0
    auto result = meta_devs.emplace(*ctxs.back(), meta_dev);
242
0
    return &result.first->second;
243
0
}
244
245
//
246
// meta backend buffer type
247
//
248
249
struct ggml_backend_meta_buffer_type_context {
250
    std::vector<ggml_backend_buffer_type_t> simple_bufts;
251
252
    std::string name;
253
254
0
    ggml_backend_meta_buffer_type_context(std::vector<ggml_backend_buffer_type_t> simple_bufts) : simple_bufts(std::move(simple_bufts)) {
255
0
        name = "Meta(";
256
0
        for (size_t i = 0; i < simple_bufts.size(); i++) {
257
0
            if (i > 0) {
258
0
                name += ",";
259
0
            }
260
0
            name += ggml_backend_buft_name(simple_bufts[i]);
261
0
        }
262
0
        name += ")";
263
0
    }
264
265
0
    bool operator<(const ggml_backend_meta_buffer_type_context & other) const {
266
0
        return simple_bufts < other.simple_bufts;
267
0
    }
268
};
269
270
0
static size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft) {
271
0
    GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft));
272
0
    const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context;
273
0
    return meta_buft_ctx->simple_bufts.size();
274
0
}
275
276
0
static const char * ggml_backend_meta_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
277
0
    GGML_ASSERT(ggml_backend_buft_is_meta(buft));
278
0
    const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) buft->context;
279
0
    return meta_buft_ctx->name.c_str();
280
0
}
281
282
0
static ggml_backend_buffer_type_t ggml_backend_meta_buft_simple_buft(ggml_backend_buffer_type_t meta_buft, size_t index) {
283
0
    GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft));
284
0
    const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context;
285
0
    GGML_ASSERT(index < meta_buft_ctx->simple_bufts.size());
286
0
    return meta_buft_ctx->simple_bufts[index];
287
0
}
288
289
static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
290
291
0
static size_t ggml_backend_meta_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
292
0
    const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
293
0
    size_t max_alignment = 1;
294
0
    for (size_t i = 0; i < n_simple_bufts; i++) {
295
0
        const size_t alignment = ggml_backend_buft_get_alignment(ggml_backend_meta_buft_simple_buft(buft, i));
296
0
        max_alignment = std::max(max_alignment, alignment);
297
0
        GGML_ASSERT(max_alignment % alignment == 0);
298
0
    }
299
0
    return max_alignment;
300
0
}
301
302
0
static size_t ggml_backend_meta_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
303
0
    const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
304
0
    size_t max_size = SIZE_MAX;
305
0
    for (size_t i = 0; i < n_simple_bufts; i++) {
306
0
        max_size = std::min(max_size, ggml_backend_buft_get_max_size(ggml_backend_meta_buft_simple_buft(buft, i)));
307
0
    }
308
0
    return max_size;
309
0
}
310
311
0
static size_t ggml_backend_meta_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
312
0
    const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
313
0
    size_t max_alloc_size = 0;
314
0
    for (size_t i = 0; i < n_simple_bufts; i++) {
315
0
        const size_t alloc_size = ggml_backend_buft_get_alloc_size(ggml_backend_meta_buft_simple_buft(buft, i), tensor);
316
0
        max_alloc_size = std::max(max_alloc_size, alloc_size);
317
0
    }
318
0
    return max_alloc_size;
319
0
}
320
321
0
static bool ggml_backend_meta_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
322
0
    const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
323
0
    for (size_t i = 0; i < n_simple_bufts; i++) {
324
0
        if (!ggml_backend_buft_is_host(ggml_backend_meta_buft_simple_buft(buft, i))) {
325
0
            return false;
326
0
        }
327
0
    }
328
0
    return true;
329
0
}
330
331
static const struct ggml_backend_buffer_type_i ggml_backend_meta_buffer_type_iface = {
332
    /* .get_name         = */ ggml_backend_meta_buffer_type_get_name,
333
    /* .alloc_buffer     = */ ggml_backend_meta_buffer_type_alloc_buffer,
334
    /* .get_alignment    = */ ggml_backend_meta_buffer_type_get_alignment,
335
    /* .get_max_size     = */ ggml_backend_meta_buffer_type_get_max_size,
336
    /* .get_alloc_size   = */ ggml_backend_meta_buffer_type_get_alloc_size,
337
    /* .is_host          = */ ggml_backend_meta_buffer_type_is_host,
338
};
339
340
0
bool ggml_backend_buft_is_meta(ggml_backend_buffer_type_t buft) {
341
0
    return buft != nullptr && buft->iface.get_name == ggml_backend_meta_buffer_type_iface.get_name;
342
0
}
343
344
0
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev) {
345
0
    static std::map<ggml_backend_dev_t, struct ggml_backend_buffer_type> meta_bufts;
346
0
    GGML_ASSERT(ggml_backend_dev_is_meta(dev));
347
0
    {
348
0
        auto it = meta_bufts.find(dev);
349
0
        if (it != meta_bufts.end()) {
350
0
            return &it->second;
351
0
        }
352
0
    }
353
354
0
    const size_t n_devs = ggml_backend_meta_dev_n_devs(dev);
355
0
    std::vector<ggml_backend_buffer_type_t> simple_bufts;
356
0
    simple_bufts.reserve(n_devs);
357
0
    for (size_t i = 0; i < n_devs; i++) {
358
0
        simple_bufts.push_back(ggml_backend_dev_buffer_type(ggml_backend_meta_dev_simple_dev(dev, i)));
359
0
    }
360
0
    ggml_backend_meta_buffer_type_context * buft_ctx = new ggml_backend_meta_buffer_type_context(simple_bufts);
361
362
0
    struct ggml_backend_buffer_type meta_buft = {
363
0
        /*iface  =*/ ggml_backend_meta_buffer_type_iface,
364
0
        /*device =*/ dev,
365
0
        /*ctx    =*/ buft_ctx,
366
0
    };
367
0
    auto result = meta_bufts.emplace(dev, meta_buft);
368
0
    return &result.first->second;
369
0
}
370
371
0
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev) {
372
0
    GGML_ASSERT(ggml_backend_dev_is_meta(dev));
373
0
    const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
374
375
0
    ggml_backend_buffer_type_t host_buft = nullptr;
376
0
    for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) {
377
0
        ggml_backend_buffer_type_t simple_host_buft = ggml_backend_dev_host_buffer_type(simple_dev);
378
0
        if (simple_host_buft == nullptr) {
379
0
            return nullptr;
380
0
        }
381
0
        if (host_buft == nullptr) {
382
0
            host_buft = simple_host_buft;
383
0
        } else if (host_buft != simple_host_buft) {
384
            // if different simple devices have different host buffer types,
385
            // we cannot provide a single host buffer type for the meta device
386
0
            return nullptr;
387
0
        }
388
0
    }
389
0
    return host_buft;
390
0
}
391
392
//
393
// meta backend buffer
394
//
395
396
// Container to hold the tensor slices per simple ggml backend buffer.
397
struct ggml_backend_meta_simple_tensor_container {
398
    std::vector<ggml_context_ptr> ctxs;
399
    std::map<const ggml_tensor *, std::vector<ggml_tensor *>> simple_tensors;
400
401
0
    ggml_backend_meta_simple_tensor_container(const ggml_init_params & params, const int n_simple) {
402
0
        ctxs.reserve(n_simple);
403
0
        for (int i = 0; i < n_simple; i++) {
404
0
            ctxs.emplace_back(ggml_init(params));
405
0
        }
406
0
    }
407
0
    ggml_backend_meta_simple_tensor_container() {}
408
};
409
410
struct ggml_backend_meta_buffer_context {
411
    // FIXME
412
    // Most tensors can simply be stored statically in their own buffer.
413
    // Externally created views however also need a mapping to simple tensors but they use the buffer of the view source.
414
    // If external views are simply using that buffer they will slowly deplete its memory.
415
    // Current solution: rotating set of 2 "compute" containers to hold external views, works correctly for llama.cpp.
416
    // Long-term: tie the lifetime of external views to the meta backend executing the graph instead,
417
    //     currently not possible due to graph-external operations in the backend scheduler.
418
    ggml_backend_meta_simple_tensor_container stc_static;
419
    ggml_backend_meta_simple_tensor_container stc_compute[2];
420
    int stc_compute_index      = 0;
421
    int stc_compute_index_next = 0;
422
    std::vector<ggml_backend_buffer_ptr> bufs;
423
424
    // FIXME
425
    // The size of the split state cache is unbounded and can theoretically grow infinitely large.
426
    // However, it is also expensive to build and clearing it on every rebuild in ggml_backend_meta_graph_compute is too expensive.
427
    static constexpr size_t nbtc = GGML_TENSOR_SIZE - sizeof(ggml_tensor::padding);
428
    std::map<std::pair<const ggml_tensor *, bool>, std::pair<ggml_backend_meta_split_state, char[nbtc]>> split_state_cache;
429
430
    int debug;
431
432
    ggml_backend_meta_buffer_context(
433
            ggml_backend_meta_simple_tensor_container & stc_static,
434
            ggml_backend_meta_simple_tensor_container & stc_compute_0,
435
            ggml_backend_meta_simple_tensor_container & stc_compute_1,
436
            const std::vector<ggml_backend_buffer_t> & bufs)
437
0
            : stc_static(std::move(stc_static)), stc_compute{std::move(stc_compute_0), std::move(stc_compute_1)} {
438
0
        this->bufs.reserve(bufs.size());
439
0
        for (ggml_backend_buffer_t buf : bufs) {
440
0
            this->bufs.emplace_back(buf);
441
0
        }
442
0
        const char * GGML_META_DEBUG = getenv("GGML_META_DEBUG");
443
0
        debug = GGML_META_DEBUG ? atoi(GGML_META_DEBUG) : 0;
444
0
    }
445
446
0
    ggml_backend_meta_simple_tensor_container & get_simple_tensor_container(const ggml_tensor * tensor) {
447
0
        if (stc_static.simple_tensors.find(tensor) != stc_static.simple_tensors.end()) {
448
0
            return stc_static;
449
0
        }
450
0
        return stc_compute[stc_compute_index];
451
0
    }
452
};
453
454
0
static void ggml_backend_meta_buffer_free_buffer(ggml_backend_buffer_t buffer) {
455
0
    GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
456
0
    ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
457
0
    delete buf_ctx;
458
0
}
459
460
0
static size_t ggml_backend_meta_buffer_n_bufs(ggml_backend_buffer_t meta_buf) {
461
0
    GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf));
462
0
    ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context;
463
0
    return buf_ctx->bufs.size();
464
0
}
465
466
0
static ggml_backend_buffer_t ggml_backend_meta_buffer_simple_buffer(ggml_backend_buffer_t meta_buf, size_t index) {
467
0
    GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf));
468
0
    ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context;
469
0
    GGML_ASSERT(index < buf_ctx->bufs.size());
470
0
    return buf_ctx->bufs[index].get();
471
0
}
472
473
0
static struct ggml_tensor * ggml_backend_meta_buffer_simple_tensor(const struct ggml_tensor * tensor, size_t index) {
474
0
    GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
475
0
    ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
476
0
    GGML_ASSERT(index < buf_ctx->bufs.size());
477
478
0
    ggml_backend_meta_simple_tensor_container & stc = buf_ctx->get_simple_tensor_container(tensor);
479
0
    auto it = stc.simple_tensors.find(tensor);
480
0
    if (it == stc.simple_tensors.end()) {
481
0
        return nullptr;
482
0
    }
483
0
    return it->second[index];
484
0
}
485
486
static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync);
487
488
static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(
489
0
        ggml_backend_meta_simple_tensor_container & stc, const struct ggml_tensor * tensor, bool assume_sync) {
490
    // FIXME Currently this function preserves/erases the information in n_segments and nr in an inconsistent way.
491
    // Since the operations in question are developed specifically for llama.cpp this currently does not manifest as a bug there.
492
    // However, in a broader ggml context with arbitrary ggml graphs this can lead to unexpected results.
493
0
    const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
494
0
    ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
495
496
0
    auto split_states_equal = [&](const ggml_backend_meta_split_state & a, const ggml_backend_meta_split_state & b) -> bool {
497
0
        if (a.axis != b.axis) {
498
0
            return false;
499
0
        }
500
0
        for (size_t j = 0; j < n_bufs; j++) {
501
0
            int64_t sum_a = 0;
502
0
            for (size_t s = 0; s < a.n_segments; s++) {
503
0
                sum_a += a.ne[s*n_bufs + j] * a.nr[s];
504
0
            }
505
0
            int64_t sum_b = 0;
506
0
            for (size_t s = 0; s < b.n_segments; s++) {
507
0
                sum_b += b.ne[s*n_bufs + j] * b.nr[s];
508
0
            }
509
0
            if (sum_a != sum_b) {
510
0
                return false;
511
0
            }
512
0
        }
513
0
        return true;
514
0
    };
515
516
0
    auto handle_generic = [&](const std::vector<ggml_backend_meta_split_state> & src_ss, bool scalar_only) -> ggml_backend_meta_split_state {
517
0
        ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1};
518
0
        for (size_t i = 0; i < GGML_MAX_SRC; i++) {
519
0
            if (tensor->src[i] == nullptr || tensor->src[i] == tensor) {
520
0
                continue;
521
0
            }
522
0
            if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) {
523
0
                ret = src_ss[i];
524
0
            } else if (!split_states_equal(src_ss[i], ret)) {
525
0
                ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
526
0
                break;
527
0
            }
528
0
        }
529
0
        if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) {
530
0
            ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
531
0
        }
532
0
        if (scalar_only && ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) {
533
0
            ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
534
0
        }
535
0
        GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
536
0
        return ret;
537
0
    };
538
539
    // Some ops process data on a per-row bases:
540
0
    auto handle_per_row = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
541
0
        GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_0);
542
0
        return src_ss[0];
543
0
    };
544
545
    // Some ops broadcast the src1 data across src0:
546
0
    auto handle_bin_bcast = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
547
0
        if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS &&
548
0
                tensor->src[1]->ne[src_ss[0].axis] == 1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
549
0
            return src_ss[0];
550
0
        }
551
0
        if (src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[0].axis == src_ss[1].axis ||
552
0
           (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL)))) {
553
0
            return src_ss[0]; // GGML_OP_ADD_ID
554
0
        }
555
0
        GGML_ASSERT(tensor->src[2] == nullptr || src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
556
0
        return handle_generic(src_ss, /*scalar_only =*/ false);
557
0
    };
558
559
0
    auto handle_concat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
560
0
        const ggml_backend_meta_split_axis concat_axis = ggml_backend_meta_split_axis(ggml_get_op_params_i32(tensor, 0));
561
0
        if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis >= 0 && src_ss[1].axis < GGML_MAX_DIMS) {
562
0
            GGML_ASSERT(concat_axis != src_ss[1].axis);
563
0
            return src_ss[1];
564
0
        }
565
0
        if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) {
566
0
            GGML_ASSERT(concat_axis != src_ss[0].axis);
567
0
            return src_ss[0];
568
0
        }
569
0
        if (src_ss[0].axis == src_ss[1].axis && src_ss[0].axis != concat_axis) {
570
0
            return src_ss[0];
571
0
        }
572
0
        return handle_generic(src_ss, /*scalar_only =*/ true);
573
0
    };
574
575
0
    auto handle_mul_mat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
576
0
        if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
577
0
            return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1};
578
0
        }
579
0
        if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
580
0
            ggml_backend_meta_split_state ret = src_ss[0];
581
0
            ret.axis = GGML_BACKEND_SPLIT_AXIS_0;
582
0
            ret.nr[0] = 1;
583
0
            ret.n_segments = 1;
584
0
            return ret;
585
0
        }
586
0
        if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
587
0
            return src_ss[1];
588
0
        }
589
0
        if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_0) {
590
0
            GGML_ASSERT(split_states_equal(src_ss[0], src_ss[1]));
591
0
            return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, {1}, 1};
592
0
        }
593
0
        GGML_ABORT("fatal error");
594
        //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
595
0
    };
596
597
0
    auto handle_reshape = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
598
0
        switch (src_ss[0].axis) {
599
0
            case GGML_BACKEND_SPLIT_AXIS_0:
600
0
            case GGML_BACKEND_SPLIT_AXIS_1:
601
0
            case GGML_BACKEND_SPLIT_AXIS_2:
602
0
            case GGML_BACKEND_SPLIT_AXIS_3: {
603
0
                GGML_ASSERT(src_ss[0].n_segments == 1);
604
0
                if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1 && src_ss[0].nr[0] == 1) {
605
0
                    return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, {1}, 1};
606
0
                }
607
0
                int64_t base_ne_in = tensor->src[0]->ne[0];
608
0
                for (int dim = 1; dim <= src_ss[0].axis; dim++) {
609
0
                    base_ne_in *= tensor->src[0]->ne[dim];
610
0
                }
611
0
                base_ne_in /= src_ss[0].nr[0];
612
0
                int64_t base_ne_out = 1;
613
0
                for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
614
0
                    const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim];
615
0
                    if (base_ne_out_next % base_ne_in == 0) {
616
0
                        return {ggml_backend_meta_split_axis(dim), {0}, {uint32_t(base_ne_out_next/base_ne_in)}, 1};
617
0
                    }
618
0
                    if (base_ne_out_next > base_ne_in) {
619
0
                        GGML_ASSERT(src_ss[0].n_segments == 1);
620
0
                        GGML_ASSERT(src_ss[0].nr[0]      == 1);
621
0
                        return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1};
622
0
                    }
623
0
                    base_ne_out = base_ne_out_next;
624
0
                }
625
0
                GGML_ABORT("shape mismatch for %s", ggml_op_name(tensor->op));
626
0
            }
627
0
            case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
628
0
            case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
629
0
                return src_ss[0];
630
0
            }
631
0
            default: {
632
0
                GGML_ABORT("fatal error");
633
                //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
634
0
            }
635
0
        }
636
0
    };
637
638
0
    auto handle_cpy = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
639
0
        if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) {
640
0
            return handle_reshape(src_ss);
641
0
        }
642
0
        return handle_generic(src_ss, /*scalar_only =*/ false);
643
0
    };
644
645
0
    auto handle_view = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
646
0
        if (ggml_is_contiguous(tensor) && ggml_is_contiguous(tensor->src[0])) {
647
0
            return handle_reshape(src_ss);
648
0
        }
649
0
        const int axis = src_ss[0].axis;
650
0
        {
651
0
            bool all_strides_the_same = true;
652
0
            for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
653
0
                if (tensor->ne[dim] == 1 && tensor->src[0]->ne[dim] == 1) {
654
0
                    continue;
655
0
                }
656
0
                if (tensor->nb[dim] != tensor->src[0]->nb[dim]) {
657
0
                    all_strides_the_same = false;
658
0
                    break;
659
0
                }
660
0
            }
661
0
            if (all_strides_the_same) {
662
0
                return src_ss[0];
663
0
            }
664
0
        }
665
0
        if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]) && axis >= 0 && axis < GGML_MAX_DIMS-1) {
666
0
            for (int dim = 0; dim < GGML_MAX_DIMS-1; dim++) {
667
0
                if (tensor->nb[dim+1] == tensor->src[0]->nb[axis+1]) {
668
0
                    return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1};
669
0
                }
670
0
            }
671
0
            GGML_ABORT("fatal error");
672
0
        }
673
0
        if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED || src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) {
674
0
            return src_ss[0];
675
0
        }
676
0
        GGML_ABORT("view of permuted tensor not implemented");
677
        //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
678
0
    };
679
680
0
    auto handle_permute = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
681
0
        switch (src_ss[0].axis) {
682
0
            case GGML_BACKEND_SPLIT_AXIS_0:
683
0
            case GGML_BACKEND_SPLIT_AXIS_1:
684
0
            case GGML_BACKEND_SPLIT_AXIS_2:
685
0
            case GGML_BACKEND_SPLIT_AXIS_3: {
686
0
                GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1);
687
0
                return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, {src_ss[0].nr[0]}, 1};
688
0
            }
689
0
            case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
690
0
            case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
691
0
                return src_ss[0];
692
0
            }
693
0
            default: {
694
0
                GGML_ABORT("fatal error");
695
                //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
696
0
            }
697
0
        }
698
0
    };
699
700
0
    auto handle_transpose = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
701
0
        switch (src_ss[0].axis) {
702
0
            case GGML_BACKEND_SPLIT_AXIS_0:
703
0
            case GGML_BACKEND_SPLIT_AXIS_1: {
704
0
                GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1);
705
0
                return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, {src_ss[0].nr[0]}, 1};
706
0
            }
707
0
            case GGML_BACKEND_SPLIT_AXIS_2:
708
0
            case GGML_BACKEND_SPLIT_AXIS_3:
709
0
            case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
710
0
            case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
711
0
                return src_ss[0];
712
0
            }
713
0
            default: {
714
0
                GGML_ABORT("fatal error");
715
                //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
716
0
            }
717
0
        }
718
0
    };
719
720
0
    auto handle_get_rows = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
721
0
        if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
722
0
            return src_ss[0];
723
0
        }
724
0
        return handle_generic(src_ss, /*scalar_only =*/ true);
725
0
    };
726
727
0
    auto handle_set_rows = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
728
0
        GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_1);
729
0
        GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
730
0
        GGML_ASSERT(split_states_equal(src_ss[0], src_ss[2]));
731
0
        return src_ss[0];
732
0
    };
733
734
0
    auto handle_rope = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
735
0
        GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
736
0
        return src_ss[0];
737
0
    };
738
739
0
    auto handle_pad = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
740
0
        if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) {
741
0
            GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 0] == 0);
742
0
            GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 1] == 0);
743
0
        }
744
0
        return src_ss[0];
745
0
    };
746
747
0
    auto handle_flash_attn_ext = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
748
0
        GGML_ASSERT(                             src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_2);
749
0
        GGML_ASSERT(                             src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_2);
750
0
        GGML_ASSERT(                             src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_2);
751
0
        GGML_ASSERT(tensor->src[4] == nullptr || src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
752
0
        GGML_ASSERT(tensor->src[4] == nullptr || src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_0);
753
0
        return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1};
754
0
    };
755
756
0
    auto handle_ssm_conv = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
757
0
        if (src_ss[0].axis == src_ss[1].axis) {
758
0
            if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0) {
759
0
                return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1};
760
0
            }
761
0
            if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1) {
762
0
                return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1};
763
0
            }
764
0
        }
765
0
        return handle_generic(src_ss, /*scalar_only =*/ false);
766
0
    };
767
768
0
    auto handle_gated_delta_net = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
769
0
        if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED &&
770
0
                src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED &&
771
0
                src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
772
0
            return src_ss[0];
773
0
        }
774
0
        GGML_ASSERT(src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1);
775
0
        GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1);
776
0
        GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1);
777
0
        GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1);
778
0
        GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1);
779
        // state shape is [S_v, S_v, H_v, n_seqs] (s0 only); the heads dim is its own axis 2,
780
        // so a head-aligned split on the input cache lands on axis 2 here.
781
0
        GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0);
782
0
        return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1};
783
0
    };
784
785
0
    auto calculate_split_state = [&]() -> ggml_backend_meta_split_state {
786
0
        if (ggml_nelements(tensor) == 0) {
787
0
            return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
788
0
        }
789
0
        if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) {
790
0
            ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer));
791
0
            const ggml_backend_meta_device_context * dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
792
0
            ggml_backend_meta_split_state ret = dev_ctx->get_split_state(tensor, dev_ctx->get_split_state_ud);
793
0
            if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) {
794
0
                const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1;
795
0
                int64_t ne_sum = 0;
796
0
                for (size_t s = 0; s < ret.n_segments; s++) {
797
0
                    for (size_t j = 0; j < n_bufs; j++) {
798
0
                        GGML_ASSERT(ret.ne[s*n_bufs + j] % granularity == 0);
799
0
                        ne_sum += ret.ne[s*n_bufs + j] * ret.nr[s];
800
0
                    }
801
0
                }
802
0
                GGML_ASSERT(ne_sum == tensor->ne[ret.axis]);
803
0
            }
804
0
            return ret;
805
0
        }
806
807
0
        std::vector<ggml_backend_meta_split_state> src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1});
808
0
        for (size_t i = 0; i < GGML_MAX_SRC; i++) {
809
0
            if (tensor->src[i] == nullptr || tensor->src[i] == tensor) {
810
0
                src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
811
0
                continue;
812
0
            }
813
0
            src_ss[i] = ggml_backend_meta_get_split_state(stc, tensor->src[i], /*assume_sync =*/ true);
814
0
            GGML_ASSERT(src_ss[i].axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
815
0
        }
816
817
0
        ggml_backend_meta_split_state split_state;
818
0
        switch (tensor->op) {
819
0
            case GGML_OP_NONE: {
820
0
                split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1};
821
0
            } break;
822
0
            case GGML_OP_DUP: {
823
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
824
0
            } break;
825
0
            case GGML_OP_ADD:
826
0
            case GGML_OP_ADD_ID: {
827
0
                split_state = handle_bin_bcast(src_ss);
828
0
            } break;
829
0
            case GGML_OP_ADD1:
830
0
            case GGML_OP_ACC: {
831
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
832
0
            } break;
833
0
            case GGML_OP_SUB:
834
0
            case GGML_OP_MUL:
835
0
            case GGML_OP_DIV: {
836
0
                split_state = handle_bin_bcast(src_ss);
837
0
            } break;
838
0
            case GGML_OP_SQR:
839
0
            case GGML_OP_SQRT:
840
0
            case GGML_OP_LOG:
841
0
            case GGML_OP_SIN:
842
0
            case GGML_OP_COS: {
843
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
844
0
            } break;
845
0
            case GGML_OP_SUM: {
846
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
847
0
            } break;
848
0
            case GGML_OP_SUM_ROWS:
849
0
            case GGML_OP_CUMSUM:
850
0
            case GGML_OP_MEAN:
851
0
            case GGML_OP_ARGMAX:
852
0
            case GGML_OP_COUNT_EQUAL: {
853
0
                split_state = handle_per_row(src_ss);
854
0
            } break;
855
0
            case GGML_OP_REPEAT:
856
0
            case GGML_OP_REPEAT_BACK: {
857
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
858
0
            } break;
859
0
            case GGML_OP_CONCAT: {
860
0
                split_state = handle_concat(src_ss);
861
0
            } break;
862
0
            case GGML_OP_SILU_BACK: {
863
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
864
0
            } break;
865
0
            case GGML_OP_NORM:
866
0
            case GGML_OP_RMS_NORM:
867
0
            case GGML_OP_RMS_NORM_BACK:
868
0
            case GGML_OP_GROUP_NORM:
869
0
            case GGML_OP_L2_NORM: {
870
0
                split_state = handle_per_row(src_ss);
871
0
            } break;
872
0
            case GGML_OP_MUL_MAT:
873
0
            case GGML_OP_MUL_MAT_ID: {
874
0
                split_state = handle_mul_mat(src_ss);
875
0
            } break;
876
0
            case GGML_OP_OUT_PROD: {
877
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
878
0
            } break;
879
0
            case GGML_OP_SCALE: {
880
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
881
0
            } break;
882
0
            case GGML_OP_SET: {
883
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
884
0
            } break;
885
0
            case GGML_OP_CPY: {
886
0
                split_state = handle_cpy(src_ss);
887
0
            } break;
888
0
            case GGML_OP_CONT:
889
0
            case GGML_OP_RESHAPE: {
890
0
                split_state = handle_reshape(src_ss);
891
0
            } break;
892
0
            case GGML_OP_VIEW: {
893
0
                split_state = handle_view(src_ss);
894
0
            } break;
895
0
            case GGML_OP_PERMUTE: {
896
0
                split_state = handle_permute(src_ss);
897
0
            } break;
898
0
            case GGML_OP_TRANSPOSE: {
899
0
                split_state = handle_transpose(src_ss);
900
0
            } break;
901
0
            case GGML_OP_GET_ROWS: {
902
0
                split_state = handle_get_rows(src_ss);
903
0
            } break;
904
0
            case GGML_OP_GET_ROWS_BACK: {
905
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
906
0
            } break;
907
0
            case GGML_OP_SET_ROWS: {
908
0
                split_state = handle_set_rows(src_ss);
909
0
            } break;
910
0
            case GGML_OP_DIAG:
911
0
            case GGML_OP_DIAG_MASK_INF:
912
0
            case GGML_OP_DIAG_MASK_ZERO: {
913
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
914
0
            } break;
915
0
            case GGML_OP_SOFT_MAX:
916
0
            case GGML_OP_SOFT_MAX_BACK: {
917
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
918
0
            } break;
919
0
            case GGML_OP_ROPE: {
920
0
                split_state = handle_rope(src_ss);
921
0
            } break;
922
0
            case GGML_OP_ROPE_BACK: {
923
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
924
0
            } break;
925
0
            case GGML_OP_CLAMP: {
926
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
927
0
            } break;
928
0
            case GGML_OP_CONV_TRANSPOSE_1D:
929
0
            case GGML_OP_IM2COL:
930
0
            case GGML_OP_IM2COL_BACK:
931
0
            case GGML_OP_IM2COL_3D:
932
0
            case GGML_OP_CONV_2D:
933
0
            case GGML_OP_CONV_3D:
934
0
            case GGML_OP_CONV_2D_DW:
935
0
            case GGML_OP_CONV_TRANSPOSE_2D:
936
0
            case GGML_OP_POOL_1D:
937
0
            case GGML_OP_POOL_2D:
938
0
            case GGML_OP_POOL_2D_BACK:
939
0
            case GGML_OP_UPSCALE: {
940
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
941
0
            } break;
942
0
            case GGML_OP_PAD: {
943
0
                split_state = handle_pad(src_ss);
944
0
            } break;
945
0
            case GGML_OP_PAD_REFLECT_1D:
946
0
            case GGML_OP_ROLL:
947
0
            case GGML_OP_ARANGE:
948
0
            case GGML_OP_TIMESTEP_EMBEDDING: {
949
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
950
0
            } break;
951
0
            case GGML_OP_ARGSORT:
952
0
            case GGML_OP_TOP_K: {
953
0
                split_state = handle_per_row(src_ss);
954
0
            } break;
955
0
            case GGML_OP_LEAKY_RELU: {
956
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
957
0
            } break;
958
0
            case GGML_OP_TRI: {
959
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
960
0
            } break;
961
0
            case GGML_OP_FILL: {
962
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
963
0
            } break;
964
0
            case GGML_OP_FLASH_ATTN_EXT: {
965
0
                split_state = handle_flash_attn_ext(src_ss);
966
0
            } break;
967
0
            case GGML_OP_FLASH_ATTN_BACK: {
968
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
969
0
            } break;
970
0
            case GGML_OP_SSM_CONV: {
971
0
                split_state = handle_ssm_conv(src_ss);
972
0
            } break;
973
0
            case GGML_OP_SSM_SCAN:
974
0
            case GGML_OP_WIN_PART:
975
0
            case GGML_OP_WIN_UNPART:
976
0
            case GGML_OP_GET_REL_POS:
977
0
            case GGML_OP_ADD_REL_POS:
978
0
            case GGML_OP_RWKV_WKV6:
979
0
            case GGML_OP_GATED_LINEAR_ATTN:
980
0
            case GGML_OP_RWKV_WKV7:
981
0
            case GGML_OP_SOLVE_TRI: {
982
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
983
0
            } break;
984
0
            case GGML_OP_GATED_DELTA_NET: {
985
0
                split_state = handle_gated_delta_net(src_ss);
986
0
            } break;
987
0
            case GGML_OP_UNARY: {
988
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
989
0
            } break;
990
0
            case GGML_OP_MAP_CUSTOM1:
991
0
            case GGML_OP_MAP_CUSTOM2:
992
0
            case GGML_OP_MAP_CUSTOM3:
993
0
            case GGML_OP_CUSTOM: {
994
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ true);
995
0
            } break;
996
0
            case GGML_OP_CROSS_ENTROPY_LOSS:
997
0
            case GGML_OP_CROSS_ENTROPY_LOSS_BACK: {
998
0
                split_state = handle_per_row(src_ss);
999
0
            } break;
1000
0
            case GGML_OP_OPT_STEP_ADAMW:
1001
0
            case GGML_OP_OPT_STEP_SGD:
1002
0
            case GGML_OP_GLU: {
1003
0
                split_state = handle_generic(src_ss, /*scalar_only =*/ false);
1004
0
            } break;
1005
0
            default: {
1006
0
                GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op));
1007
0
                split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
1008
0
            } break;
1009
0
        }
1010
0
        if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
1011
0
            bool first_src_split_by_axis = true;
1012
0
            const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
1013
1014
0
            for (size_t i = 0; i < GGML_MAX_SRC; i++) {
1015
0
                if (tensor->src[i] == nullptr || src_ss[i].axis < 0 || src_ss[i].axis >= GGML_MAX_DIMS) {
1016
0
                    continue;
1017
0
                }
1018
0
                if (first_src_split_by_axis) {
1019
0
                    for (size_t j = 0; j < n_bufs; j++) {
1020
                        // Take over ratio from src:
1021
0
                        for (size_t s = 0; s < src_ss[i].n_segments; s++) {
1022
0
                            split_state.ne[s*n_bufs + j] = 0;
1023
0
                        }
1024
0
                        for (size_t s = 0; s < src_ss[i].n_segments; s++) {
1025
0
                            split_state.ne[j] += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s];
1026
0
                        }
1027
0
                        split_state.ne[j] *= tensor->ne[split_state.axis];
1028
0
                        if (split_state.ne[j] != 0 || tensor->src[i]->ne[src_ss[i].axis] != 0) {
1029
0
                            const int64_t div = tensor->src[i]->ne[src_ss[i].axis] * split_state.nr[0];
1030
0
                            GGML_ASSERT(split_state.ne[j] % div == 0);
1031
0
                            split_state.ne[j] /= div;
1032
0
                        }
1033
0
                    }
1034
0
                } else {
1035
0
                    GGML_ASSERT(split_state.n_segments == 1);
1036
0
                    for (size_t j = 0; j < n_bufs; j++) {
1037
                        // Assert that ratio is consistent:
1038
0
                        int64_t sum = 0;
1039
0
                        for (size_t s = 0; s < src_ss[i].n_segments; s++) {
1040
0
                            sum += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s];
1041
0
                        }
1042
0
                        GGML_ASSERT(split_state.ne[j]*split_state.nr[0] * tensor->src[i]->ne[src_ss[i].axis]
1043
0
                                                                 == sum * tensor->ne[split_state.axis]);
1044
0
                    }
1045
0
                }
1046
0
                first_src_split_by_axis = false;
1047
0
            }
1048
0
            GGML_ASSERT(!first_src_split_by_axis);
1049
0
        }
1050
0
        return split_state;
1051
0
    };
1052
1053
0
    const std::pair key = std::make_pair(tensor, assume_sync);
1054
0
    auto it = buf_ctx->split_state_cache.find(key);
1055
0
    if (it != buf_ctx->split_state_cache.end() && memcmp(it->second.second, (const char *) tensor, sizeof(it->second.second)) != 0) {
1056
0
        buf_ctx->split_state_cache.clear();
1057
0
        it = buf_ctx->split_state_cache.end();
1058
0
    }
1059
1060
0
    if (it == buf_ctx->split_state_cache.end()) {
1061
0
        buf_ctx->split_state_cache[key].first = calculate_split_state();
1062
0
        memcpy(buf_ctx->split_state_cache[key].second, tensor, sizeof(buf_ctx->split_state_cache[key].second));
1063
0
        if (buf_ctx->debug > 0) {
1064
0
            std::string srcs_info;
1065
0
            for (size_t i = 0; i < GGML_MAX_SRC; i++) {
1066
0
                if (tensor->src[i] == nullptr) {
1067
0
                    continue;
1068
0
                }
1069
0
                if (!srcs_info.empty()) {
1070
0
                    srcs_info += ", ";
1071
0
                }
1072
0
                const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true);
1073
0
                GGML_ASSERT(split_state.n_segments == 1);
1074
0
                const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis);
1075
0
                std::string ne_info;
1076
0
                for (size_t j = 0; j < n_bufs; j++) {
1077
0
                    if (!ne_info.empty()) {
1078
0
                        ne_info += ", ";
1079
0
                    }
1080
0
                    ne_info += std::to_string(split_state.ne[j]) + "x" + std::to_string(split_state.nr[0]);
1081
0
                }
1082
0
                srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]";
1083
0
            }
1084
0
            std::string ne_info;
1085
0
            for (size_t j = 0; j < n_bufs; j++) {
1086
0
                if (!ne_info.empty()) {
1087
0
                    ne_info += ", ";
1088
0
                }
1089
0
                const ggml_backend_meta_split_state & ss = buf_ctx->split_state_cache[key].first;
1090
0
                ne_info += std::to_string(ss.ne[j]) + "x" + std::to_string(ss.nr[0]);
1091
0
            }
1092
0
            GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op),
1093
0
                ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].first.axis), ne_info.c_str());
1094
0
        }
1095
0
    }
1096
1097
0
    ggml_backend_meta_split_state ret = buf_ctx->split_state_cache[key].first;
1098
0
    GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_NONE);
1099
#ifndef NDEBUG
1100
    if (ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) {
1101
        int64_t ne_ret = 0;
1102
        for (size_t s = 0; s < ret.n_segments; s++) {
1103
            for (size_t j = 0; j < n_bufs; j++) {
1104
                ne_ret += ret.ne[s*n_bufs + j] * ret.nr[s];
1105
            }
1106
        }
1107
        assert(ne_ret == tensor->ne[int(ret.axis)]);
1108
    }
1109
#endif // NDEBUG
1110
0
    return ret;
1111
0
}
1112
1113
0
static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) {
1114
0
    GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
1115
0
    ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
1116
0
    return ggml_backend_meta_get_split_state(buf_ctx->get_simple_tensor_container(tensor), tensor, assume_sync);
1117
0
}
1118
1119
0
static void * ggml_backend_meta_buffer_get_base(ggml_backend_buffer_t buffer) {
1120
0
    GGML_UNUSED(buffer);
1121
0
    return (void *) 0x1000000000000000; // FIXME
1122
0
}
1123
1124
0
static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_meta_simple_tensor_container & stc, ggml_tensor * tensor) {
1125
0
    GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
1126
0
    ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
1127
0
    const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
1128
1129
0
    const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(stc, tensor, /*assume_sync =*/ true);
1130
0
    GGML_ASSERT(ggml_nelements(tensor) == 0 || split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
1131
0
    GGML_ASSERT(split_state.n_segments <= 16);
1132
1133
0
    int split_dim = split_state.axis;
1134
0
    int64_t ne[GGML_MAX_DIMS];
1135
0
    size_t  nb[GGML_MAX_DIMS];
1136
0
    for (size_t k = 0; k < GGML_MAX_DIMS; k++) {
1137
0
        ne[k] = tensor->ne[k];
1138
0
        nb[k] = tensor->nb[k];
1139
0
    }
1140
1141
0
    std::vector<ggml_tensor *> simple_tensors;
1142
0
    simple_tensors.reserve(n_simple_bufs);
1143
0
    for (size_t j = 0; j < n_simple_bufs; j++) {
1144
0
        ggml_context          * simple_ctx = stc.ctxs[j].get();
1145
0
        ggml_backend_buffer_t   simple_buf = buf_ctx->bufs[j].get();
1146
1147
0
        if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
1148
            // TODO: the following assert fails for llama-parallel even though the results are correct:
1149
            // GGML_ASSERT(ggml_is_contiguously_allocated(tensor));
1150
0
            ne[split_dim] = 0;
1151
0
            for (size_t s = 0; s < split_state.n_segments; s++) {
1152
0
                ne[split_dim] += split_state.ne[s*n_simple_bufs + j] * split_state.nr[s];
1153
0
            }
1154
0
            for (int i = 0; i < GGML_MAX_DIMS; i++) {
1155
0
                if (tensor->nb[i] > tensor->nb[split_dim]) {
1156
0
                    nb[i] = tensor->nb[i] * ne[split_dim]/tensor->ne[split_dim];
1157
0
                }
1158
0
            }
1159
0
        }
1160
1161
0
        ggml_tensor * t_ij = ggml_new_tensor(simple_ctx, tensor->type, GGML_MAX_DIMS, ne);
1162
0
        t_ij->op = tensor->op;
1163
0
        for (int i = 0; i < GGML_MAX_DIMS; i++) {
1164
0
            t_ij->nb[i] = nb[i];
1165
0
        }
1166
0
        t_ij->flags = tensor->flags;
1167
0
        memcpy(t_ij->op_params, tensor->op_params, sizeof(tensor->op_params));
1168
0
        ggml_set_name(t_ij, tensor->name);
1169
0
        t_ij->buffer = simple_buf;
1170
0
        t_ij->view_src = tensor->view_src;
1171
0
        t_ij->view_offs = tensor->view_offs;
1172
0
        if (t_ij->view_src != nullptr && ggml_backend_buffer_is_meta(t_ij->view_src->buffer)) {
1173
0
            t_ij->view_src = ggml_backend_meta_buffer_simple_tensor(tensor->view_src, j);
1174
0
            if (t_ij->view_offs > 0 && split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
1175
0
                GGML_ASSERT(tensor->ne[split_dim] != 0);
1176
0
                const int split_dim_view_src = ggml_backend_meta_get_split_state(tensor->view_src, /*assume_sync =*/ true).axis;
1177
0
                GGML_ASSERT(split_dim_view_src >= 0 && split_dim_view_src < GGML_MAX_DIMS);
1178
1179
                // The offset can be internal to the data split, in those cases the view offset should not be scaled.
1180
                // If however, the offset is larger than the data split then it needs to be scaled proportionally.
1181
0
                bool split_internal_offset = t_ij->view_offs <= tensor->view_src->nb[split_dim_view_src];
1182
0
                for (int i = 0; i < GGML_MAX_DIMS; i++) {
1183
0
                    const size_t dim_size = tensor->ne[i] * tensor->nb[i];
1184
0
                    if (tensor->view_offs <= dim_size && dim_size < tensor->nb[split_dim]) {
1185
0
                        split_internal_offset = true;
1186
0
                        break;
1187
0
                    }
1188
0
                }
1189
0
                if (!split_internal_offset) {
1190
0
                    t_ij->view_offs = t_ij->view_offs * ne[split_dim]/tensor->ne[split_dim];
1191
0
                }
1192
0
            }
1193
0
        }
1194
0
        if (t_ij->view_src != nullptr) {
1195
0
            t_ij->data = (char *) t_ij->view_src->data + t_ij->view_offs;
1196
0
        } else if (simple_buf != nullptr) {
1197
0
            t_ij->data = (char *) ggml_backend_buffer_get_base(simple_buf)
1198
0
                + size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(tensor->buffer));
1199
0
        }
1200
0
        t_ij->extra = tensor->extra;
1201
0
        for (int i = 0; i < GGML_MAX_SRC; i++) {
1202
0
            t_ij->src[i] = tensor->src[i];
1203
0
            if (tensor->src[i] == tensor) {
1204
0
                t_ij->src[i] = t_ij;
1205
0
            } else if (t_ij->src[i] != nullptr && ggml_backend_buffer_is_meta(t_ij->src[i]->buffer)) {
1206
0
                t_ij->src[i] = ggml_backend_meta_buffer_simple_tensor(tensor->src[i], j);
1207
0
            }
1208
0
        }
1209
1210
0
        simple_tensors.push_back(t_ij);
1211
0
    }
1212
1213
    // If one of the sources has a zero-sized slice, disable the computation:
1214
0
    for (int i = 0; i < GGML_MAX_SRC; i++) {
1215
0
        if (tensor->src[i] == nullptr || !ggml_backend_buffer_is_meta(tensor->src[i]->buffer)) {
1216
0
            continue;
1217
0
        }
1218
1219
0
        const ggml_backend_meta_split_state split_state_src = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true);
1220
0
        if (split_state_src.axis < 0 || split_state_src.axis >= GGML_MAX_DIMS) {
1221
0
            continue;
1222
0
        }
1223
0
        for (size_t j = 0; j < n_simple_bufs; j++) {
1224
0
            int64_t ne_sum = 0;
1225
0
            for (size_t s = 0; s < split_state_src.n_segments; s++) {
1226
0
                ne_sum += split_state_src.ne[s*n_simple_bufs + j] * split_state_src.nr[s];
1227
0
            }
1228
0
            if (ne_sum == 0) {
1229
0
                simple_tensors[j]->flags &= ~GGML_TENSOR_FLAG_COMPUTE;
1230
0
            }
1231
0
        }
1232
0
    }
1233
1234
0
    stc.simple_tensors[tensor] = simple_tensors;
1235
1236
0
    return GGML_STATUS_SUCCESS;
1237
0
}
1238
1239
0
static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
1240
0
    GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
1241
0
    ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
1242
0
    buf_ctx->stc_compute_index = buf_ctx->stc_compute_index_next;
1243
0
    return ggml_backend_meta_buffer_init_tensor_impl(buf_ctx->get_simple_tensor_container(tensor), tensor);
1244
0
}
1245
1246
0
static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1247
0
    const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer);
1248
0
    GGML_ASSERT(ggml_is_contiguous(tensor));
1249
1250
0
    const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1251
1252
0
    if (split_state.n_segments != 1 || split_state.nr[0] != 1) {
1253
0
        GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
1254
0
        GGML_ASSERT(split_state.nr[0] != 0);
1255
0
        GGML_ASSERT(tensor->ne[3] == 1);
1256
1257
0
        size_t offset_data = 0;
1258
0
        std::vector<size_t> simple_offsets(n_bufs, 0);
1259
0
        if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) {
1260
0
            GGML_ASSERT(tensor->ne[2] == 1);
1261
1262
0
            const size_t row_stride = tensor->nb[1];
1263
0
            GGML_ASSERT(offset % row_stride == 0);
1264
0
            GGML_ASSERT(size   % row_stride == 0);
1265
0
            const int64_t row_start = offset / row_stride;
1266
0
            const int64_t row_count = size   / row_stride;
1267
0
            GGML_ASSERT(row_start + row_count <= tensor->ne[1]);
1268
1269
0
            const int64_t blck_size = ggml_blck_size(tensor->type);
1270
0
            for (size_t s = 0; s < split_state.n_segments; s++) {
1271
0
                for (size_t r = 0; r < split_state.nr[s]; r++) {
1272
0
                    for (size_t j = 0; j < n_bufs; j++) {
1273
0
                        ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1274
0
                        GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0);
1275
0
                        const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0];
1276
0
                        ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data,
1277
0
                            simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes,
1278
0
                            row_count, simple_tensor->nb[1], tensor->nb[1]);
1279
0
                        offset_data       += nbytes;
1280
0
                        simple_offsets[j] += nbytes;
1281
0
                    }
1282
0
                }
1283
0
            }
1284
0
            GGML_ASSERT(offset_data*row_count == size);
1285
0
            return;
1286
0
        }
1287
0
        GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1);
1288
1289
0
        const size_t row_stride = tensor->nb[2];
1290
0
        GGML_ASSERT(offset % row_stride == 0);
1291
0
        GGML_ASSERT(size   % row_stride == 0);
1292
0
        const int64_t row_start = offset / row_stride;
1293
0
        const int64_t row_count = size   / row_stride;
1294
0
        GGML_ASSERT(row_start + row_count <= tensor->ne[2]);
1295
1296
0
        for (size_t s = 0; s < split_state.n_segments; s++) {
1297
0
            for (size_t r = 0; r < split_state.nr[s]; r++) {
1298
0
                for (size_t j = 0; j < n_bufs; j++) {
1299
0
                    ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1300
0
                    const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1];
1301
0
                    ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data,
1302
0
                        simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes,
1303
0
                        row_count, simple_tensor->nb[2], tensor->nb[2]);
1304
0
                    offset_data       += nbytes;
1305
0
                    simple_offsets[j] += nbytes;
1306
0
                }
1307
0
            }
1308
0
        }
1309
0
        GGML_ASSERT(offset_data*row_count == size);
1310
0
        return;
1311
0
    }
1312
1313
0
    switch (split_state.axis) {
1314
0
        case GGML_BACKEND_SPLIT_AXIS_0:
1315
0
        case GGML_BACKEND_SPLIT_AXIS_1:
1316
0
        case GGML_BACKEND_SPLIT_AXIS_2: {
1317
            // Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
1318
0
            const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
1319
0
            GGML_ASSERT(offset % chunk_size_full == 0);
1320
0
            GGML_ASSERT(size   % chunk_size_full == 0);
1321
0
            const int64_t i_start =  offset        /chunk_size_full;
1322
0
            const int64_t i_stop  = (offset + size)/chunk_size_full;
1323
0
            size_t offset_j = 0;
1324
0
            for (size_t j = 0; j < n_bufs; j++) {
1325
0
                ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1326
0
                const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
1327
0
                if (chunk_size_j == 0) {
1328
0
                    continue;
1329
0
                }
1330
0
                const size_t simple_offset = i_start * chunk_size_j;
1331
0
                ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full);
1332
0
                offset_j += chunk_size_j;
1333
0
            }
1334
0
            GGML_ASSERT(offset_j == chunk_size_full);
1335
0
        } break;
1336
0
        case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
1337
0
            for (size_t j = 0; j < n_bufs; j++) {
1338
0
                ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1339
0
                ggml_backend_tensor_set(simple_tensor, data, offset, size);
1340
0
            }
1341
0
        } break;
1342
0
        case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
1343
0
            GGML_ASSERT(tensor->type == GGML_TYPE_F32);
1344
0
            const int64_t ne = ggml_nelements(tensor);
1345
0
            std::vector<float> tmp;
1346
0
            tmp.reserve(ne);
1347
0
            for (int64_t i = 0; i < ne; i++) {
1348
0
                tmp.push_back(((const float *) data)[i] / n_bufs);
1349
0
            }
1350
0
            for (size_t j = 0; j < n_bufs; j++) {
1351
0
                ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1352
0
                ggml_backend_tensor_set(simple_tensor, tmp.data(), offset, size);
1353
0
            }
1354
0
        } break;
1355
0
        default: {
1356
0
            GGML_ABORT("fatal error");
1357
0
        }
1358
0
    }
1359
0
}
1360
1361
0
static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1362
0
    const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer);
1363
0
    GGML_ASSERT(ggml_is_contiguous(tensor));
1364
1365
0
    const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1366
1367
0
    if (split_state.n_segments != 1 || split_state.nr[0] != 1) {
1368
0
        GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
1369
0
        GGML_ASSERT(split_state.nr[0] != 0);
1370
0
        GGML_ASSERT(tensor->ne[3] == 1);
1371
1372
0
        size_t offset_data = 0;
1373
0
        std::vector<size_t> simple_offsets(n_bufs, 0);
1374
0
        if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) {
1375
0
            GGML_ASSERT(tensor->ne[2] == 1);
1376
1377
0
            const size_t row_stride = tensor->nb[1];
1378
0
            GGML_ASSERT(offset % row_stride == 0);
1379
0
            GGML_ASSERT(size   % row_stride == 0);
1380
0
            const int64_t row_start = offset / row_stride;
1381
0
            const int64_t row_count = size   / row_stride;
1382
0
            GGML_ASSERT(row_start + row_count <= tensor->ne[1]);
1383
1384
0
            const int64_t blck_size = ggml_blck_size(tensor->type);
1385
0
            for (size_t s = 0; s < split_state.n_segments; s++) {
1386
0
                for (size_t r = 0; r < split_state.nr[s]; r++) {
1387
0
                    for (size_t j = 0; j < n_bufs; j++) {
1388
0
                        const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1389
0
                        GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0);
1390
0
                        const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0];
1391
0
                        ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data,
1392
0
                            simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes,
1393
0
                            row_count, simple_tensor->nb[1], tensor->nb[1]);
1394
0
                        offset_data       += nbytes;
1395
0
                        simple_offsets[j] += nbytes;
1396
0
                    }
1397
0
                }
1398
0
            }
1399
0
            GGML_ASSERT(offset_data*row_count == size);
1400
0
            return;
1401
0
        }
1402
0
        GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1);
1403
1404
0
        const size_t row_stride = tensor->nb[2];
1405
0
        GGML_ASSERT(offset % row_stride == 0);
1406
0
        GGML_ASSERT(size   % row_stride == 0);
1407
0
        const int64_t row_start = offset / row_stride;
1408
0
        const int64_t row_count = size   / row_stride;
1409
0
        GGML_ASSERT(row_start + row_count <= tensor->ne[2]);
1410
1411
0
        for (size_t s = 0; s < split_state.n_segments; s++) {
1412
0
            for (size_t r = 0; r < split_state.nr[s]; r++) {
1413
0
                for (size_t j = 0; j < n_bufs; j++) {
1414
0
                    const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1415
0
                    const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1];
1416
0
                    ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data,
1417
0
                        simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes,
1418
0
                        row_count, simple_tensor->nb[2], tensor->nb[2]);
1419
0
                    offset_data       += nbytes;
1420
0
                    simple_offsets[j] += nbytes;
1421
0
                }
1422
0
            }
1423
0
        }
1424
0
        GGML_ASSERT(offset_data*row_count == size);
1425
0
        return;
1426
0
    }
1427
1428
0
    switch (split_state.axis) {
1429
0
        case GGML_BACKEND_SPLIT_AXIS_0:
1430
0
        case GGML_BACKEND_SPLIT_AXIS_1:
1431
0
        case GGML_BACKEND_SPLIT_AXIS_2: {
1432
            // Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
1433
0
            const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
1434
0
            GGML_ASSERT(offset % chunk_size_full == 0);
1435
0
            GGML_ASSERT(size   % chunk_size_full == 0);
1436
0
            const int64_t i_start =  offset        /chunk_size_full;
1437
0
            const int64_t i_stop  = (offset + size)/chunk_size_full;
1438
0
            size_t offset_j = 0;
1439
0
            for (size_t j = 0; j < n_bufs; j++){
1440
0
                const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1441
0
                const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
1442
0
                if (chunk_size_j == 0) {
1443
0
                    continue;
1444
0
                }
1445
0
                const size_t simple_offset = i_start * chunk_size_j;
1446
0
                ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full);
1447
0
                offset_j += chunk_size_j;
1448
0
            }
1449
0
            GGML_ASSERT(offset_j == chunk_size_full);
1450
0
        } break;
1451
0
        case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
1452
            // TODO other simple backend may be better
1453
0
            const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0);
1454
0
            ggml_backend_tensor_get(simple_tensor, data, offset, size);
1455
0
        } break;
1456
0
        default: {
1457
0
            GGML_ABORT("fatal error");
1458
0
        }
1459
0
    }
1460
0
}
1461
1462
0
static void ggml_backend_meta_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1463
0
    const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer);
1464
0
    for (size_t i = 0; i < n_buffers; i++) {
1465
0
        ggml_backend_buffer_clear(ggml_backend_meta_buffer_simple_buffer(buffer, i), value);
1466
0
    }
1467
0
}
1468
1469
0
static void ggml_backend_meta_buffer_reset(ggml_backend_buffer_t buffer) {
1470
0
    GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
1471
0
    ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
1472
0
    for (size_t i = 0; i < buf_ctx->bufs.size(); i++) {
1473
0
        ggml_backend_buffer_reset(ggml_backend_meta_buffer_simple_buffer(buffer, i));
1474
0
    }
1475
0
}
1476
1477
static const ggml_backend_buffer_i ggml_backend_meta_buffer_iface = {
1478
    /* .free_buffer     = */ ggml_backend_meta_buffer_free_buffer,
1479
    /* .get_base        = */ ggml_backend_meta_buffer_get_base,
1480
    /* .init_tensor     = */ ggml_backend_meta_buffer_init_tensor,
1481
    /* .memset_tensor   = */ nullptr, // TODO implement
1482
    /* .set_tensor      = */ ggml_backend_meta_buffer_set_tensor,
1483
    /* .get_tensor      = */ ggml_backend_meta_buffer_get_tensor,
1484
    /* .set_tensor_2d   = */ nullptr,
1485
    /* .get_tensor_2d   = */ nullptr,
1486
    /* .cpy_tensor      = */ nullptr,
1487
    /* .clear           = */ ggml_backend_meta_buffer_clear,
1488
    /* .reset           = */ ggml_backend_meta_buffer_reset,
1489
};
1490
1491
0
bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf) {
1492
0
    return buf != nullptr && buf->iface.free_buffer == ggml_backend_meta_buffer_iface.free_buffer;
1493
0
}
1494
1495
0
static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1496
0
    const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
1497
1498
0
    const ggml_init_params params = {
1499
0
        /*.mem_size   =*/ 1024*1024*ggml_tensor_overhead(), // FIXME
1500
0
        /*.mem_buffer =*/ nullptr,
1501
0
        /*.no_alloc   =*/ true,
1502
0
    };
1503
0
    ggml_backend_meta_simple_tensor_container stc_static;
1504
0
    ggml_backend_meta_simple_tensor_container stc_compute_0(params, n_simple_bufts);
1505
0
    ggml_backend_meta_simple_tensor_container stc_compute_1(params, n_simple_bufts);
1506
1507
0
    size_t max_size = 0;
1508
0
    std::vector<ggml_backend_buffer_t> bufs;
1509
0
    bufs.reserve(n_simple_bufts);
1510
0
    for (size_t i = 0; i < n_simple_bufts; i++) {
1511
0
        bufs.push_back(ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size));
1512
0
        GGML_ASSERT(bufs.back() != nullptr);
1513
0
        max_size = std::max(max_size, ggml_backend_buffer_get_size(bufs.back()));
1514
0
    }
1515
0
    ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs);
1516
1517
0
    return ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, buf_ctx, max_size);
1518
0
}
1519
1520
0
struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
1521
0
    const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
1522
1523
0
    constexpr size_t compute_headroom = 16; // Maximum number of views per statically allocated tensor that can be created between evals.
1524
0
    const ggml_init_params params_static = {
1525
0
        /*.mem_size   =*/ ggml_get_mem_size(ctx),
1526
0
        /*.mem_buffer =*/ nullptr,
1527
0
        /*.no_alloc   =*/ true,
1528
0
    };
1529
0
    const ggml_init_params params_compute = {
1530
0
        /*.mem_size   =*/ compute_headroom*ggml_get_mem_size(ctx),
1531
0
        /*.mem_buffer =*/ nullptr,
1532
0
        /*.no_alloc   =*/ true,
1533
0
    };
1534
0
    ggml_backend_meta_simple_tensor_container stc_static   (params_static,  n_simple_bufts);
1535
0
    ggml_backend_meta_simple_tensor_container stc_compute_0(params_compute, n_simple_bufts);
1536
0
    ggml_backend_meta_simple_tensor_container stc_compute_1(params_compute, n_simple_bufts);
1537
1538
0
    std::vector<ggml_backend_buffer_t> bufs(n_simple_bufts, nullptr);
1539
0
    ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs);
1540
1541
0
    ggml_backend_buffer_t meta_buf = ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, meta_buf_ctx, 0);
1542
0
    for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
1543
0
        t->buffer = meta_buf;
1544
0
        ggml_backend_meta_buffer_init_tensor_impl(meta_buf_ctx->stc_static, t);
1545
0
        t->data = (void *) 0x2000000000000000; // FIXME
1546
0
    }
1547
0
    for (size_t i = 0; i < n_simple_bufts; i++) {
1548
0
        ggml_context * ctx = meta_buf_ctx->stc_static.ctxs[i].get();
1549
0
        ggml_backend_buffer_type_t simple_buft = ggml_backend_meta_buft_simple_buft(buft, i);
1550
1551
        // If a ggml_context only has zero-sized tensors, ggml_backend_alloc_ctx_tensors_from_buft returns NULL.
1552
        // For those edge cases, allocate a dummy buffer instead.
1553
0
        bool any_nonzero_slice = false;
1554
0
        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
1555
0
            if (ggml_nelements(t) != 0) {
1556
0
                any_nonzero_slice = true;
1557
0
                break;
1558
0
            }
1559
0
        }
1560
0
        if (any_nonzero_slice) {
1561
0
            meta_buf_ctx->bufs[i].reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx, simple_buft));
1562
0
        } else {
1563
0
            meta_buf_ctx->bufs[i].reset(ggml_backend_buft_alloc_buffer(simple_buft, 0));
1564
0
            for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
1565
0
                t->buffer = meta_buf_ctx->bufs[i].get();
1566
0
            }
1567
0
        }
1568
0
        GGML_ASSERT(meta_buf_ctx->bufs[i]);
1569
0
        meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->bufs[i].get()));
1570
0
    }
1571
0
    return meta_buf;
1572
0
}
1573
1574
//
1575
// meta backend
1576
//
1577
1578
0
static ggml_guid_t ggml_backend_meta_guid() {
1579
0
    static ggml_guid guid = {0xf1, 0x0e, 0x34, 0xcf, 0x9c, 0x6f, 0x43, 0xcb, 0x96, 0x92, 0xbe, 0x8e, 0xbb, 0x71, 0x3f, 0xda};
1580
0
    return &guid;
1581
0
}
1582
1583
struct ggml_backend_meta_context {
1584
    struct cgraph_config {
1585
        ggml_cgraph * cgraph_main = nullptr;
1586
        int           offset      = 0; // Node offset vs. original graph
1587
1588
        std::vector<ggml_cgraph *> cgraphs_aux;
1589
    };
1590
    struct backend_config {
1591
        ggml_backend_t backend;
1592
1593
        std::vector<cgraph_config>           cgraphs;
1594
        std::vector<ggml_tensor *>           nodes;
1595
        std::vector<ggml_backend_buffer_ptr> bufs;
1596
1597
0
        backend_config(ggml_backend_t backend, const size_t n_reduce_steps) : backend(backend) {
1598
0
            bufs.resize(n_reduce_steps);
1599
0
        }
1600
    };
1601
    std::string                 name;
1602
    std::vector<backend_config> backend_configs;
1603
    ggml_context_ptr            ctx;
1604
    std::vector<ggml_cgraph *>  cgraphs_aux;
1605
    std::vector<ggml_tensor *>  nodes_aux;
1606
    size_t                      n_reduce_steps;
1607
    int                         max_nnodes    = 0;
1608
    size_t                      max_tmp_size  = 0;
1609
    size_t                      max_subgraphs = 0;
1610
    size_t                      n_subgraphs   = 0;
1611
    uint64_t                    uid           = 0;
1612
1613
    void *                               comm_ctx       = nullptr;
1614
    ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr;
1615
1616
0
    ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) {
1617
0
        const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev);
1618
0
        n_reduce_steps = std::ceil(std::log2(n_devs));
1619
0
        name = "Meta(";
1620
0
        std::vector<ggml_backend_t> simple_backends;
1621
0
        backend_configs.reserve(n_devs);
1622
0
        simple_backends.reserve(n_devs);
1623
0
        for (size_t i = 0; i < n_devs; i++) {
1624
0
            ggml_backend_dev_t simple_dev = ggml_backend_meta_dev_simple_dev(meta_dev, i);
1625
0
            if (i > 0) {
1626
0
                name += ",";
1627
0
            }
1628
0
            name += ggml_backend_dev_name(simple_dev);
1629
0
            simple_backends.push_back(ggml_backend_dev_init(simple_dev, params));
1630
0
            backend_configs.emplace_back(simple_backends.back(), n_reduce_steps);
1631
0
        }
1632
0
        name += ")";
1633
1634
0
        if (n_devs > 1) {
1635
0
            ggml_backend_comm_init_t comm_init = (ggml_backend_comm_init_t) ggml_backend_reg_get_proc_address(
1636
0
                ggml_backend_dev_backend_reg(ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_init");
1637
0
            if (comm_init != nullptr) {
1638
0
                comm_ctx = comm_init(simple_backends.data(), simple_backends.size());
1639
0
            }
1640
0
        }
1641
0
        if (comm_ctx != nullptr) {
1642
0
            comm_allreduce = (ggml_backend_comm_allreduce_tensor_t)
1643
0
                ggml_backend_reg_get_proc_address(ggml_backend_dev_backend_reg(
1644
0
                    ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_allreduce_tensor");
1645
0
            GGML_ASSERT(comm_allreduce != nullptr);
1646
0
        }
1647
0
    }
1648
1649
0
    ~ggml_backend_meta_context() {
1650
0
        if (comm_ctx != nullptr) {
1651
0
            ggml_backend_comm_free_t comm_free = (ggml_backend_comm_free_t) ggml_backend_reg_get_proc_address(
1652
0
                ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_configs[0].backend)), "ggml_backend_comm_free");
1653
0
            GGML_ASSERT(comm_free != nullptr);
1654
0
            comm_free(comm_ctx);
1655
0
        }
1656
0
        for (auto & bc : backend_configs) {
1657
0
            ggml_backend_free(bc.backend);
1658
0
        }
1659
0
    }
1660
};
1661
1662
0
static const char * ggml_backend_meta_get_name(ggml_backend_t backend) {
1663
0
    GGML_ASSERT(ggml_backend_is_meta(backend));
1664
0
    const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) backend->context;
1665
0
    return backend_ctx->name.c_str();
1666
0
}
1667
1668
0
static void ggml_backend_meta_free(ggml_backend_t backend) {
1669
0
    GGML_ASSERT(ggml_backend_is_meta(backend));
1670
0
    ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context;
1671
0
    delete backend_ctx;
1672
0
    delete backend;
1673
0
}
1674
1675
0
static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1676
0
    const size_t n_backends = ggml_backend_meta_n_backends(backend);
1677
0
    GGML_ASSERT(offset == 0);
1678
0
    GGML_ASSERT(ggml_is_contiguous(tensor));
1679
1680
0
    const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1681
0
    GGML_ASSERT(split_state.n_segments == 1);
1682
0
    GGML_ASSERT(split_state.nr[0]      == 1);
1683
1684
0
    switch (split_state.axis) {
1685
0
        case GGML_BACKEND_SPLIT_AXIS_0:
1686
0
        case GGML_BACKEND_SPLIT_AXIS_1:
1687
0
        case GGML_BACKEND_SPLIT_AXIS_2: {
1688
            // Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
1689
0
            const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
1690
0
            GGML_ASSERT(offset % chunk_size_full == 0);
1691
0
            GGML_ASSERT(size   % chunk_size_full == 0);
1692
0
            const int64_t i_start =  offset        /chunk_size_full;
1693
0
            const int64_t i_stop  = (offset + size)/chunk_size_full;
1694
0
            size_t offset_j = 0;
1695
0
            for (size_t j = 0; j < n_backends; j++){
1696
0
                ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j);
1697
0
                ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1698
0
                const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
1699
0
                if (chunk_size_j == 0) {
1700
0
                    continue;
1701
0
                }
1702
0
                ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j,
1703
0
                    i_stop - i_start, chunk_size_j, chunk_size_full);
1704
0
                offset_j += chunk_size_j;
1705
0
            }
1706
0
            GGML_ASSERT(offset_j == chunk_size_full);
1707
0
        } break;
1708
0
        case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
1709
0
            for (size_t j = 0; j < n_backends; j++) {
1710
0
                ggml_backend_tensor_set_async(
1711
0
                    ggml_backend_meta_simple_backend(backend, j), ggml_backend_meta_buffer_simple_tensor(tensor, j), data, offset, size);
1712
0
            }
1713
0
        } break;
1714
0
        default: {
1715
0
            GGML_ABORT("fatal error");
1716
0
        }
1717
0
    }
1718
0
}
1719
1720
0
static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1721
0
    const size_t n_backends = ggml_backend_meta_n_backends(backend);
1722
0
    GGML_ASSERT(offset == 0);
1723
0
    GGML_ASSERT(ggml_is_contiguous(tensor));
1724
1725
0
    const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1726
0
    GGML_ASSERT(split_state.n_segments == 1);
1727
0
    GGML_ASSERT(split_state.nr[0]      == 1);
1728
1729
0
    switch (split_state.axis) {
1730
0
        case GGML_BACKEND_SPLIT_AXIS_0:
1731
0
        case GGML_BACKEND_SPLIT_AXIS_1:
1732
0
        case GGML_BACKEND_SPLIT_AXIS_2: {
1733
            // Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
1734
0
            const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
1735
0
            GGML_ASSERT(offset % chunk_size_full == 0);
1736
0
            GGML_ASSERT(size   % chunk_size_full == 0);
1737
0
            const int64_t i_start =  offset        /chunk_size_full;
1738
0
            const int64_t i_stop  = (offset + size)/chunk_size_full;
1739
0
            size_t offset_j = 0;
1740
0
            for (size_t j = 0; j < n_backends; j++){
1741
0
                ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j);
1742
0
                const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1743
0
                const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
1744
0
                if (chunk_size_j == 0) {
1745
0
                    continue;
1746
0
                }
1747
0
                ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j,
1748
0
                    i_stop - i_start, chunk_size_j, chunk_size_full);
1749
0
                offset_j += chunk_size_j;
1750
0
            }
1751
0
            GGML_ASSERT(offset_j == chunk_size_full);
1752
0
        } break;
1753
0
        case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
1754
            // TODO other simple backend may be better
1755
0
            ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, 0);
1756
0
            const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0);
1757
0
            ggml_backend_tensor_get_async(simple_backend, simple_tensor, data, offset, size);
1758
0
        } break;
1759
0
        default: {
1760
0
            GGML_ABORT("fatal error");
1761
0
        }
1762
0
    }
1763
0
}
1764
1765
0
static void ggml_backend_meta_synchronize(ggml_backend_t backend) {
1766
0
    const size_t n_backends = ggml_backend_meta_n_backends(backend);
1767
0
    for (size_t i = 0; i < n_backends; i++) {
1768
0
        ggml_backend_synchronize(ggml_backend_meta_simple_backend(backend, i));
1769
0
    }
1770
0
}
1771
1772
0
static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1773
0
    GGML_ASSERT(cgraph->grads == nullptr);
1774
0
    const size_t n_backends = ggml_backend_meta_n_backends(backend);
1775
0
    ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context;
1776
1777
    // If the previous cgraph had a defined UID it can be used to skip rebuilding the subgraphs per simple backend.
1778
0
    const bool needs_rebuild = (cgraph->uid == 0) || (cgraph->uid != backend_ctx->uid);
1779
1780
0
    bool max_nnodes_raised = false;
1781
0
    if (cgraph->n_nodes > backend_ctx->max_nnodes) {
1782
0
        for (size_t j = 0; j < n_backends; j++) {
1783
0
            auto & bcj = backend_ctx->backend_configs[j];
1784
0
            bcj.nodes.resize(cgraph->n_nodes);
1785
0
            bcj.cgraphs.resize(cgraph->n_nodes);
1786
0
        }
1787
0
        backend_ctx->max_nnodes = cgraph->n_nodes;
1788
0
        max_nnodes_raised = true;
1789
0
        assert(needs_rebuild);
1790
0
    }
1791
1792
0
    if (needs_rebuild) {
1793
0
        std::set<ggml_backend_buffer_t> used_buffers;
1794
0
        for (int i = 0; i < cgraph->n_leafs; i++) {
1795
0
            if (ggml_backend_buffer_is_meta(cgraph->leafs[i]->buffer)) {
1796
0
                used_buffers.emplace(cgraph->leafs[i]->buffer);
1797
0
            }
1798
0
        }
1799
0
        for (int i = 0; i < cgraph->n_nodes; i++) {
1800
0
            if (ggml_backend_buffer_is_meta(cgraph->nodes[i]->buffer)) {
1801
0
                used_buffers.emplace(cgraph->nodes[i]->buffer);
1802
0
            }
1803
0
        }
1804
0
        for (ggml_backend_buffer_t buf : used_buffers) {
1805
0
            ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buf->context;
1806
0
            buf_ctx->stc_compute_index_next = buf_ctx->stc_compute_index ^ 1;
1807
0
            ggml_backend_meta_simple_tensor_container & stc = buf_ctx->stc_compute[buf_ctx->stc_compute_index_next];
1808
0
            for (ggml_context_ptr & ctx : stc.ctxs) {
1809
0
                ggml_reset(ctx.get());
1810
0
            }
1811
0
            stc.simple_tensors.clear();
1812
0
        }
1813
0
        size_t n_subgraphs  = 0;
1814
0
        size_t max_tmp_size = 0;
1815
1816
0
        for (size_t j = 0; j < n_backends; j++) {
1817
0
            auto & bcj = backend_ctx->backend_configs[j];
1818
1819
0
            for (int i = 0; i < cgraph->n_nodes; i++) {
1820
0
                ggml_tensor * node = cgraph->nodes[i];
1821
0
                if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) {
1822
                    // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes.
1823
                    // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash.
1824
0
                    bcj.nodes[i] = node;
1825
0
                    continue;
1826
0
                }
1827
0
                bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j);
1828
0
                GGML_ASSERT(bcj.nodes[i]);
1829
0
            }
1830
0
        }
1831
1832
0
        {
1833
            // For MoE models it may make sense to delay the AllReduce in order to reduce I/O:
1834
0
            auto get_i_delayed = [&](const int i) -> int {
1835
0
                int id = i; // i_delayed
1836
0
                int idr = i; // i_delayed return, last safe return value
1837
1838
0
                ggml_tensor * node = cgraph->nodes[id];
1839
0
                int32_t n_used = ggml_node_get_use_count(cgraph, id);
1840
1841
                // Skip MIRRORED nodes that don't consume node
1842
0
                auto skip_unrelated = [&]() {
1843
0
                    while (id + 1 < cgraph->n_nodes) {
1844
0
                        ggml_tensor * next = cgraph->nodes[id+1];
1845
0
                        if (ggml_backend_meta_get_split_state(next, false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
1846
0
                            break;
1847
0
                        }
1848
0
                        bool safe = true;
1849
0
                        for (int s = 0; s < GGML_MAX_SRC; s++) {
1850
0
                            if (next->src[s] == nullptr) {
1851
0
                                continue;
1852
0
                            }
1853
0
                            if (next->src[s] == node) {
1854
0
                                safe = false;
1855
0
                                break;
1856
0
                            }
1857
0
                            if (ggml_backend_meta_get_split_state(next->src[s], false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
1858
0
                                safe = false;
1859
0
                                break;
1860
0
                            }
1861
0
                        }
1862
0
                        if (!safe) {
1863
0
                            break;
1864
0
                        }
1865
0
                        id++;
1866
0
                    }
1867
0
                };
1868
1869
0
                skip_unrelated();
1870
0
                if (id + 1 >= cgraph->n_nodes) {
1871
0
                    return idr;
1872
0
                }
1873
0
                {
1874
0
                    ggml_tensor * next = cgraph->nodes[id+1];
1875
0
                    if (next->op == GGML_OP_ADD_ID && next->src[0] == node &&
1876
0
                            ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL &&
1877
0
                            ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
1878
0
                        node = next;
1879
0
                        id++;
1880
0
                        idr = id;
1881
0
                        n_used = ggml_node_get_use_count(cgraph, id);
1882
0
                    }
1883
0
                }
1884
                // Chain of MULs with MIRRORED src[1]
1885
0
                while (true) {
1886
0
                    skip_unrelated();
1887
0
                    if (id + 1 >= cgraph->n_nodes) {
1888
0
                        return idr;
1889
0
                    }
1890
0
                    ggml_tensor * next = cgraph->nodes[id+1];
1891
0
                    if (next->op == GGML_OP_MUL && next->src[0] == node &&
1892
0
                            ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
1893
0
                        node = next;
1894
0
                        id++;
1895
0
                        idr = id;
1896
0
                        n_used = ggml_node_get_use_count(cgraph, id);
1897
0
                    } else {
1898
0
                        break;
1899
0
                    }
1900
0
                }
1901
1902
0
                if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) {
1903
0
                    return idr;
1904
0
                }
1905
0
                for (int32_t k = 0; k < n_used; k++) {
1906
0
                    ggml_tensor * next = cgraph->nodes[id+1];
1907
0
                    if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] ||
1908
0
                            next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] ||
1909
0
                            ggml_node_get_use_count(cgraph, id+1) != 1) {
1910
0
                        return idr;
1911
0
                    }
1912
0
                    id++;
1913
0
                }
1914
0
                {
1915
0
                    ggml_tensor * next = cgraph->nodes[id+1];
1916
0
                    if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] ||
1917
0
                            next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) {
1918
0
                        return idr;
1919
0
                    }
1920
0
                    id++;
1921
0
                }
1922
0
                for (int32_t k = 0; k < n_used - 2; k++) {
1923
0
                    ggml_tensor * next = cgraph->nodes[id+1];
1924
0
                    if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] ||
1925
0
                            next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) {
1926
0
                        return idr;
1927
0
                    }
1928
0
                    id++;
1929
0
                }
1930
0
                idr = id;
1931
0
                return idr;
1932
0
            };
1933
1934
0
            int i_start = 0;
1935
0
            for (int i = 0; i < cgraph->n_nodes; i++) {
1936
0
                ggml_tensor * node = cgraph->nodes[i];
1937
0
                if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) {
1938
0
                    continue;
1939
0
                }
1940
0
                const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false);
1941
0
                if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) {
1942
0
                    max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node));
1943
0
                }
1944
0
                const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL;
1945
0
                if (!new_subgraph) {
1946
0
                    continue;
1947
0
                }
1948
1949
0
                const int i_delayed = get_i_delayed(i);
1950
1951
                // If we can delay the AllReduce we need to consider the interaction with zero-sized tensor slices.
1952
                // A backend with such a slice would normally have valid data after participating in the AllReduce with a node that has
1953
                //     its compute flag disabled and thus gets its data zeroed out.
1954
                // If the AllReduce is delayed then the nodes until that point also need to have their compute flag disabled.
1955
0
                if (i_delayed > i) {
1956
0
                    for (size_t j = 0; j < n_backends; j++) {
1957
0
                        auto & bcj = backend_ctx->backend_configs[j];
1958
0
                        if ((bcj.nodes[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
1959
0
                            for (int ii = i + 1; ii <= i_delayed; ii++) {
1960
0
                                bcj.nodes[ii]->flags &= ~GGML_TENSOR_FLAG_COMPUTE;
1961
0
                            }
1962
0
                        }
1963
0
                    }
1964
0
                }
1965
1966
0
                i = i_delayed;
1967
1968
0
                for (size_t j = 0; j < n_backends; j++) {
1969
0
                    auto & bcj = backend_ctx->backend_configs[j];
1970
0
                    bcj.cgraphs[n_subgraphs].offset = i_start;
1971
0
                }
1972
0
                n_subgraphs++;
1973
0
                i_start = i + 1;
1974
0
            }
1975
0
            GGML_ASSERT(i_start == cgraph->n_nodes);
1976
0
        }
1977
1978
0
        backend_ctx->uid         = cgraph->uid;
1979
0
        backend_ctx->n_subgraphs = n_subgraphs;
1980
1981
0
        if (max_tmp_size > backend_ctx->max_tmp_size) {
1982
0
            for (size_t j = 0; j < n_backends; j++) {
1983
0
                auto & bcj = backend_ctx->backend_configs[j];
1984
0
                for (size_t i = 0; i < backend_ctx->n_reduce_steps; i++) {
1985
0
                    bcj.bufs[i].reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size));
1986
0
                }
1987
0
            }
1988
0
            backend_ctx->max_tmp_size = max_tmp_size;
1989
0
        }
1990
1991
0
        if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) {
1992
0
            backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs);
1993
0
            const size_t n_nodes_per_device = 3 * backend_ctx->n_reduce_steps; // tmp + ADD (+zeroing) graph per step and device
1994
0
            const size_t n_cgraphs_per_device = 2 * backend_ctx->n_reduce_steps; // ADD ( + zeroing) graph per step and device
1995
0
            const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads);
1996
0
            const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads);
1997
0
            const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead();
1998
0
            const ggml_init_params params = {
1999
0
                /*.mem_size   =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux),
2000
0
                /*.mem_buffer =*/ nullptr,
2001
0
                /*.no_alloc   =*/ true,
2002
0
            };
2003
0
            backend_ctx->ctx.reset(ggml_init(params));
2004
0
            for (size_t j = 0; j < n_backends; j++) {
2005
0
                auto & bcj = backend_ctx->backend_configs[j];
2006
0
                for (size_t i = 0; i < n_subgraphs; i++) {
2007
0
                    bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false);
2008
0
                }
2009
0
            }
2010
0
            backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs);
2011
0
            for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) {
2012
0
                backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads);
2013
0
            }
2014
0
            backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs);
2015
0
            for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) {
2016
0
                backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1);
2017
0
            }
2018
0
        }
2019
2020
0
        for (size_t j = 0; j < n_backends; j++) {
2021
0
            auto & bcj = backend_ctx->backend_configs[j];
2022
0
            for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) {
2023
0
                ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main;
2024
0
                const size_t i_node_start = bcj.cgraphs[i_graph].offset;
2025
0
                const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes;
2026
0
                cgraph_ij->n_nodes = i_node_stop - i_node_start;
2027
0
                ggml_hash_set_reset(&cgraph_ij->visited_hash_set);
2028
0
                for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) {
2029
0
                    ggml_tensor * node_ij = bcj.nodes[i_node];
2030
0
                    cgraph_ij->nodes[i_node - i_node_start] = node_ij;
2031
0
                    const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]);
2032
0
                    const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij);
2033
0
                    cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig];
2034
0
                }
2035
0
                cgraph_ij->uid = ggml_graph_next_uid();
2036
0
            }
2037
0
        }
2038
0
    }
2039
2040
0
    size_t iga = 0; // i graph aux
2041
0
    size_t ina = 0; // i node aux
2042
2043
0
    auto get_node_aux = [&](ggml_tensor * t) -> ggml_tensor * {
2044
0
        ggml_tensor * ret = backend_ctx->nodes_aux[ina++];
2045
0
        memset(ret, 0, sizeof(ggml_tensor));
2046
0
        ret->op   = GGML_OP_NONE;
2047
0
        ret->type = t->type;
2048
0
        for (size_t k = 0; k < GGML_MAX_DIMS; k++) {
2049
0
            ret->ne[k] = t->ne[k];
2050
0
            ret->nb[k] = t->nb[k];
2051
0
        }
2052
0
        return ret;
2053
0
    };
2054
0
    auto set_tmp_data = [&](ggml_tensor * tensor, const size_t j, const size_t i_buf) {
2055
0
        auto & bcj = backend_ctx->backend_configs[j];
2056
0
        ggml_backend_buffer_ptr & buf_ptr = bcj.bufs[i_buf];
2057
0
        if (!buf_ptr || ggml_backend_buffer_get_size(buf_ptr.get()) < backend_ctx->max_tmp_size) {
2058
0
            buf_ptr.reset(ggml_backend_alloc_buffer(bcj.backend, backend_ctx->max_tmp_size));
2059
0
        }
2060
0
        tensor->buffer = buf_ptr.get();
2061
0
        tensor->data   = ggml_backend_buffer_get_base(buf_ptr.get());
2062
0
    };
2063
    // FIXME usage_counts
2064
0
    auto get_cgraph_aux = [&]() -> ggml_cgraph * {
2065
0
        ggml_cgraph * ret = backend_ctx->cgraphs_aux[iga++];
2066
0
        return ret;
2067
0
    };
2068
2069
    // Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable:
2070
0
    auto allreduce_fallback = [&](size_t i) -> ggml_status {
2071
0
        std::vector<ggml_cgraph *> step_cgraphs(n_backends, nullptr);
2072
2073
        // Zero out nodes that were disabled due to having a zero-sized slice:
2074
0
        for (size_t j = 0; j < n_backends; j++) {
2075
0
            auto & bcj = backend_ctx->backend_configs[j];
2076
0
            ggml_tensor * node = bcj.cgraphs[i].cgraph_main->nodes[bcj.cgraphs[i].cgraph_main->n_nodes - 1];
2077
0
            if (node->flags & GGML_TENSOR_FLAG_COMPUTE) {
2078
0
                continue;
2079
0
            }
2080
0
            ggml_tensor * node_zero = get_node_aux(node);
2081
0
            node_zero->op = GGML_OP_SCALE; // FIXME 0.0f * NaN == NaN
2082
0
            node_zero->src[0] = node;
2083
0
            ggml_set_op_params_f32(node_zero, 0, 0.0f);
2084
0
            node_zero->data = node->data;
2085
0
            node_zero->buffer = node->buffer;
2086
0
            node_zero->flags |= GGML_TENSOR_FLAG_COMPUTE;
2087
2088
0
            step_cgraphs[j] = get_cgraph_aux();
2089
0
            step_cgraphs[j]->nodes[0] = node_zero;
2090
0
            step_cgraphs[j]->n_nodes = 1;
2091
0
            const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]);
2092
0
            if (status != GGML_STATUS_SUCCESS) {
2093
0
                return status;
2094
0
            }
2095
0
        }
2096
0
        std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr);
2097
2098
0
        auto push_data = [&](const size_t j_src, const size_t j_dst, const size_t i_buf) {
2099
0
            assert(step_cgraphs[j_dst] == nullptr);
2100
0
            auto & bcj_src = backend_ctx->backend_configs[j_src];
2101
0
            auto & bcj_dst = backend_ctx->backend_configs[j_dst];
2102
2103
0
            ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1];
2104
0
            ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1];
2105
0
            GGML_ASSERT(ggml_is_contiguous(node_src));
2106
0
            GGML_ASSERT(ggml_is_contiguous(node_dst));
2107
2108
0
            ggml_tensor * node_tmp = get_node_aux(node_dst);
2109
0
            set_tmp_data(node_tmp, j_dst, i_buf);
2110
2111
0
            ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_tmp);
2112
2113
0
            ggml_tensor * node_red = get_node_aux(node_dst);
2114
0
            node_red->view_src = node_dst->view_src == nullptr ? node_dst : node_dst->view_src;
2115
0
            node_red->view_offs = node_dst->view_offs;
2116
0
            node_red->op = GGML_OP_ADD;
2117
0
            node_red->src[0] = node_dst;
2118
0
            node_red->src[1] = node_tmp;
2119
0
            node_red->flags |= GGML_TENSOR_FLAG_COMPUTE;
2120
0
            ggml_backend_view_init(node_red);
2121
2122
0
            ggml_cgraph * cgraph_aux = get_cgraph_aux();
2123
0
            cgraph_aux->nodes[0] = node_red;
2124
0
            cgraph_aux->n_nodes = 1;
2125
0
            step_cgraphs[j_dst] = cgraph_aux;
2126
0
        };
2127
2128
0
        size_t offset_j = n_backends/2;
2129
0
        while ((offset_j & (offset_j - 1)) != 0) {
2130
0
            offset_j--;
2131
0
        }
2132
0
        const size_t offset_j_max = offset_j;
2133
0
        size_t i_buf = 0;
2134
2135
        // If n_backends is not a power of 2, fold in the excess prior to butterfly reduction:
2136
0
        for (size_t j_src = 2*offset_j_max; j_src < n_backends; j_src++) {
2137
0
            const size_t j_dst = j_src - 2*offset_j_max;
2138
0
            push_data(j_src, j_dst, i_buf);
2139
0
            const ggml_status status = ggml_backend_graph_compute_async(backend_ctx->backend_configs[j_dst].backend, step_cgraphs[j_dst]);
2140
0
            if (status != GGML_STATUS_SUCCESS) {
2141
0
                return status;
2142
0
            }
2143
0
            i_buf = 1;
2144
0
        }
2145
2146
        // Butterfly reduction:
2147
0
        for (; offset_j >= 1; offset_j /= 2) {
2148
0
            std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr);
2149
2150
0
            for (size_t j = 0; j < 2*offset_j_max; j++) {
2151
0
                const size_t j_other = j ^ offset_j;
2152
0
                if (j_other >= n_backends) {
2153
0
                    continue;
2154
0
                }
2155
0
                push_data(j, j_other, i_buf);
2156
0
            }
2157
2158
0
            for (size_t j = 0; j < 2*offset_j_max; j++) {
2159
0
                if (step_cgraphs[j] == nullptr) {
2160
0
                    continue;
2161
0
                }
2162
0
                auto & bcj = backend_ctx->backend_configs[j];
2163
0
                const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]);
2164
0
                if (status != GGML_STATUS_SUCCESS) {
2165
0
                    return status;
2166
0
                }
2167
0
            }
2168
0
            i_buf++;
2169
0
        }
2170
0
        assert(i_buf == backend_ctx->n_reduce_steps);
2171
2172
        // If n_backends is not a power of 2, copy back the reduced tensors to the excess:
2173
0
        for (size_t j = 2*offset_j_max; j < n_backends; j++) {
2174
0
            auto & bcj_src = backend_ctx->backend_configs[j - 2*offset_j_max];
2175
0
            auto & bcj_dst = backend_ctx->backend_configs[j];
2176
2177
0
            ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1];
2178
0
            ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1];
2179
0
            ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_dst);
2180
0
        }
2181
2182
0
        return GGML_STATUS_SUCCESS;
2183
0
    };
2184
2185
2186
0
    for (size_t i = 0; i < backend_ctx->n_subgraphs; i++) {
2187
0
        for (size_t j = 0; j < n_backends; j++) {
2188
0
            auto & bcj = backend_ctx->backend_configs[j];
2189
0
            const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, bcj.cgraphs[i].cgraph_main);
2190
0
            if (status != GGML_STATUS_SUCCESS) {
2191
0
                return status;
2192
0
            }
2193
0
        }
2194
2195
0
        if (n_backends > 1 && i < backend_ctx->n_subgraphs - 1) {
2196
0
            bool backend_allreduce_success = false;
2197
0
            if (backend_ctx->comm_ctx) {
2198
0
                std::vector<ggml_tensor *> nodes;
2199
0
                nodes.reserve(n_backends);
2200
0
                for (size_t j = 0; j < n_backends; j++) {
2201
0
                    auto & bcj = backend_ctx->backend_configs[j];
2202
0
                    ggml_cgraph * cgraph_ij = bcj.cgraphs[i].cgraph_main;
2203
0
                    nodes.push_back(cgraph_ij->nodes[cgraph_ij->n_nodes-1]);
2204
0
                }
2205
0
                backend_allreduce_success = backend_ctx->comm_allreduce(backend_ctx->comm_ctx, nodes.data());
2206
0
            }
2207
2208
0
            if (!backend_allreduce_success) {
2209
0
                const ggml_status status = allreduce_fallback(i);
2210
0
                if (status != GGML_STATUS_SUCCESS) {
2211
0
                    return status;
2212
0
                }
2213
0
            }
2214
0
        }
2215
0
    }
2216
0
    return GGML_STATUS_SUCCESS;
2217
0
}
2218
2219
static const ggml_backend_i ggml_backend_meta_i = {
2220
    /* .get_name                = */ ggml_backend_meta_get_name,
2221
    /* .free                    = */ ggml_backend_meta_free,
2222
    /* .set_tensor_async        = */ ggml_backend_meta_set_tensor_async,
2223
    /* .get_tensor_async        = */ ggml_backend_meta_get_tensor_async,
2224
    /* .set_tensor_2d_async     = */ nullptr,
2225
    /* .get_tensor_2d_async     = */ nullptr,
2226
    /* .cpy_tensor_async        = */ nullptr,
2227
    /* .synchronize             = */ ggml_backend_meta_synchronize,
2228
    /* .graph_plan_create       = */ nullptr,
2229
    /* .graph_plan_free         = */ nullptr,
2230
    /* .graph_plan_update       = */ nullptr,
2231
    /* .graph_plan_compute      = */ nullptr,
2232
    /* .graph_compute           = */ ggml_backend_meta_graph_compute,
2233
    /* .event_record            = */ nullptr,
2234
    /* .event_wait              = */ nullptr,
2235
    /* .graph_optimize          = */ nullptr,
2236
};
2237
2238
0
bool ggml_backend_is_meta(ggml_backend_t backend) {
2239
0
    return backend != nullptr && backend->iface.get_name == ggml_backend_meta_i.get_name;
2240
0
}
2241
2242
0
static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params) {
2243
0
    ggml_backend_meta_context * backend_ctx = new ggml_backend_meta_context(dev, params);
2244
2245
0
    ggml_backend_t backend = new struct ggml_backend;
2246
0
    backend->guid    = ggml_backend_meta_guid();
2247
0
    backend->iface   = ggml_backend_meta_i;
2248
0
    backend->device  = dev;
2249
0
    backend->context = backend_ctx;
2250
0
    return backend;
2251
0
}
2252
2253
0
size_t ggml_backend_meta_n_backends(ggml_backend_t meta_backend) {
2254
0
    GGML_ASSERT(ggml_backend_is_meta(meta_backend));
2255
0
    const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context;
2256
0
    return backend_ctx->backend_configs.size();
2257
0
}
2258
2259
0
ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index) {
2260
0
    GGML_ASSERT(ggml_backend_is_meta(meta_backend));
2261
0
    const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context;
2262
0
    return backend_ctx->backend_configs[index].backend;
2263
0
}