/src/llama.cpp/ggml/src/ggml-backend-meta.cpp
Line | Count | Source |
1 | | #include "ggml.h" |
2 | | #include "ggml-impl.h" |
3 | | #include "ggml-backend.h" |
4 | | #include "ggml-backend-impl.h" |
5 | | #include "ggml-alloc.h" |
6 | | #include "ggml-cpp.h" |
7 | | |
8 | | #include <algorithm> |
9 | | #include <cassert> |
10 | | #include <cmath> |
11 | | #include <cstddef> |
12 | | #include <cstdint> |
13 | | #include <cstring> |
14 | | #include <map> |
15 | | #include <memory> |
16 | | #include <set> |
17 | | #include <string> |
18 | | #include <tuple> |
19 | | #include <utility> |
20 | | #include <vector> |
21 | | |
22 | | struct ggml_backend_meta_device; |
23 | | struct ggml_backend_meta_buffer_type; |
24 | | struct ggml_backend_meta_buffer; |
25 | | struct ggml_backend_meta; |
26 | | |
27 | 0 | const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis) { |
28 | 0 | switch (split_axis) { |
29 | 0 | case GGML_BACKEND_SPLIT_AXIS_0: |
30 | 0 | return "0"; |
31 | 0 | case GGML_BACKEND_SPLIT_AXIS_1: |
32 | 0 | return "1"; |
33 | 0 | case GGML_BACKEND_SPLIT_AXIS_2: |
34 | 0 | return "2"; |
35 | 0 | case GGML_BACKEND_SPLIT_AXIS_3: |
36 | 0 | return "3"; |
37 | 0 | case GGML_BACKEND_SPLIT_AXIS_MIRRORED: |
38 | 0 | return "MIRRORED"; |
39 | 0 | case GGML_BACKEND_SPLIT_AXIS_PARTIAL: |
40 | 0 | return "PARTIAL"; |
41 | 0 | case GGML_BACKEND_SPLIT_AXIS_NONE: |
42 | 0 | return "NONE"; |
43 | 0 | case GGML_BACKEND_SPLIT_AXIS_UNKNOWN: |
44 | 0 | return "UNKNOWN"; |
45 | 0 | default: |
46 | 0 | GGML_ABORT("fatal error"); |
47 | 0 | } |
48 | 0 | } |
49 | | |
50 | | // |
51 | | // meta backend device |
52 | | // |
53 | | |
54 | | struct ggml_backend_meta_device_context { |
55 | | std::vector<ggml_backend_dev_t> simple_devs; |
56 | | ggml_backend_meta_get_split_state_t get_split_state; |
57 | | void * get_split_state_ud; |
58 | | |
59 | | std::string name; |
60 | | std::string description; |
61 | | |
62 | | ggml_backend_meta_device_context( |
63 | | std::vector<ggml_backend_dev_t> simple_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) : |
64 | 0 | simple_devs(std::move(simple_devs)), get_split_state(get_split_state), get_split_state_ud(get_split_state_ud) { |
65 | 0 | name = std::string("Meta("); |
66 | 0 | description = std::string("Meta("); |
67 | 0 | for (size_t i = 0; i < simple_devs.size(); i++) { |
68 | 0 | if (i > 0) { |
69 | 0 | name += ","; |
70 | 0 | description += ","; |
71 | 0 | } |
72 | 0 | name += ggml_backend_dev_name (simple_devs[i]); |
73 | 0 | description += ggml_backend_dev_description(simple_devs[i]); |
74 | 0 | } |
75 | 0 | name += ")"; |
76 | 0 | description += ")"; |
77 | 0 | } |
78 | | |
79 | 0 | bool operator<(const ggml_backend_meta_device_context & other) const { |
80 | 0 | return std::tie(simple_devs, get_split_state, get_split_state_ud) |
81 | 0 | < std::tie(other.simple_devs, other.get_split_state, other.get_split_state_ud); |
82 | 0 | } |
83 | | }; |
84 | | |
85 | | static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev); |
86 | | |
87 | 0 | static const char * ggml_backend_meta_device_get_name(ggml_backend_dev_t dev) { |
88 | 0 | GGML_ASSERT(ggml_backend_dev_is_meta(dev)); |
89 | 0 | const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; |
90 | 0 | return meta_dev_ctx->name.c_str(); |
91 | 0 | } |
92 | | |
93 | 0 | static const char * ggml_backend_meta_device_get_description(ggml_backend_dev_t dev) { |
94 | 0 | GGML_ASSERT(ggml_backend_dev_is_meta(dev)); |
95 | 0 | const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; |
96 | 0 | return meta_dev_ctx->description.c_str(); |
97 | 0 | } |
98 | | |
99 | 0 | static void ggml_backend_meta_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { |
100 | 0 | GGML_ASSERT(ggml_backend_dev_is_meta(dev)); |
101 | 0 | const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; |
102 | 0 | *free = 0; |
103 | 0 | *total = 0; |
104 | 0 | for (ggml_backend_dev_t dev : meta_dev_ctx->simple_devs) { |
105 | 0 | size_t tmp_free, tmp_total; |
106 | 0 | ggml_backend_dev_memory(dev, &tmp_free, &tmp_total); |
107 | 0 | *free += tmp_free; |
108 | 0 | *total += tmp_total; |
109 | 0 | } |
110 | 0 | } |
111 | | |
112 | 0 | static enum ggml_backend_dev_type ggml_backend_meta_device_get_type(ggml_backend_dev_t dev) { |
113 | 0 | return GGML_BACKEND_DEVICE_TYPE_META; |
114 | | |
115 | 0 | GGML_UNUSED(dev); |
116 | 0 | } |
117 | | |
118 | 0 | static void ggml_backend_meta_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { |
119 | 0 | GGML_ASSERT(ggml_backend_dev_is_meta(dev)); |
120 | 0 | const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; |
121 | | |
122 | | // TODO replace placeholders |
123 | 0 | props->name = ggml_backend_meta_device_get_name(dev); |
124 | 0 | props->description = ggml_backend_meta_device_get_description(dev); |
125 | 0 | props->type = ggml_backend_meta_device_get_type(dev); |
126 | 0 | props->device_id = 0; |
127 | |
|
128 | 0 | ggml_backend_meta_device_get_memory(dev, &props->memory_free, &props->memory_total); |
129 | |
|
130 | 0 | props->caps = { |
131 | 0 | /* .async = */ true, |
132 | 0 | /* .host_buffer = */ false, // Not implemented. |
133 | 0 | /* .buffer_from_host_ptr = */ false, // Not implemented. |
134 | 0 | /* .events = */ false, // Not implemented. |
135 | 0 | }; |
136 | 0 | for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) { |
137 | 0 | ggml_backend_dev_props tmp_props; |
138 | 0 | ggml_backend_dev_get_props(simple_dev, &tmp_props); |
139 | 0 | props->caps.async = props->caps.async && tmp_props.caps.async; |
140 | 0 | props->caps.host_buffer = props->caps.host_buffer && tmp_props.caps.host_buffer; |
141 | 0 | props->caps.buffer_from_host_ptr = props->caps.buffer_from_host_ptr && tmp_props.caps.buffer_from_host_ptr; |
142 | 0 | props->caps.events = props->caps.events && tmp_props.caps.events; |
143 | 0 | } |
144 | 0 | } |
145 | | |
146 | | static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params); |
147 | | |
148 | | static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev); |
149 | | |
150 | | static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev); |
151 | | |
152 | 0 | static bool ggml_backend_meta_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { |
153 | 0 | GGML_ASSERT(ggml_backend_dev_is_meta(dev)); |
154 | 0 | const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; |
155 | 0 | return std::all_of(meta_dev_ctx->simple_devs.begin(), meta_dev_ctx->simple_devs.end(), |
156 | 0 | [op](ggml_backend_dev_t simple_dev) { return ggml_backend_dev_supports_op(simple_dev, op); }); |
157 | 0 | } |
158 | | |
159 | 0 | static bool ggml_backend_meta_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { |
160 | 0 | GGML_ASSERT(ggml_backend_dev_is_meta(dev)); |
161 | 0 | ggml_backend_dev_t dev_buft = ggml_backend_buft_get_device(buft); |
162 | 0 | if (!ggml_backend_dev_is_meta(dev_buft)) { |
163 | 0 | return false; |
164 | 0 | } |
165 | 0 | const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; |
166 | 0 | const ggml_backend_meta_device_context * meta_buft_dev_ctx = (const ggml_backend_meta_device_context *) dev_buft->context; |
167 | 0 | if (meta_dev_ctx->simple_devs.size() != meta_buft_dev_ctx->simple_devs.size()) { |
168 | 0 | return false; |
169 | 0 | } |
170 | 0 | for (size_t i = 0; i < meta_dev_ctx->simple_devs.size(); i++) { |
171 | 0 | if (meta_dev_ctx->simple_devs[i] != meta_buft_dev_ctx->simple_devs[i]) { |
172 | 0 | return false; |
173 | 0 | } |
174 | 0 | } |
175 | 0 | return true; |
176 | 0 | } |
177 | | |
178 | | static const ggml_backend_device_i ggml_backend_meta_device_iface = { |
179 | | /* .get_name = */ ggml_backend_meta_device_get_name, |
180 | | /* .get_description = */ ggml_backend_meta_device_get_description, |
181 | | /* .get_memory = */ ggml_backend_meta_device_get_memory, |
182 | | /* .get_type = */ ggml_backend_meta_device_get_type, |
183 | | /* .get_props = */ ggml_backend_meta_device_get_props, |
184 | | /* .init_backend = */ ggml_backend_meta_device_init_backend, |
185 | | /* .get_buffer_type = */ ggml_backend_meta_device_get_buffer_type, |
186 | | /* .get_host_buffer_type = */ ggml_backend_meta_device_get_host_buffer_type, |
187 | | /* .buffer_from_host_ptr = */ nullptr, |
188 | | /* .supports_op = */ ggml_backend_meta_device_supports_op, |
189 | | /* .supports_buft = */ ggml_backend_meta_device_supports_buft, |
190 | | /* .offload_op = */ nullptr, |
191 | | /* .event_new = */ nullptr, |
192 | | /* .event_free = */ nullptr, |
193 | | /* .event_synchronize = */ nullptr, |
194 | | }; |
195 | | |
196 | 0 | static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev) { |
197 | 0 | return dev != nullptr && dev->iface.get_name == ggml_backend_meta_device_iface.get_name; |
198 | 0 | } |
199 | | |
200 | 0 | static size_t ggml_backend_meta_dev_n_devs(ggml_backend_dev_t meta_dev) { |
201 | 0 | GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev)); |
202 | 0 | const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context; |
203 | 0 | return meta_dev_ctx->simple_devs.size(); |
204 | 0 | } |
205 | | |
206 | 0 | static ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev, size_t index) { |
207 | 0 | GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev)); |
208 | 0 | const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context; |
209 | 0 | GGML_ASSERT(index < meta_dev_ctx->simple_devs.size()); |
210 | 0 | return meta_dev_ctx->simple_devs[index]; |
211 | 0 | } |
212 | | |
213 | | ggml_backend_dev_t ggml_backend_meta_device( |
214 | 0 | ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) { |
215 | 0 | GGML_ASSERT(n_devs <= GGML_BACKEND_META_MAX_DEVICES); |
216 | | // TODO: this is not thread-safe - needs to be fixed |
217 | 0 | static std::vector<std::unique_ptr<ggml_backend_meta_device_context>> ctxs; |
218 | 0 | static std::map<ggml_backend_meta_device_context, struct ggml_backend_device> meta_devs; |
219 | |
|
220 | 0 | std::vector<ggml_backend_dev_t> simple_devs; |
221 | 0 | simple_devs.reserve(n_devs); |
222 | 0 | for (size_t i = 0; i < n_devs; i++) { |
223 | 0 | simple_devs.push_back(devs[i]); |
224 | 0 | } |
225 | 0 | ggml_backend_meta_device_context ctx(simple_devs, get_split_state, get_split_state_ud); |
226 | |
|
227 | 0 | { |
228 | 0 | auto it = meta_devs.find(ctx); |
229 | 0 | if (it != meta_devs.end()) { |
230 | 0 | return &it->second; |
231 | 0 | } |
232 | 0 | } |
233 | 0 | ctxs.push_back(std::make_unique<ggml_backend_meta_device_context>(ctx)); |
234 | |
|
235 | 0 | struct ggml_backend_device meta_dev = { |
236 | 0 | /*iface =*/ ggml_backend_meta_device_iface, |
237 | 0 | /*reg =*/ nullptr, |
238 | 0 | /*ctx =*/ ctxs.back().get(), |
239 | 0 | }; |
240 | |
|
241 | 0 | auto result = meta_devs.emplace(*ctxs.back(), meta_dev); |
242 | 0 | return &result.first->second; |
243 | 0 | } |
244 | | |
245 | | // |
246 | | // meta backend buffer type |
247 | | // |
248 | | |
249 | | struct ggml_backend_meta_buffer_type_context { |
250 | | std::vector<ggml_backend_buffer_type_t> simple_bufts; |
251 | | |
252 | | std::string name; |
253 | | |
254 | 0 | ggml_backend_meta_buffer_type_context(std::vector<ggml_backend_buffer_type_t> simple_bufts) : simple_bufts(std::move(simple_bufts)) { |
255 | 0 | name = "Meta("; |
256 | 0 | for (size_t i = 0; i < simple_bufts.size(); i++) { |
257 | 0 | if (i > 0) { |
258 | 0 | name += ","; |
259 | 0 | } |
260 | 0 | name += ggml_backend_buft_name(simple_bufts[i]); |
261 | 0 | } |
262 | 0 | name += ")"; |
263 | 0 | } |
264 | | |
265 | 0 | bool operator<(const ggml_backend_meta_buffer_type_context & other) const { |
266 | 0 | return simple_bufts < other.simple_bufts; |
267 | 0 | } |
268 | | }; |
269 | | |
270 | 0 | static size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft) { |
271 | 0 | GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft)); |
272 | 0 | const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context; |
273 | 0 | return meta_buft_ctx->simple_bufts.size(); |
274 | 0 | } |
275 | | |
276 | 0 | static const char * ggml_backend_meta_buffer_type_get_name(ggml_backend_buffer_type_t buft) { |
277 | 0 | GGML_ASSERT(ggml_backend_buft_is_meta(buft)); |
278 | 0 | const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) buft->context; |
279 | 0 | return meta_buft_ctx->name.c_str(); |
280 | 0 | } |
281 | | |
282 | 0 | static ggml_backend_buffer_type_t ggml_backend_meta_buft_simple_buft(ggml_backend_buffer_type_t meta_buft, size_t index) { |
283 | 0 | GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft)); |
284 | 0 | const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context; |
285 | 0 | GGML_ASSERT(index < meta_buft_ctx->simple_bufts.size()); |
286 | 0 | return meta_buft_ctx->simple_bufts[index]; |
287 | 0 | } |
288 | | |
289 | | static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); |
290 | | |
291 | 0 | static size_t ggml_backend_meta_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { |
292 | 0 | const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); |
293 | 0 | size_t max_alignment = 1; |
294 | 0 | for (size_t i = 0; i < n_simple_bufts; i++) { |
295 | 0 | const size_t alignment = ggml_backend_buft_get_alignment(ggml_backend_meta_buft_simple_buft(buft, i)); |
296 | 0 | max_alignment = std::max(max_alignment, alignment); |
297 | 0 | GGML_ASSERT(max_alignment % alignment == 0); |
298 | 0 | } |
299 | 0 | return max_alignment; |
300 | 0 | } |
301 | | |
302 | 0 | static size_t ggml_backend_meta_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { |
303 | 0 | const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); |
304 | 0 | size_t max_size = SIZE_MAX; |
305 | 0 | for (size_t i = 0; i < n_simple_bufts; i++) { |
306 | 0 | max_size = std::min(max_size, ggml_backend_buft_get_max_size(ggml_backend_meta_buft_simple_buft(buft, i))); |
307 | 0 | } |
308 | 0 | return max_size; |
309 | 0 | } |
310 | | |
311 | 0 | static size_t ggml_backend_meta_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { |
312 | 0 | const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); |
313 | 0 | size_t max_alloc_size = 0; |
314 | 0 | for (size_t i = 0; i < n_simple_bufts; i++) { |
315 | 0 | const size_t alloc_size = ggml_backend_buft_get_alloc_size(ggml_backend_meta_buft_simple_buft(buft, i), tensor); |
316 | 0 | max_alloc_size = std::max(max_alloc_size, alloc_size); |
317 | 0 | } |
318 | 0 | return max_alloc_size; |
319 | 0 | } |
320 | | |
321 | 0 | static bool ggml_backend_meta_buffer_type_is_host(ggml_backend_buffer_type_t buft) { |
322 | 0 | const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); |
323 | 0 | for (size_t i = 0; i < n_simple_bufts; i++) { |
324 | 0 | if (!ggml_backend_buft_is_host(ggml_backend_meta_buft_simple_buft(buft, i))) { |
325 | 0 | return false; |
326 | 0 | } |
327 | 0 | } |
328 | 0 | return true; |
329 | 0 | } |
330 | | |
331 | | static const struct ggml_backend_buffer_type_i ggml_backend_meta_buffer_type_iface = { |
332 | | /* .get_name = */ ggml_backend_meta_buffer_type_get_name, |
333 | | /* .alloc_buffer = */ ggml_backend_meta_buffer_type_alloc_buffer, |
334 | | /* .get_alignment = */ ggml_backend_meta_buffer_type_get_alignment, |
335 | | /* .get_max_size = */ ggml_backend_meta_buffer_type_get_max_size, |
336 | | /* .get_alloc_size = */ ggml_backend_meta_buffer_type_get_alloc_size, |
337 | | /* .is_host = */ ggml_backend_meta_buffer_type_is_host, |
338 | | }; |
339 | | |
340 | 0 | bool ggml_backend_buft_is_meta(ggml_backend_buffer_type_t buft) { |
341 | 0 | return buft != nullptr && buft->iface.get_name == ggml_backend_meta_buffer_type_iface.get_name; |
342 | 0 | } |
343 | | |
344 | 0 | static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev) { |
345 | 0 | static std::map<ggml_backend_dev_t, struct ggml_backend_buffer_type> meta_bufts; |
346 | 0 | GGML_ASSERT(ggml_backend_dev_is_meta(dev)); |
347 | 0 | { |
348 | 0 | auto it = meta_bufts.find(dev); |
349 | 0 | if (it != meta_bufts.end()) { |
350 | 0 | return &it->second; |
351 | 0 | } |
352 | 0 | } |
353 | | |
354 | 0 | const size_t n_devs = ggml_backend_meta_dev_n_devs(dev); |
355 | 0 | std::vector<ggml_backend_buffer_type_t> simple_bufts; |
356 | 0 | simple_bufts.reserve(n_devs); |
357 | 0 | for (size_t i = 0; i < n_devs; i++) { |
358 | 0 | simple_bufts.push_back(ggml_backend_dev_buffer_type(ggml_backend_meta_dev_simple_dev(dev, i))); |
359 | 0 | } |
360 | 0 | ggml_backend_meta_buffer_type_context * buft_ctx = new ggml_backend_meta_buffer_type_context(simple_bufts); |
361 | |
|
362 | 0 | struct ggml_backend_buffer_type meta_buft = { |
363 | 0 | /*iface =*/ ggml_backend_meta_buffer_type_iface, |
364 | 0 | /*device =*/ dev, |
365 | 0 | /*ctx =*/ buft_ctx, |
366 | 0 | }; |
367 | 0 | auto result = meta_bufts.emplace(dev, meta_buft); |
368 | 0 | return &result.first->second; |
369 | 0 | } |
370 | | |
371 | 0 | static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev) { |
372 | 0 | GGML_ASSERT(ggml_backend_dev_is_meta(dev)); |
373 | 0 | const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; |
374 | |
|
375 | 0 | ggml_backend_buffer_type_t host_buft = nullptr; |
376 | 0 | for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) { |
377 | 0 | ggml_backend_buffer_type_t simple_host_buft = ggml_backend_dev_host_buffer_type(simple_dev); |
378 | 0 | if (simple_host_buft == nullptr) { |
379 | 0 | return nullptr; |
380 | 0 | } |
381 | 0 | if (host_buft == nullptr) { |
382 | 0 | host_buft = simple_host_buft; |
383 | 0 | } else if (host_buft != simple_host_buft) { |
384 | | // if different simple devices have different host buffer types, |
385 | | // we cannot provide a single host buffer type for the meta device |
386 | 0 | return nullptr; |
387 | 0 | } |
388 | 0 | } |
389 | 0 | return host_buft; |
390 | 0 | } |
391 | | |
392 | | // |
393 | | // meta backend buffer |
394 | | // |
395 | | |
396 | | // Container to hold the tensor slices per simple ggml backend buffer. |
397 | | struct ggml_backend_meta_simple_tensor_container { |
398 | | std::vector<ggml_context_ptr> ctxs; |
399 | | std::map<const ggml_tensor *, std::vector<ggml_tensor *>> simple_tensors; |
400 | | |
401 | 0 | ggml_backend_meta_simple_tensor_container(const ggml_init_params & params, const int n_simple) { |
402 | 0 | ctxs.reserve(n_simple); |
403 | 0 | for (int i = 0; i < n_simple; i++) { |
404 | 0 | ctxs.emplace_back(ggml_init(params)); |
405 | 0 | } |
406 | 0 | } |
407 | 0 | ggml_backend_meta_simple_tensor_container() {} |
408 | | }; |
409 | | |
410 | | struct ggml_backend_meta_buffer_context { |
411 | | // FIXME |
412 | | // Most tensors can simply be stored statically in their own buffer. |
413 | | // Externally created views however also need a mapping to simple tensors but they use the buffer of the view source. |
414 | | // If external views are simply using that buffer they will slowly deplete its memory. |
415 | | // Current solution: rotating set of 2 "compute" containers to hold external views, works correctly for llama.cpp. |
416 | | // Long-term: tie the lifetime of external views to the meta backend executing the graph instead, |
417 | | // currently not possible due to graph-external operations in the backend scheduler. |
418 | | ggml_backend_meta_simple_tensor_container stc_static; |
419 | | ggml_backend_meta_simple_tensor_container stc_compute[2]; |
420 | | int stc_compute_index = 0; |
421 | | int stc_compute_index_next = 0; |
422 | | std::vector<ggml_backend_buffer_ptr> bufs; |
423 | | |
424 | | // FIXME |
425 | | // The size of the split state cache is unbounded and can theoretically grow infinitely large. |
426 | | // However, it is also expensive to build and clearing it on every rebuild in ggml_backend_meta_graph_compute is too expensive. |
427 | | static constexpr size_t nbtc = GGML_TENSOR_SIZE - sizeof(ggml_tensor::padding); |
428 | | std::map<std::pair<const ggml_tensor *, bool>, std::pair<ggml_backend_meta_split_state, char[nbtc]>> split_state_cache; |
429 | | |
430 | | int debug; |
431 | | |
432 | | ggml_backend_meta_buffer_context( |
433 | | ggml_backend_meta_simple_tensor_container & stc_static, |
434 | | ggml_backend_meta_simple_tensor_container & stc_compute_0, |
435 | | ggml_backend_meta_simple_tensor_container & stc_compute_1, |
436 | | const std::vector<ggml_backend_buffer_t> & bufs) |
437 | 0 | : stc_static(std::move(stc_static)), stc_compute{std::move(stc_compute_0), std::move(stc_compute_1)} { |
438 | 0 | this->bufs.reserve(bufs.size()); |
439 | 0 | for (ggml_backend_buffer_t buf : bufs) { |
440 | 0 | this->bufs.emplace_back(buf); |
441 | 0 | } |
442 | 0 | const char * GGML_META_DEBUG = getenv("GGML_META_DEBUG"); |
443 | 0 | debug = GGML_META_DEBUG ? atoi(GGML_META_DEBUG) : 0; |
444 | 0 | } |
445 | | |
446 | 0 | ggml_backend_meta_simple_tensor_container & get_simple_tensor_container(const ggml_tensor * tensor) { |
447 | 0 | if (stc_static.simple_tensors.find(tensor) != stc_static.simple_tensors.end()) { |
448 | 0 | return stc_static; |
449 | 0 | } |
450 | 0 | return stc_compute[stc_compute_index]; |
451 | 0 | } |
452 | | }; |
453 | | |
454 | 0 | static void ggml_backend_meta_buffer_free_buffer(ggml_backend_buffer_t buffer) { |
455 | 0 | GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); |
456 | 0 | ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; |
457 | 0 | delete buf_ctx; |
458 | 0 | } |
459 | | |
460 | 0 | static size_t ggml_backend_meta_buffer_n_bufs(ggml_backend_buffer_t meta_buf) { |
461 | 0 | GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf)); |
462 | 0 | ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context; |
463 | 0 | return buf_ctx->bufs.size(); |
464 | 0 | } |
465 | | |
466 | 0 | static ggml_backend_buffer_t ggml_backend_meta_buffer_simple_buffer(ggml_backend_buffer_t meta_buf, size_t index) { |
467 | 0 | GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf)); |
468 | 0 | ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context; |
469 | 0 | GGML_ASSERT(index < buf_ctx->bufs.size()); |
470 | 0 | return buf_ctx->bufs[index].get(); |
471 | 0 | } |
472 | | |
473 | 0 | static struct ggml_tensor * ggml_backend_meta_buffer_simple_tensor(const struct ggml_tensor * tensor, size_t index) { |
474 | 0 | GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); |
475 | 0 | ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; |
476 | 0 | GGML_ASSERT(index < buf_ctx->bufs.size()); |
477 | |
|
478 | 0 | ggml_backend_meta_simple_tensor_container & stc = buf_ctx->get_simple_tensor_container(tensor); |
479 | 0 | auto it = stc.simple_tensors.find(tensor); |
480 | 0 | if (it == stc.simple_tensors.end()) { |
481 | 0 | return nullptr; |
482 | 0 | } |
483 | 0 | return it->second[index]; |
484 | 0 | } |
485 | | |
486 | | static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync); |
487 | | |
488 | | static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( |
489 | 0 | ggml_backend_meta_simple_tensor_container & stc, const struct ggml_tensor * tensor, bool assume_sync) { |
490 | | // FIXME Currently this function preserves/erases the information in n_segments and nr in an inconsistent way. |
491 | | // Since the operations in question are developed specifically for llama.cpp this currently does not manifest as a bug there. |
492 | | // However, in a broader ggml context with arbitrary ggml graphs this can lead to unexpected results. |
493 | 0 | const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); |
494 | 0 | ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; |
495 | |
|
496 | 0 | auto split_states_equal = [&](const ggml_backend_meta_split_state & a, const ggml_backend_meta_split_state & b) -> bool { |
497 | 0 | if (a.axis != b.axis) { |
498 | 0 | return false; |
499 | 0 | } |
500 | 0 | for (size_t j = 0; j < n_bufs; j++) { |
501 | 0 | int64_t sum_a = 0; |
502 | 0 | for (size_t s = 0; s < a.n_segments; s++) { |
503 | 0 | sum_a += a.ne[s*n_bufs + j] * a.nr[s]; |
504 | 0 | } |
505 | 0 | int64_t sum_b = 0; |
506 | 0 | for (size_t s = 0; s < b.n_segments; s++) { |
507 | 0 | sum_b += b.ne[s*n_bufs + j] * b.nr[s]; |
508 | 0 | } |
509 | 0 | if (sum_a != sum_b) { |
510 | 0 | return false; |
511 | 0 | } |
512 | 0 | } |
513 | 0 | return true; |
514 | 0 | }; |
515 | |
|
516 | 0 | auto handle_generic = [&](const std::vector<ggml_backend_meta_split_state> & src_ss, bool scalar_only) -> ggml_backend_meta_split_state { |
517 | 0 | ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1}; |
518 | 0 | for (size_t i = 0; i < GGML_MAX_SRC; i++) { |
519 | 0 | if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { |
520 | 0 | continue; |
521 | 0 | } |
522 | 0 | if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { |
523 | 0 | ret = src_ss[i]; |
524 | 0 | } else if (!split_states_equal(src_ss[i], ret)) { |
525 | 0 | ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; |
526 | 0 | break; |
527 | 0 | } |
528 | 0 | } |
529 | 0 | if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { |
530 | 0 | ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; |
531 | 0 | } |
532 | 0 | if (scalar_only && ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { |
533 | 0 | ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; |
534 | 0 | } |
535 | 0 | GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); |
536 | 0 | return ret; |
537 | 0 | }; |
538 | | |
539 | | // Some ops process data on a per-row bases: |
540 | 0 | auto handle_per_row = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { |
541 | 0 | GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_0); |
542 | 0 | return src_ss[0]; |
543 | 0 | }; |
544 | | |
545 | | // Some ops broadcast the src1 data across src0: |
546 | 0 | auto handle_bin_bcast = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { |
547 | 0 | if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS && |
548 | 0 | tensor->src[1]->ne[src_ss[0].axis] == 1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { |
549 | 0 | return src_ss[0]; |
550 | 0 | } |
551 | 0 | if (src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[0].axis == src_ss[1].axis || |
552 | 0 | (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL)))) { |
553 | 0 | return src_ss[0]; // GGML_OP_ADD_ID |
554 | 0 | } |
555 | 0 | GGML_ASSERT(tensor->src[2] == nullptr || src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); |
556 | 0 | return handle_generic(src_ss, /*scalar_only =*/ false); |
557 | 0 | }; |
558 | |
|
559 | 0 | auto handle_concat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { |
560 | 0 | const ggml_backend_meta_split_axis concat_axis = ggml_backend_meta_split_axis(ggml_get_op_params_i32(tensor, 0)); |
561 | 0 | if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis >= 0 && src_ss[1].axis < GGML_MAX_DIMS) { |
562 | 0 | GGML_ASSERT(concat_axis != src_ss[1].axis); |
563 | 0 | return src_ss[1]; |
564 | 0 | } |
565 | 0 | if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { |
566 | 0 | GGML_ASSERT(concat_axis != src_ss[0].axis); |
567 | 0 | return src_ss[0]; |
568 | 0 | } |
569 | 0 | if (src_ss[0].axis == src_ss[1].axis && src_ss[0].axis != concat_axis) { |
570 | 0 | return src_ss[0]; |
571 | 0 | } |
572 | 0 | return handle_generic(src_ss, /*scalar_only =*/ true); |
573 | 0 | }; |
574 | |
|
575 | 0 | auto handle_mul_mat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { |
576 | 0 | if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { |
577 | 0 | return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1}; |
578 | 0 | } |
579 | 0 | if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { |
580 | 0 | ggml_backend_meta_split_state ret = src_ss[0]; |
581 | 0 | ret.axis = GGML_BACKEND_SPLIT_AXIS_0; |
582 | 0 | ret.nr[0] = 1; |
583 | 0 | ret.n_segments = 1; |
584 | 0 | return ret; |
585 | 0 | } |
586 | 0 | if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { |
587 | 0 | return src_ss[1]; |
588 | 0 | } |
589 | 0 | if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_0) { |
590 | 0 | GGML_ASSERT(split_states_equal(src_ss[0], src_ss[1])); |
591 | 0 | return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, {1}, 1}; |
592 | 0 | } |
593 | 0 | GGML_ABORT("fatal error"); |
594 | | //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; |
595 | 0 | }; |
596 | |
|
597 | 0 | auto handle_reshape = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { |
598 | 0 | switch (src_ss[0].axis) { |
599 | 0 | case GGML_BACKEND_SPLIT_AXIS_0: |
600 | 0 | case GGML_BACKEND_SPLIT_AXIS_1: |
601 | 0 | case GGML_BACKEND_SPLIT_AXIS_2: |
602 | 0 | case GGML_BACKEND_SPLIT_AXIS_3: { |
603 | 0 | GGML_ASSERT(src_ss[0].n_segments == 1); |
604 | 0 | if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1 && src_ss[0].nr[0] == 1) { |
605 | 0 | return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, {1}, 1}; |
606 | 0 | } |
607 | 0 | int64_t base_ne_in = tensor->src[0]->ne[0]; |
608 | 0 | for (int dim = 1; dim <= src_ss[0].axis; dim++) { |
609 | 0 | base_ne_in *= tensor->src[0]->ne[dim]; |
610 | 0 | } |
611 | 0 | base_ne_in /= src_ss[0].nr[0]; |
612 | 0 | int64_t base_ne_out = 1; |
613 | 0 | for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { |
614 | 0 | const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim]; |
615 | 0 | if (base_ne_out_next % base_ne_in == 0) { |
616 | 0 | return {ggml_backend_meta_split_axis(dim), {0}, {uint32_t(base_ne_out_next/base_ne_in)}, 1}; |
617 | 0 | } |
618 | 0 | if (base_ne_out_next > base_ne_in) { |
619 | 0 | GGML_ASSERT(src_ss[0].n_segments == 1); |
620 | 0 | GGML_ASSERT(src_ss[0].nr[0] == 1); |
621 | 0 | return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1}; |
622 | 0 | } |
623 | 0 | base_ne_out = base_ne_out_next; |
624 | 0 | } |
625 | 0 | GGML_ABORT("shape mismatch for %s", ggml_op_name(tensor->op)); |
626 | 0 | } |
627 | 0 | case GGML_BACKEND_SPLIT_AXIS_MIRRORED: |
628 | 0 | case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { |
629 | 0 | return src_ss[0]; |
630 | 0 | } |
631 | 0 | default: { |
632 | 0 | GGML_ABORT("fatal error"); |
633 | | //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; |
634 | 0 | } |
635 | 0 | } |
636 | 0 | }; |
637 | |
|
638 | 0 | auto handle_cpy = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { |
639 | 0 | if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { |
640 | 0 | return handle_reshape(src_ss); |
641 | 0 | } |
642 | 0 | return handle_generic(src_ss, /*scalar_only =*/ false); |
643 | 0 | }; |
644 | |
|
645 | 0 | auto handle_view = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { |
646 | 0 | if (ggml_is_contiguous(tensor) && ggml_is_contiguous(tensor->src[0])) { |
647 | 0 | return handle_reshape(src_ss); |
648 | 0 | } |
649 | 0 | const int axis = src_ss[0].axis; |
650 | 0 | { |
651 | 0 | bool all_strides_the_same = true; |
652 | 0 | for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { |
653 | 0 | if (tensor->ne[dim] == 1 && tensor->src[0]->ne[dim] == 1) { |
654 | 0 | continue; |
655 | 0 | } |
656 | 0 | if (tensor->nb[dim] != tensor->src[0]->nb[dim]) { |
657 | 0 | all_strides_the_same = false; |
658 | 0 | break; |
659 | 0 | } |
660 | 0 | } |
661 | 0 | if (all_strides_the_same) { |
662 | 0 | return src_ss[0]; |
663 | 0 | } |
664 | 0 | } |
665 | 0 | if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]) && axis >= 0 && axis < GGML_MAX_DIMS-1) { |
666 | 0 | for (int dim = 0; dim < GGML_MAX_DIMS-1; dim++) { |
667 | 0 | if (tensor->nb[dim+1] == tensor->src[0]->nb[axis+1]) { |
668 | 0 | return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1}; |
669 | 0 | } |
670 | 0 | } |
671 | 0 | GGML_ABORT("fatal error"); |
672 | 0 | } |
673 | 0 | if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED || src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { |
674 | 0 | return src_ss[0]; |
675 | 0 | } |
676 | 0 | GGML_ABORT("view of permuted tensor not implemented"); |
677 | | //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; |
678 | 0 | }; |
679 | |
|
680 | 0 | auto handle_permute = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { |
681 | 0 | switch (src_ss[0].axis) { |
682 | 0 | case GGML_BACKEND_SPLIT_AXIS_0: |
683 | 0 | case GGML_BACKEND_SPLIT_AXIS_1: |
684 | 0 | case GGML_BACKEND_SPLIT_AXIS_2: |
685 | 0 | case GGML_BACKEND_SPLIT_AXIS_3: { |
686 | 0 | GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1); |
687 | 0 | return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, {src_ss[0].nr[0]}, 1}; |
688 | 0 | } |
689 | 0 | case GGML_BACKEND_SPLIT_AXIS_MIRRORED: |
690 | 0 | case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { |
691 | 0 | return src_ss[0]; |
692 | 0 | } |
693 | 0 | default: { |
694 | 0 | GGML_ABORT("fatal error"); |
695 | | //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; |
696 | 0 | } |
697 | 0 | } |
698 | 0 | }; |
699 | |
|
700 | 0 | auto handle_transpose = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { |
701 | 0 | switch (src_ss[0].axis) { |
702 | 0 | case GGML_BACKEND_SPLIT_AXIS_0: |
703 | 0 | case GGML_BACKEND_SPLIT_AXIS_1: { |
704 | 0 | GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1); |
705 | 0 | return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, {src_ss[0].nr[0]}, 1}; |
706 | 0 | } |
707 | 0 | case GGML_BACKEND_SPLIT_AXIS_2: |
708 | 0 | case GGML_BACKEND_SPLIT_AXIS_3: |
709 | 0 | case GGML_BACKEND_SPLIT_AXIS_MIRRORED: |
710 | 0 | case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { |
711 | 0 | return src_ss[0]; |
712 | 0 | } |
713 | 0 | default: { |
714 | 0 | GGML_ABORT("fatal error"); |
715 | | //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; |
716 | 0 | } |
717 | 0 | } |
718 | 0 | }; |
719 | |
|
720 | 0 | auto handle_get_rows = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { |
721 | 0 | if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { |
722 | 0 | return src_ss[0]; |
723 | 0 | } |
724 | 0 | return handle_generic(src_ss, /*scalar_only =*/ true); |
725 | 0 | }; |
726 | |
|
727 | 0 | auto handle_set_rows = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { |
728 | 0 | GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_1); |
729 | 0 | GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); |
730 | 0 | GGML_ASSERT(split_states_equal(src_ss[0], src_ss[2])); |
731 | 0 | return src_ss[0]; |
732 | 0 | }; |
733 | |
|
734 | 0 | auto handle_rope = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { |
735 | 0 | GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); |
736 | 0 | return src_ss[0]; |
737 | 0 | }; |
738 | |
|
739 | 0 | auto handle_pad = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { |
740 | 0 | if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { |
741 | 0 | GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 0] == 0); |
742 | 0 | GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 1] == 0); |
743 | 0 | } |
744 | 0 | return src_ss[0]; |
745 | 0 | }; |
746 | |
|
747 | 0 | auto handle_flash_attn_ext = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { |
748 | 0 | GGML_ASSERT( src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_2); |
749 | 0 | GGML_ASSERT( src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_2); |
750 | 0 | GGML_ASSERT( src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_2); |
751 | 0 | GGML_ASSERT(tensor->src[4] == nullptr || src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); |
752 | 0 | GGML_ASSERT(tensor->src[4] == nullptr || src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_0); |
753 | 0 | return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1}; |
754 | 0 | }; |
755 | |
|
756 | 0 | auto handle_ssm_conv = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { |
757 | 0 | if (src_ss[0].axis == src_ss[1].axis) { |
758 | 0 | if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0) { |
759 | 0 | return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1}; |
760 | 0 | } |
761 | 0 | if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1) { |
762 | 0 | return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1}; |
763 | 0 | } |
764 | 0 | } |
765 | 0 | return handle_generic(src_ss, /*scalar_only =*/ false); |
766 | 0 | }; |
767 | |
|
768 | 0 | auto handle_gated_delta_net = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { |
769 | 0 | if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && |
770 | 0 | src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && |
771 | 0 | src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { |
772 | 0 | return src_ss[0]; |
773 | 0 | } |
774 | 0 | GGML_ASSERT(src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1); |
775 | 0 | GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1); |
776 | 0 | GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); |
777 | 0 | GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); |
778 | 0 | GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); |
779 | | // state shape is [S_v, S_v, H_v, n_seqs] (s0 only); the heads dim is its own axis 2, |
780 | | // so a head-aligned split on the input cache lands on axis 2 here. |
781 | 0 | GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); |
782 | 0 | return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1}; |
783 | 0 | }; |
784 | |
|
785 | 0 | auto calculate_split_state = [&]() -> ggml_backend_meta_split_state { |
786 | 0 | if (ggml_nelements(tensor) == 0) { |
787 | 0 | return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; |
788 | 0 | } |
789 | 0 | if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) { |
790 | 0 | ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer)); |
791 | 0 | const ggml_backend_meta_device_context * dev_ctx = (const ggml_backend_meta_device_context *) dev->context; |
792 | 0 | ggml_backend_meta_split_state ret = dev_ctx->get_split_state(tensor, dev_ctx->get_split_state_ud); |
793 | 0 | if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) { |
794 | 0 | const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1; |
795 | 0 | int64_t ne_sum = 0; |
796 | 0 | for (size_t s = 0; s < ret.n_segments; s++) { |
797 | 0 | for (size_t j = 0; j < n_bufs; j++) { |
798 | 0 | GGML_ASSERT(ret.ne[s*n_bufs + j] % granularity == 0); |
799 | 0 | ne_sum += ret.ne[s*n_bufs + j] * ret.nr[s]; |
800 | 0 | } |
801 | 0 | } |
802 | 0 | GGML_ASSERT(ne_sum == tensor->ne[ret.axis]); |
803 | 0 | } |
804 | 0 | return ret; |
805 | 0 | } |
806 | | |
807 | 0 | std::vector<ggml_backend_meta_split_state> src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1}); |
808 | 0 | for (size_t i = 0; i < GGML_MAX_SRC; i++) { |
809 | 0 | if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { |
810 | 0 | src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; |
811 | 0 | continue; |
812 | 0 | } |
813 | 0 | src_ss[i] = ggml_backend_meta_get_split_state(stc, tensor->src[i], /*assume_sync =*/ true); |
814 | 0 | GGML_ASSERT(src_ss[i].axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); |
815 | 0 | } |
816 | |
|
817 | 0 | ggml_backend_meta_split_state split_state; |
818 | 0 | switch (tensor->op) { |
819 | 0 | case GGML_OP_NONE: { |
820 | 0 | split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1}; |
821 | 0 | } break; |
822 | 0 | case GGML_OP_DUP: { |
823 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ true); |
824 | 0 | } break; |
825 | 0 | case GGML_OP_ADD: |
826 | 0 | case GGML_OP_ADD_ID: { |
827 | 0 | split_state = handle_bin_bcast(src_ss); |
828 | 0 | } break; |
829 | 0 | case GGML_OP_ADD1: |
830 | 0 | case GGML_OP_ACC: { |
831 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ true); |
832 | 0 | } break; |
833 | 0 | case GGML_OP_SUB: |
834 | 0 | case GGML_OP_MUL: |
835 | 0 | case GGML_OP_DIV: { |
836 | 0 | split_state = handle_bin_bcast(src_ss); |
837 | 0 | } break; |
838 | 0 | case GGML_OP_SQR: |
839 | 0 | case GGML_OP_SQRT: |
840 | 0 | case GGML_OP_LOG: |
841 | 0 | case GGML_OP_SIN: |
842 | 0 | case GGML_OP_COS: { |
843 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ false); |
844 | 0 | } break; |
845 | 0 | case GGML_OP_SUM: { |
846 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ true); |
847 | 0 | } break; |
848 | 0 | case GGML_OP_SUM_ROWS: |
849 | 0 | case GGML_OP_CUMSUM: |
850 | 0 | case GGML_OP_MEAN: |
851 | 0 | case GGML_OP_ARGMAX: |
852 | 0 | case GGML_OP_COUNT_EQUAL: { |
853 | 0 | split_state = handle_per_row(src_ss); |
854 | 0 | } break; |
855 | 0 | case GGML_OP_REPEAT: |
856 | 0 | case GGML_OP_REPEAT_BACK: { |
857 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ false); |
858 | 0 | } break; |
859 | 0 | case GGML_OP_CONCAT: { |
860 | 0 | split_state = handle_concat(src_ss); |
861 | 0 | } break; |
862 | 0 | case GGML_OP_SILU_BACK: { |
863 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ false); |
864 | 0 | } break; |
865 | 0 | case GGML_OP_NORM: |
866 | 0 | case GGML_OP_RMS_NORM: |
867 | 0 | case GGML_OP_RMS_NORM_BACK: |
868 | 0 | case GGML_OP_GROUP_NORM: |
869 | 0 | case GGML_OP_L2_NORM: { |
870 | 0 | split_state = handle_per_row(src_ss); |
871 | 0 | } break; |
872 | 0 | case GGML_OP_MUL_MAT: |
873 | 0 | case GGML_OP_MUL_MAT_ID: { |
874 | 0 | split_state = handle_mul_mat(src_ss); |
875 | 0 | } break; |
876 | 0 | case GGML_OP_OUT_PROD: { |
877 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ true); |
878 | 0 | } break; |
879 | 0 | case GGML_OP_SCALE: { |
880 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ false); |
881 | 0 | } break; |
882 | 0 | case GGML_OP_SET: { |
883 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ true); |
884 | 0 | } break; |
885 | 0 | case GGML_OP_CPY: { |
886 | 0 | split_state = handle_cpy(src_ss); |
887 | 0 | } break; |
888 | 0 | case GGML_OP_CONT: |
889 | 0 | case GGML_OP_RESHAPE: { |
890 | 0 | split_state = handle_reshape(src_ss); |
891 | 0 | } break; |
892 | 0 | case GGML_OP_VIEW: { |
893 | 0 | split_state = handle_view(src_ss); |
894 | 0 | } break; |
895 | 0 | case GGML_OP_PERMUTE: { |
896 | 0 | split_state = handle_permute(src_ss); |
897 | 0 | } break; |
898 | 0 | case GGML_OP_TRANSPOSE: { |
899 | 0 | split_state = handle_transpose(src_ss); |
900 | 0 | } break; |
901 | 0 | case GGML_OP_GET_ROWS: { |
902 | 0 | split_state = handle_get_rows(src_ss); |
903 | 0 | } break; |
904 | 0 | case GGML_OP_GET_ROWS_BACK: { |
905 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ true); |
906 | 0 | } break; |
907 | 0 | case GGML_OP_SET_ROWS: { |
908 | 0 | split_state = handle_set_rows(src_ss); |
909 | 0 | } break; |
910 | 0 | case GGML_OP_DIAG: |
911 | 0 | case GGML_OP_DIAG_MASK_INF: |
912 | 0 | case GGML_OP_DIAG_MASK_ZERO: { |
913 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ true); |
914 | 0 | } break; |
915 | 0 | case GGML_OP_SOFT_MAX: |
916 | 0 | case GGML_OP_SOFT_MAX_BACK: { |
917 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ false); |
918 | 0 | } break; |
919 | 0 | case GGML_OP_ROPE: { |
920 | 0 | split_state = handle_rope(src_ss); |
921 | 0 | } break; |
922 | 0 | case GGML_OP_ROPE_BACK: { |
923 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ true); |
924 | 0 | } break; |
925 | 0 | case GGML_OP_CLAMP: { |
926 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ false); |
927 | 0 | } break; |
928 | 0 | case GGML_OP_CONV_TRANSPOSE_1D: |
929 | 0 | case GGML_OP_IM2COL: |
930 | 0 | case GGML_OP_IM2COL_BACK: |
931 | 0 | case GGML_OP_IM2COL_3D: |
932 | 0 | case GGML_OP_CONV_2D: |
933 | 0 | case GGML_OP_CONV_3D: |
934 | 0 | case GGML_OP_CONV_2D_DW: |
935 | 0 | case GGML_OP_CONV_TRANSPOSE_2D: |
936 | 0 | case GGML_OP_POOL_1D: |
937 | 0 | case GGML_OP_POOL_2D: |
938 | 0 | case GGML_OP_POOL_2D_BACK: |
939 | 0 | case GGML_OP_UPSCALE: { |
940 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ true); |
941 | 0 | } break; |
942 | 0 | case GGML_OP_PAD: { |
943 | 0 | split_state = handle_pad(src_ss); |
944 | 0 | } break; |
945 | 0 | case GGML_OP_PAD_REFLECT_1D: |
946 | 0 | case GGML_OP_ROLL: |
947 | 0 | case GGML_OP_ARANGE: |
948 | 0 | case GGML_OP_TIMESTEP_EMBEDDING: { |
949 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ true); |
950 | 0 | } break; |
951 | 0 | case GGML_OP_ARGSORT: |
952 | 0 | case GGML_OP_TOP_K: { |
953 | 0 | split_state = handle_per_row(src_ss); |
954 | 0 | } break; |
955 | 0 | case GGML_OP_LEAKY_RELU: { |
956 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ false); |
957 | 0 | } break; |
958 | 0 | case GGML_OP_TRI: { |
959 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ true); |
960 | 0 | } break; |
961 | 0 | case GGML_OP_FILL: { |
962 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ false); |
963 | 0 | } break; |
964 | 0 | case GGML_OP_FLASH_ATTN_EXT: { |
965 | 0 | split_state = handle_flash_attn_ext(src_ss); |
966 | 0 | } break; |
967 | 0 | case GGML_OP_FLASH_ATTN_BACK: { |
968 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ true); |
969 | 0 | } break; |
970 | 0 | case GGML_OP_SSM_CONV: { |
971 | 0 | split_state = handle_ssm_conv(src_ss); |
972 | 0 | } break; |
973 | 0 | case GGML_OP_SSM_SCAN: |
974 | 0 | case GGML_OP_WIN_PART: |
975 | 0 | case GGML_OP_WIN_UNPART: |
976 | 0 | case GGML_OP_GET_REL_POS: |
977 | 0 | case GGML_OP_ADD_REL_POS: |
978 | 0 | case GGML_OP_RWKV_WKV6: |
979 | 0 | case GGML_OP_GATED_LINEAR_ATTN: |
980 | 0 | case GGML_OP_RWKV_WKV7: |
981 | 0 | case GGML_OP_SOLVE_TRI: { |
982 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ true); |
983 | 0 | } break; |
984 | 0 | case GGML_OP_GATED_DELTA_NET: { |
985 | 0 | split_state = handle_gated_delta_net(src_ss); |
986 | 0 | } break; |
987 | 0 | case GGML_OP_UNARY: { |
988 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ false); |
989 | 0 | } break; |
990 | 0 | case GGML_OP_MAP_CUSTOM1: |
991 | 0 | case GGML_OP_MAP_CUSTOM2: |
992 | 0 | case GGML_OP_MAP_CUSTOM3: |
993 | 0 | case GGML_OP_CUSTOM: { |
994 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ true); |
995 | 0 | } break; |
996 | 0 | case GGML_OP_CROSS_ENTROPY_LOSS: |
997 | 0 | case GGML_OP_CROSS_ENTROPY_LOSS_BACK: { |
998 | 0 | split_state = handle_per_row(src_ss); |
999 | 0 | } break; |
1000 | 0 | case GGML_OP_OPT_STEP_ADAMW: |
1001 | 0 | case GGML_OP_OPT_STEP_SGD: |
1002 | 0 | case GGML_OP_GLU: { |
1003 | 0 | split_state = handle_generic(src_ss, /*scalar_only =*/ false); |
1004 | 0 | } break; |
1005 | 0 | default: { |
1006 | 0 | GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op)); |
1007 | 0 | split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; |
1008 | 0 | } break; |
1009 | 0 | } |
1010 | 0 | if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { |
1011 | 0 | bool first_src_split_by_axis = true; |
1012 | 0 | const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); |
1013 | |
|
1014 | 0 | for (size_t i = 0; i < GGML_MAX_SRC; i++) { |
1015 | 0 | if (tensor->src[i] == nullptr || src_ss[i].axis < 0 || src_ss[i].axis >= GGML_MAX_DIMS) { |
1016 | 0 | continue; |
1017 | 0 | } |
1018 | 0 | if (first_src_split_by_axis) { |
1019 | 0 | for (size_t j = 0; j < n_bufs; j++) { |
1020 | | // Take over ratio from src: |
1021 | 0 | for (size_t s = 0; s < src_ss[i].n_segments; s++) { |
1022 | 0 | split_state.ne[s*n_bufs + j] = 0; |
1023 | 0 | } |
1024 | 0 | for (size_t s = 0; s < src_ss[i].n_segments; s++) { |
1025 | 0 | split_state.ne[j] += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s]; |
1026 | 0 | } |
1027 | 0 | split_state.ne[j] *= tensor->ne[split_state.axis]; |
1028 | 0 | if (split_state.ne[j] != 0 || tensor->src[i]->ne[src_ss[i].axis] != 0) { |
1029 | 0 | const int64_t div = tensor->src[i]->ne[src_ss[i].axis] * split_state.nr[0]; |
1030 | 0 | GGML_ASSERT(split_state.ne[j] % div == 0); |
1031 | 0 | split_state.ne[j] /= div; |
1032 | 0 | } |
1033 | 0 | } |
1034 | 0 | } else { |
1035 | 0 | GGML_ASSERT(split_state.n_segments == 1); |
1036 | 0 | for (size_t j = 0; j < n_bufs; j++) { |
1037 | | // Assert that ratio is consistent: |
1038 | 0 | int64_t sum = 0; |
1039 | 0 | for (size_t s = 0; s < src_ss[i].n_segments; s++) { |
1040 | 0 | sum += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s]; |
1041 | 0 | } |
1042 | 0 | GGML_ASSERT(split_state.ne[j]*split_state.nr[0] * tensor->src[i]->ne[src_ss[i].axis] |
1043 | 0 | == sum * tensor->ne[split_state.axis]); |
1044 | 0 | } |
1045 | 0 | } |
1046 | 0 | first_src_split_by_axis = false; |
1047 | 0 | } |
1048 | 0 | GGML_ASSERT(!first_src_split_by_axis); |
1049 | 0 | } |
1050 | 0 | return split_state; |
1051 | 0 | }; |
1052 | |
|
1053 | 0 | const std::pair key = std::make_pair(tensor, assume_sync); |
1054 | 0 | auto it = buf_ctx->split_state_cache.find(key); |
1055 | 0 | if (it != buf_ctx->split_state_cache.end() && memcmp(it->second.second, (const char *) tensor, sizeof(it->second.second)) != 0) { |
1056 | 0 | buf_ctx->split_state_cache.clear(); |
1057 | 0 | it = buf_ctx->split_state_cache.end(); |
1058 | 0 | } |
1059 | |
|
1060 | 0 | if (it == buf_ctx->split_state_cache.end()) { |
1061 | 0 | buf_ctx->split_state_cache[key].first = calculate_split_state(); |
1062 | 0 | memcpy(buf_ctx->split_state_cache[key].second, tensor, sizeof(buf_ctx->split_state_cache[key].second)); |
1063 | 0 | if (buf_ctx->debug > 0) { |
1064 | 0 | std::string srcs_info; |
1065 | 0 | for (size_t i = 0; i < GGML_MAX_SRC; i++) { |
1066 | 0 | if (tensor->src[i] == nullptr) { |
1067 | 0 | continue; |
1068 | 0 | } |
1069 | 0 | if (!srcs_info.empty()) { |
1070 | 0 | srcs_info += ", "; |
1071 | 0 | } |
1072 | 0 | const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true); |
1073 | 0 | GGML_ASSERT(split_state.n_segments == 1); |
1074 | 0 | const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis); |
1075 | 0 | std::string ne_info; |
1076 | 0 | for (size_t j = 0; j < n_bufs; j++) { |
1077 | 0 | if (!ne_info.empty()) { |
1078 | 0 | ne_info += ", "; |
1079 | 0 | } |
1080 | 0 | ne_info += std::to_string(split_state.ne[j]) + "x" + std::to_string(split_state.nr[0]); |
1081 | 0 | } |
1082 | 0 | srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]"; |
1083 | 0 | } |
1084 | 0 | std::string ne_info; |
1085 | 0 | for (size_t j = 0; j < n_bufs; j++) { |
1086 | 0 | if (!ne_info.empty()) { |
1087 | 0 | ne_info += ", "; |
1088 | 0 | } |
1089 | 0 | const ggml_backend_meta_split_state & ss = buf_ctx->split_state_cache[key].first; |
1090 | 0 | ne_info += std::to_string(ss.ne[j]) + "x" + std::to_string(ss.nr[0]); |
1091 | 0 | } |
1092 | 0 | GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op), |
1093 | 0 | ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].first.axis), ne_info.c_str()); |
1094 | 0 | } |
1095 | 0 | } |
1096 | |
|
1097 | 0 | ggml_backend_meta_split_state ret = buf_ctx->split_state_cache[key].first; |
1098 | 0 | GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_NONE); |
1099 | | #ifndef NDEBUG |
1100 | | if (ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { |
1101 | | int64_t ne_ret = 0; |
1102 | | for (size_t s = 0; s < ret.n_segments; s++) { |
1103 | | for (size_t j = 0; j < n_bufs; j++) { |
1104 | | ne_ret += ret.ne[s*n_bufs + j] * ret.nr[s]; |
1105 | | } |
1106 | | } |
1107 | | assert(ne_ret == tensor->ne[int(ret.axis)]); |
1108 | | } |
1109 | | #endif // NDEBUG |
1110 | 0 | return ret; |
1111 | 0 | } |
1112 | | |
1113 | 0 | static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) { |
1114 | 0 | GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); |
1115 | 0 | ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; |
1116 | 0 | return ggml_backend_meta_get_split_state(buf_ctx->get_simple_tensor_container(tensor), tensor, assume_sync); |
1117 | 0 | } |
1118 | | |
1119 | 0 | static void * ggml_backend_meta_buffer_get_base(ggml_backend_buffer_t buffer) { |
1120 | 0 | GGML_UNUSED(buffer); |
1121 | 0 | return (void *) 0x1000000000000000; // FIXME |
1122 | 0 | } |
1123 | | |
1124 | 0 | static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_meta_simple_tensor_container & stc, ggml_tensor * tensor) { |
1125 | 0 | GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); |
1126 | 0 | ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; |
1127 | 0 | const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); |
1128 | |
|
1129 | 0 | const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(stc, tensor, /*assume_sync =*/ true); |
1130 | 0 | GGML_ASSERT(ggml_nelements(tensor) == 0 || split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); |
1131 | 0 | GGML_ASSERT(split_state.n_segments <= 16); |
1132 | |
|
1133 | 0 | int split_dim = split_state.axis; |
1134 | 0 | int64_t ne[GGML_MAX_DIMS]; |
1135 | 0 | size_t nb[GGML_MAX_DIMS]; |
1136 | 0 | for (size_t k = 0; k < GGML_MAX_DIMS; k++) { |
1137 | 0 | ne[k] = tensor->ne[k]; |
1138 | 0 | nb[k] = tensor->nb[k]; |
1139 | 0 | } |
1140 | |
|
1141 | 0 | std::vector<ggml_tensor *> simple_tensors; |
1142 | 0 | simple_tensors.reserve(n_simple_bufs); |
1143 | 0 | for (size_t j = 0; j < n_simple_bufs; j++) { |
1144 | 0 | ggml_context * simple_ctx = stc.ctxs[j].get(); |
1145 | 0 | ggml_backend_buffer_t simple_buf = buf_ctx->bufs[j].get(); |
1146 | |
|
1147 | 0 | if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) { |
1148 | | // TODO: the following assert fails for llama-parallel even though the results are correct: |
1149 | | // GGML_ASSERT(ggml_is_contiguously_allocated(tensor)); |
1150 | 0 | ne[split_dim] = 0; |
1151 | 0 | for (size_t s = 0; s < split_state.n_segments; s++) { |
1152 | 0 | ne[split_dim] += split_state.ne[s*n_simple_bufs + j] * split_state.nr[s]; |
1153 | 0 | } |
1154 | 0 | for (int i = 0; i < GGML_MAX_DIMS; i++) { |
1155 | 0 | if (tensor->nb[i] > tensor->nb[split_dim]) { |
1156 | 0 | nb[i] = tensor->nb[i] * ne[split_dim]/tensor->ne[split_dim]; |
1157 | 0 | } |
1158 | 0 | } |
1159 | 0 | } |
1160 | |
|
1161 | 0 | ggml_tensor * t_ij = ggml_new_tensor(simple_ctx, tensor->type, GGML_MAX_DIMS, ne); |
1162 | 0 | t_ij->op = tensor->op; |
1163 | 0 | for (int i = 0; i < GGML_MAX_DIMS; i++) { |
1164 | 0 | t_ij->nb[i] = nb[i]; |
1165 | 0 | } |
1166 | 0 | t_ij->flags = tensor->flags; |
1167 | 0 | memcpy(t_ij->op_params, tensor->op_params, sizeof(tensor->op_params)); |
1168 | 0 | ggml_set_name(t_ij, tensor->name); |
1169 | 0 | t_ij->buffer = simple_buf; |
1170 | 0 | t_ij->view_src = tensor->view_src; |
1171 | 0 | t_ij->view_offs = tensor->view_offs; |
1172 | 0 | if (t_ij->view_src != nullptr && ggml_backend_buffer_is_meta(t_ij->view_src->buffer)) { |
1173 | 0 | t_ij->view_src = ggml_backend_meta_buffer_simple_tensor(tensor->view_src, j); |
1174 | 0 | if (t_ij->view_offs > 0 && split_dim >= 0 && split_dim < GGML_MAX_DIMS) { |
1175 | 0 | GGML_ASSERT(tensor->ne[split_dim] != 0); |
1176 | 0 | const int split_dim_view_src = ggml_backend_meta_get_split_state(tensor->view_src, /*assume_sync =*/ true).axis; |
1177 | 0 | GGML_ASSERT(split_dim_view_src >= 0 && split_dim_view_src < GGML_MAX_DIMS); |
1178 | | |
1179 | | // The offset can be internal to the data split, in those cases the view offset should not be scaled. |
1180 | | // If however, the offset is larger than the data split then it needs to be scaled proportionally. |
1181 | 0 | bool split_internal_offset = t_ij->view_offs <= tensor->view_src->nb[split_dim_view_src]; |
1182 | 0 | for (int i = 0; i < GGML_MAX_DIMS; i++) { |
1183 | 0 | const size_t dim_size = tensor->ne[i] * tensor->nb[i]; |
1184 | 0 | if (tensor->view_offs <= dim_size && dim_size < tensor->nb[split_dim]) { |
1185 | 0 | split_internal_offset = true; |
1186 | 0 | break; |
1187 | 0 | } |
1188 | 0 | } |
1189 | 0 | if (!split_internal_offset) { |
1190 | 0 | t_ij->view_offs = t_ij->view_offs * ne[split_dim]/tensor->ne[split_dim]; |
1191 | 0 | } |
1192 | 0 | } |
1193 | 0 | } |
1194 | 0 | if (t_ij->view_src != nullptr) { |
1195 | 0 | t_ij->data = (char *) t_ij->view_src->data + t_ij->view_offs; |
1196 | 0 | } else if (simple_buf != nullptr) { |
1197 | 0 | t_ij->data = (char *) ggml_backend_buffer_get_base(simple_buf) |
1198 | 0 | + size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(tensor->buffer)); |
1199 | 0 | } |
1200 | 0 | t_ij->extra = tensor->extra; |
1201 | 0 | for (int i = 0; i < GGML_MAX_SRC; i++) { |
1202 | 0 | t_ij->src[i] = tensor->src[i]; |
1203 | 0 | if (tensor->src[i] == tensor) { |
1204 | 0 | t_ij->src[i] = t_ij; |
1205 | 0 | } else if (t_ij->src[i] != nullptr && ggml_backend_buffer_is_meta(t_ij->src[i]->buffer)) { |
1206 | 0 | t_ij->src[i] = ggml_backend_meta_buffer_simple_tensor(tensor->src[i], j); |
1207 | 0 | } |
1208 | 0 | } |
1209 | |
|
1210 | 0 | simple_tensors.push_back(t_ij); |
1211 | 0 | } |
1212 | | |
1213 | | // If one of the sources has a zero-sized slice, disable the computation: |
1214 | 0 | for (int i = 0; i < GGML_MAX_SRC; i++) { |
1215 | 0 | if (tensor->src[i] == nullptr || !ggml_backend_buffer_is_meta(tensor->src[i]->buffer)) { |
1216 | 0 | continue; |
1217 | 0 | } |
1218 | | |
1219 | 0 | const ggml_backend_meta_split_state split_state_src = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true); |
1220 | 0 | if (split_state_src.axis < 0 || split_state_src.axis >= GGML_MAX_DIMS) { |
1221 | 0 | continue; |
1222 | 0 | } |
1223 | 0 | for (size_t j = 0; j < n_simple_bufs; j++) { |
1224 | 0 | int64_t ne_sum = 0; |
1225 | 0 | for (size_t s = 0; s < split_state_src.n_segments; s++) { |
1226 | 0 | ne_sum += split_state_src.ne[s*n_simple_bufs + j] * split_state_src.nr[s]; |
1227 | 0 | } |
1228 | 0 | if (ne_sum == 0) { |
1229 | 0 | simple_tensors[j]->flags &= ~GGML_TENSOR_FLAG_COMPUTE; |
1230 | 0 | } |
1231 | 0 | } |
1232 | 0 | } |
1233 | |
|
1234 | 0 | stc.simple_tensors[tensor] = simple_tensors; |
1235 | |
|
1236 | 0 | return GGML_STATUS_SUCCESS; |
1237 | 0 | } |
1238 | | |
1239 | 0 | static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { |
1240 | 0 | GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); |
1241 | 0 | ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; |
1242 | 0 | buf_ctx->stc_compute_index = buf_ctx->stc_compute_index_next; |
1243 | 0 | return ggml_backend_meta_buffer_init_tensor_impl(buf_ctx->get_simple_tensor_container(tensor), tensor); |
1244 | 0 | } |
1245 | | |
1246 | 0 | static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { |
1247 | 0 | const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer); |
1248 | 0 | GGML_ASSERT(ggml_is_contiguous(tensor)); |
1249 | |
|
1250 | 0 | const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); |
1251 | |
|
1252 | 0 | if (split_state.n_segments != 1 || split_state.nr[0] != 1) { |
1253 | 0 | GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); |
1254 | 0 | GGML_ASSERT(split_state.nr[0] != 0); |
1255 | 0 | GGML_ASSERT(tensor->ne[3] == 1); |
1256 | |
|
1257 | 0 | size_t offset_data = 0; |
1258 | 0 | std::vector<size_t> simple_offsets(n_bufs, 0); |
1259 | 0 | if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) { |
1260 | 0 | GGML_ASSERT(tensor->ne[2] == 1); |
1261 | |
|
1262 | 0 | const size_t row_stride = tensor->nb[1]; |
1263 | 0 | GGML_ASSERT(offset % row_stride == 0); |
1264 | 0 | GGML_ASSERT(size % row_stride == 0); |
1265 | 0 | const int64_t row_start = offset / row_stride; |
1266 | 0 | const int64_t row_count = size / row_stride; |
1267 | 0 | GGML_ASSERT(row_start + row_count <= tensor->ne[1]); |
1268 | |
|
1269 | 0 | const int64_t blck_size = ggml_blck_size(tensor->type); |
1270 | 0 | for (size_t s = 0; s < split_state.n_segments; s++) { |
1271 | 0 | for (size_t r = 0; r < split_state.nr[s]; r++) { |
1272 | 0 | for (size_t j = 0; j < n_bufs; j++) { |
1273 | 0 | ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); |
1274 | 0 | GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); |
1275 | 0 | const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; |
1276 | 0 | ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, |
1277 | 0 | simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes, |
1278 | 0 | row_count, simple_tensor->nb[1], tensor->nb[1]); |
1279 | 0 | offset_data += nbytes; |
1280 | 0 | simple_offsets[j] += nbytes; |
1281 | 0 | } |
1282 | 0 | } |
1283 | 0 | } |
1284 | 0 | GGML_ASSERT(offset_data*row_count == size); |
1285 | 0 | return; |
1286 | 0 | } |
1287 | 0 | GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); |
1288 | |
|
1289 | 0 | const size_t row_stride = tensor->nb[2]; |
1290 | 0 | GGML_ASSERT(offset % row_stride == 0); |
1291 | 0 | GGML_ASSERT(size % row_stride == 0); |
1292 | 0 | const int64_t row_start = offset / row_stride; |
1293 | 0 | const int64_t row_count = size / row_stride; |
1294 | 0 | GGML_ASSERT(row_start + row_count <= tensor->ne[2]); |
1295 | |
|
1296 | 0 | for (size_t s = 0; s < split_state.n_segments; s++) { |
1297 | 0 | for (size_t r = 0; r < split_state.nr[s]; r++) { |
1298 | 0 | for (size_t j = 0; j < n_bufs; j++) { |
1299 | 0 | ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); |
1300 | 0 | const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; |
1301 | 0 | ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, |
1302 | 0 | simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes, |
1303 | 0 | row_count, simple_tensor->nb[2], tensor->nb[2]); |
1304 | 0 | offset_data += nbytes; |
1305 | 0 | simple_offsets[j] += nbytes; |
1306 | 0 | } |
1307 | 0 | } |
1308 | 0 | } |
1309 | 0 | GGML_ASSERT(offset_data*row_count == size); |
1310 | 0 | return; |
1311 | 0 | } |
1312 | | |
1313 | 0 | switch (split_state.axis) { |
1314 | 0 | case GGML_BACKEND_SPLIT_AXIS_0: |
1315 | 0 | case GGML_BACKEND_SPLIT_AXIS_1: |
1316 | 0 | case GGML_BACKEND_SPLIT_AXIS_2: { |
1317 | | // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". |
1318 | 0 | const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; |
1319 | 0 | GGML_ASSERT(offset % chunk_size_full == 0); |
1320 | 0 | GGML_ASSERT(size % chunk_size_full == 0); |
1321 | 0 | const int64_t i_start = offset /chunk_size_full; |
1322 | 0 | const int64_t i_stop = (offset + size)/chunk_size_full; |
1323 | 0 | size_t offset_j = 0; |
1324 | 0 | for (size_t j = 0; j < n_bufs; j++) { |
1325 | 0 | ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); |
1326 | 0 | const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; |
1327 | 0 | if (chunk_size_j == 0) { |
1328 | 0 | continue; |
1329 | 0 | } |
1330 | 0 | const size_t simple_offset = i_start * chunk_size_j; |
1331 | 0 | ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); |
1332 | 0 | offset_j += chunk_size_j; |
1333 | 0 | } |
1334 | 0 | GGML_ASSERT(offset_j == chunk_size_full); |
1335 | 0 | } break; |
1336 | 0 | case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { |
1337 | 0 | for (size_t j = 0; j < n_bufs; j++) { |
1338 | 0 | ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); |
1339 | 0 | ggml_backend_tensor_set(simple_tensor, data, offset, size); |
1340 | 0 | } |
1341 | 0 | } break; |
1342 | 0 | case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { |
1343 | 0 | GGML_ASSERT(tensor->type == GGML_TYPE_F32); |
1344 | 0 | const int64_t ne = ggml_nelements(tensor); |
1345 | 0 | std::vector<float> tmp; |
1346 | 0 | tmp.reserve(ne); |
1347 | 0 | for (int64_t i = 0; i < ne; i++) { |
1348 | 0 | tmp.push_back(((const float *) data)[i] / n_bufs); |
1349 | 0 | } |
1350 | 0 | for (size_t j = 0; j < n_bufs; j++) { |
1351 | 0 | ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); |
1352 | 0 | ggml_backend_tensor_set(simple_tensor, tmp.data(), offset, size); |
1353 | 0 | } |
1354 | 0 | } break; |
1355 | 0 | default: { |
1356 | 0 | GGML_ABORT("fatal error"); |
1357 | 0 | } |
1358 | 0 | } |
1359 | 0 | } |
1360 | | |
1361 | 0 | static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { |
1362 | 0 | const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer); |
1363 | 0 | GGML_ASSERT(ggml_is_contiguous(tensor)); |
1364 | |
|
1365 | 0 | const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); |
1366 | |
|
1367 | 0 | if (split_state.n_segments != 1 || split_state.nr[0] != 1) { |
1368 | 0 | GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); |
1369 | 0 | GGML_ASSERT(split_state.nr[0] != 0); |
1370 | 0 | GGML_ASSERT(tensor->ne[3] == 1); |
1371 | |
|
1372 | 0 | size_t offset_data = 0; |
1373 | 0 | std::vector<size_t> simple_offsets(n_bufs, 0); |
1374 | 0 | if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) { |
1375 | 0 | GGML_ASSERT(tensor->ne[2] == 1); |
1376 | |
|
1377 | 0 | const size_t row_stride = tensor->nb[1]; |
1378 | 0 | GGML_ASSERT(offset % row_stride == 0); |
1379 | 0 | GGML_ASSERT(size % row_stride == 0); |
1380 | 0 | const int64_t row_start = offset / row_stride; |
1381 | 0 | const int64_t row_count = size / row_stride; |
1382 | 0 | GGML_ASSERT(row_start + row_count <= tensor->ne[1]); |
1383 | |
|
1384 | 0 | const int64_t blck_size = ggml_blck_size(tensor->type); |
1385 | 0 | for (size_t s = 0; s < split_state.n_segments; s++) { |
1386 | 0 | for (size_t r = 0; r < split_state.nr[s]; r++) { |
1387 | 0 | for (size_t j = 0; j < n_bufs; j++) { |
1388 | 0 | const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); |
1389 | 0 | GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); |
1390 | 0 | const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; |
1391 | 0 | ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, |
1392 | 0 | simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes, |
1393 | 0 | row_count, simple_tensor->nb[1], tensor->nb[1]); |
1394 | 0 | offset_data += nbytes; |
1395 | 0 | simple_offsets[j] += nbytes; |
1396 | 0 | } |
1397 | 0 | } |
1398 | 0 | } |
1399 | 0 | GGML_ASSERT(offset_data*row_count == size); |
1400 | 0 | return; |
1401 | 0 | } |
1402 | 0 | GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); |
1403 | |
|
1404 | 0 | const size_t row_stride = tensor->nb[2]; |
1405 | 0 | GGML_ASSERT(offset % row_stride == 0); |
1406 | 0 | GGML_ASSERT(size % row_stride == 0); |
1407 | 0 | const int64_t row_start = offset / row_stride; |
1408 | 0 | const int64_t row_count = size / row_stride; |
1409 | 0 | GGML_ASSERT(row_start + row_count <= tensor->ne[2]); |
1410 | |
|
1411 | 0 | for (size_t s = 0; s < split_state.n_segments; s++) { |
1412 | 0 | for (size_t r = 0; r < split_state.nr[s]; r++) { |
1413 | 0 | for (size_t j = 0; j < n_bufs; j++) { |
1414 | 0 | const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); |
1415 | 0 | const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; |
1416 | 0 | ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, |
1417 | 0 | simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes, |
1418 | 0 | row_count, simple_tensor->nb[2], tensor->nb[2]); |
1419 | 0 | offset_data += nbytes; |
1420 | 0 | simple_offsets[j] += nbytes; |
1421 | 0 | } |
1422 | 0 | } |
1423 | 0 | } |
1424 | 0 | GGML_ASSERT(offset_data*row_count == size); |
1425 | 0 | return; |
1426 | 0 | } |
1427 | | |
1428 | 0 | switch (split_state.axis) { |
1429 | 0 | case GGML_BACKEND_SPLIT_AXIS_0: |
1430 | 0 | case GGML_BACKEND_SPLIT_AXIS_1: |
1431 | 0 | case GGML_BACKEND_SPLIT_AXIS_2: { |
1432 | | // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". |
1433 | 0 | const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; |
1434 | 0 | GGML_ASSERT(offset % chunk_size_full == 0); |
1435 | 0 | GGML_ASSERT(size % chunk_size_full == 0); |
1436 | 0 | const int64_t i_start = offset /chunk_size_full; |
1437 | 0 | const int64_t i_stop = (offset + size)/chunk_size_full; |
1438 | 0 | size_t offset_j = 0; |
1439 | 0 | for (size_t j = 0; j < n_bufs; j++){ |
1440 | 0 | const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); |
1441 | 0 | const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; |
1442 | 0 | if (chunk_size_j == 0) { |
1443 | 0 | continue; |
1444 | 0 | } |
1445 | 0 | const size_t simple_offset = i_start * chunk_size_j; |
1446 | 0 | ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); |
1447 | 0 | offset_j += chunk_size_j; |
1448 | 0 | } |
1449 | 0 | GGML_ASSERT(offset_j == chunk_size_full); |
1450 | 0 | } break; |
1451 | 0 | case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { |
1452 | | // TODO other simple backend may be better |
1453 | 0 | const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0); |
1454 | 0 | ggml_backend_tensor_get(simple_tensor, data, offset, size); |
1455 | 0 | } break; |
1456 | 0 | default: { |
1457 | 0 | GGML_ABORT("fatal error"); |
1458 | 0 | } |
1459 | 0 | } |
1460 | 0 | } |
1461 | | |
1462 | 0 | static void ggml_backend_meta_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { |
1463 | 0 | const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer); |
1464 | 0 | for (size_t i = 0; i < n_buffers; i++) { |
1465 | 0 | ggml_backend_buffer_clear(ggml_backend_meta_buffer_simple_buffer(buffer, i), value); |
1466 | 0 | } |
1467 | 0 | } |
1468 | | |
1469 | 0 | static void ggml_backend_meta_buffer_reset(ggml_backend_buffer_t buffer) { |
1470 | 0 | GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); |
1471 | 0 | ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; |
1472 | 0 | for (size_t i = 0; i < buf_ctx->bufs.size(); i++) { |
1473 | 0 | ggml_backend_buffer_reset(ggml_backend_meta_buffer_simple_buffer(buffer, i)); |
1474 | 0 | } |
1475 | 0 | } |
1476 | | |
1477 | | static const ggml_backend_buffer_i ggml_backend_meta_buffer_iface = { |
1478 | | /* .free_buffer = */ ggml_backend_meta_buffer_free_buffer, |
1479 | | /* .get_base = */ ggml_backend_meta_buffer_get_base, |
1480 | | /* .init_tensor = */ ggml_backend_meta_buffer_init_tensor, |
1481 | | /* .memset_tensor = */ nullptr, // TODO implement |
1482 | | /* .set_tensor = */ ggml_backend_meta_buffer_set_tensor, |
1483 | | /* .get_tensor = */ ggml_backend_meta_buffer_get_tensor, |
1484 | | /* .set_tensor_2d = */ nullptr, |
1485 | | /* .get_tensor_2d = */ nullptr, |
1486 | | /* .cpy_tensor = */ nullptr, |
1487 | | /* .clear = */ ggml_backend_meta_buffer_clear, |
1488 | | /* .reset = */ ggml_backend_meta_buffer_reset, |
1489 | | }; |
1490 | | |
1491 | 0 | bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf) { |
1492 | 0 | return buf != nullptr && buf->iface.free_buffer == ggml_backend_meta_buffer_iface.free_buffer; |
1493 | 0 | } |
1494 | | |
1495 | 0 | static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { |
1496 | 0 | const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); |
1497 | |
|
1498 | 0 | const ggml_init_params params = { |
1499 | 0 | /*.mem_size =*/ 1024*1024*ggml_tensor_overhead(), // FIXME |
1500 | 0 | /*.mem_buffer =*/ nullptr, |
1501 | 0 | /*.no_alloc =*/ true, |
1502 | 0 | }; |
1503 | 0 | ggml_backend_meta_simple_tensor_container stc_static; |
1504 | 0 | ggml_backend_meta_simple_tensor_container stc_compute_0(params, n_simple_bufts); |
1505 | 0 | ggml_backend_meta_simple_tensor_container stc_compute_1(params, n_simple_bufts); |
1506 | |
|
1507 | 0 | size_t max_size = 0; |
1508 | 0 | std::vector<ggml_backend_buffer_t> bufs; |
1509 | 0 | bufs.reserve(n_simple_bufts); |
1510 | 0 | for (size_t i = 0; i < n_simple_bufts; i++) { |
1511 | 0 | bufs.push_back(ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size)); |
1512 | 0 | GGML_ASSERT(bufs.back() != nullptr); |
1513 | 0 | max_size = std::max(max_size, ggml_backend_buffer_get_size(bufs.back())); |
1514 | 0 | } |
1515 | 0 | ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs); |
1516 | |
|
1517 | 0 | return ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, buf_ctx, max_size); |
1518 | 0 | } |
1519 | | |
1520 | 0 | struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { |
1521 | 0 | const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); |
1522 | |
|
1523 | 0 | constexpr size_t compute_headroom = 16; // Maximum number of views per statically allocated tensor that can be created between evals. |
1524 | 0 | const ggml_init_params params_static = { |
1525 | 0 | /*.mem_size =*/ ggml_get_mem_size(ctx), |
1526 | 0 | /*.mem_buffer =*/ nullptr, |
1527 | 0 | /*.no_alloc =*/ true, |
1528 | 0 | }; |
1529 | 0 | const ggml_init_params params_compute = { |
1530 | 0 | /*.mem_size =*/ compute_headroom*ggml_get_mem_size(ctx), |
1531 | 0 | /*.mem_buffer =*/ nullptr, |
1532 | 0 | /*.no_alloc =*/ true, |
1533 | 0 | }; |
1534 | 0 | ggml_backend_meta_simple_tensor_container stc_static (params_static, n_simple_bufts); |
1535 | 0 | ggml_backend_meta_simple_tensor_container stc_compute_0(params_compute, n_simple_bufts); |
1536 | 0 | ggml_backend_meta_simple_tensor_container stc_compute_1(params_compute, n_simple_bufts); |
1537 | |
|
1538 | 0 | std::vector<ggml_backend_buffer_t> bufs(n_simple_bufts, nullptr); |
1539 | 0 | ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs); |
1540 | |
|
1541 | 0 | ggml_backend_buffer_t meta_buf = ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, meta_buf_ctx, 0); |
1542 | 0 | for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { |
1543 | 0 | t->buffer = meta_buf; |
1544 | 0 | ggml_backend_meta_buffer_init_tensor_impl(meta_buf_ctx->stc_static, t); |
1545 | 0 | t->data = (void *) 0x2000000000000000; // FIXME |
1546 | 0 | } |
1547 | 0 | for (size_t i = 0; i < n_simple_bufts; i++) { |
1548 | 0 | ggml_context * ctx = meta_buf_ctx->stc_static.ctxs[i].get(); |
1549 | 0 | ggml_backend_buffer_type_t simple_buft = ggml_backend_meta_buft_simple_buft(buft, i); |
1550 | | |
1551 | | // If a ggml_context only has zero-sized tensors, ggml_backend_alloc_ctx_tensors_from_buft returns NULL. |
1552 | | // For those edge cases, allocate a dummy buffer instead. |
1553 | 0 | bool any_nonzero_slice = false; |
1554 | 0 | for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { |
1555 | 0 | if (ggml_nelements(t) != 0) { |
1556 | 0 | any_nonzero_slice = true; |
1557 | 0 | break; |
1558 | 0 | } |
1559 | 0 | } |
1560 | 0 | if (any_nonzero_slice) { |
1561 | 0 | meta_buf_ctx->bufs[i].reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx, simple_buft)); |
1562 | 0 | } else { |
1563 | 0 | meta_buf_ctx->bufs[i].reset(ggml_backend_buft_alloc_buffer(simple_buft, 0)); |
1564 | 0 | for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { |
1565 | 0 | t->buffer = meta_buf_ctx->bufs[i].get(); |
1566 | 0 | } |
1567 | 0 | } |
1568 | 0 | GGML_ASSERT(meta_buf_ctx->bufs[i]); |
1569 | 0 | meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->bufs[i].get())); |
1570 | 0 | } |
1571 | 0 | return meta_buf; |
1572 | 0 | } |
1573 | | |
1574 | | // |
1575 | | // meta backend |
1576 | | // |
1577 | | |
1578 | 0 | static ggml_guid_t ggml_backend_meta_guid() { |
1579 | 0 | static ggml_guid guid = {0xf1, 0x0e, 0x34, 0xcf, 0x9c, 0x6f, 0x43, 0xcb, 0x96, 0x92, 0xbe, 0x8e, 0xbb, 0x71, 0x3f, 0xda}; |
1580 | 0 | return &guid; |
1581 | 0 | } |
1582 | | |
1583 | | struct ggml_backend_meta_context { |
1584 | | struct cgraph_config { |
1585 | | ggml_cgraph * cgraph_main = nullptr; |
1586 | | int offset = 0; // Node offset vs. original graph |
1587 | | |
1588 | | std::vector<ggml_cgraph *> cgraphs_aux; |
1589 | | }; |
1590 | | struct backend_config { |
1591 | | ggml_backend_t backend; |
1592 | | |
1593 | | std::vector<cgraph_config> cgraphs; |
1594 | | std::vector<ggml_tensor *> nodes; |
1595 | | std::vector<ggml_backend_buffer_ptr> bufs; |
1596 | | |
1597 | 0 | backend_config(ggml_backend_t backend, const size_t n_reduce_steps) : backend(backend) { |
1598 | 0 | bufs.resize(n_reduce_steps); |
1599 | 0 | } |
1600 | | }; |
1601 | | std::string name; |
1602 | | std::vector<backend_config> backend_configs; |
1603 | | ggml_context_ptr ctx; |
1604 | | std::vector<ggml_cgraph *> cgraphs_aux; |
1605 | | std::vector<ggml_tensor *> nodes_aux; |
1606 | | size_t n_reduce_steps; |
1607 | | int max_nnodes = 0; |
1608 | | size_t max_tmp_size = 0; |
1609 | | size_t max_subgraphs = 0; |
1610 | | size_t n_subgraphs = 0; |
1611 | | uint64_t uid = 0; |
1612 | | |
1613 | | void * comm_ctx = nullptr; |
1614 | | ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr; |
1615 | | |
1616 | 0 | ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) { |
1617 | 0 | const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev); |
1618 | 0 | n_reduce_steps = std::ceil(std::log2(n_devs)); |
1619 | 0 | name = "Meta("; |
1620 | 0 | std::vector<ggml_backend_t> simple_backends; |
1621 | 0 | backend_configs.reserve(n_devs); |
1622 | 0 | simple_backends.reserve(n_devs); |
1623 | 0 | for (size_t i = 0; i < n_devs; i++) { |
1624 | 0 | ggml_backend_dev_t simple_dev = ggml_backend_meta_dev_simple_dev(meta_dev, i); |
1625 | 0 | if (i > 0) { |
1626 | 0 | name += ","; |
1627 | 0 | } |
1628 | 0 | name += ggml_backend_dev_name(simple_dev); |
1629 | 0 | simple_backends.push_back(ggml_backend_dev_init(simple_dev, params)); |
1630 | 0 | backend_configs.emplace_back(simple_backends.back(), n_reduce_steps); |
1631 | 0 | } |
1632 | 0 | name += ")"; |
1633 | |
|
1634 | 0 | if (n_devs > 1) { |
1635 | 0 | ggml_backend_comm_init_t comm_init = (ggml_backend_comm_init_t) ggml_backend_reg_get_proc_address( |
1636 | 0 | ggml_backend_dev_backend_reg(ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_init"); |
1637 | 0 | if (comm_init != nullptr) { |
1638 | 0 | comm_ctx = comm_init(simple_backends.data(), simple_backends.size()); |
1639 | 0 | } |
1640 | 0 | } |
1641 | 0 | if (comm_ctx != nullptr) { |
1642 | 0 | comm_allreduce = (ggml_backend_comm_allreduce_tensor_t) |
1643 | 0 | ggml_backend_reg_get_proc_address(ggml_backend_dev_backend_reg( |
1644 | 0 | ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_allreduce_tensor"); |
1645 | 0 | GGML_ASSERT(comm_allreduce != nullptr); |
1646 | 0 | } |
1647 | 0 | } |
1648 | | |
1649 | 0 | ~ggml_backend_meta_context() { |
1650 | 0 | if (comm_ctx != nullptr) { |
1651 | 0 | ggml_backend_comm_free_t comm_free = (ggml_backend_comm_free_t) ggml_backend_reg_get_proc_address( |
1652 | 0 | ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_configs[0].backend)), "ggml_backend_comm_free"); |
1653 | 0 | GGML_ASSERT(comm_free != nullptr); |
1654 | 0 | comm_free(comm_ctx); |
1655 | 0 | } |
1656 | 0 | for (auto & bc : backend_configs) { |
1657 | 0 | ggml_backend_free(bc.backend); |
1658 | 0 | } |
1659 | 0 | } |
1660 | | }; |
1661 | | |
1662 | 0 | static const char * ggml_backend_meta_get_name(ggml_backend_t backend) { |
1663 | 0 | GGML_ASSERT(ggml_backend_is_meta(backend)); |
1664 | 0 | const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) backend->context; |
1665 | 0 | return backend_ctx->name.c_str(); |
1666 | 0 | } |
1667 | | |
1668 | 0 | static void ggml_backend_meta_free(ggml_backend_t backend) { |
1669 | 0 | GGML_ASSERT(ggml_backend_is_meta(backend)); |
1670 | 0 | ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; |
1671 | 0 | delete backend_ctx; |
1672 | 0 | delete backend; |
1673 | 0 | } |
1674 | | |
1675 | 0 | static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { |
1676 | 0 | const size_t n_backends = ggml_backend_meta_n_backends(backend); |
1677 | 0 | GGML_ASSERT(offset == 0); |
1678 | 0 | GGML_ASSERT(ggml_is_contiguous(tensor)); |
1679 | |
|
1680 | 0 | const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); |
1681 | 0 | GGML_ASSERT(split_state.n_segments == 1); |
1682 | 0 | GGML_ASSERT(split_state.nr[0] == 1); |
1683 | |
|
1684 | 0 | switch (split_state.axis) { |
1685 | 0 | case GGML_BACKEND_SPLIT_AXIS_0: |
1686 | 0 | case GGML_BACKEND_SPLIT_AXIS_1: |
1687 | 0 | case GGML_BACKEND_SPLIT_AXIS_2: { |
1688 | | // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". |
1689 | 0 | const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; |
1690 | 0 | GGML_ASSERT(offset % chunk_size_full == 0); |
1691 | 0 | GGML_ASSERT(size % chunk_size_full == 0); |
1692 | 0 | const int64_t i_start = offset /chunk_size_full; |
1693 | 0 | const int64_t i_stop = (offset + size)/chunk_size_full; |
1694 | 0 | size_t offset_j = 0; |
1695 | 0 | for (size_t j = 0; j < n_backends; j++){ |
1696 | 0 | ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); |
1697 | 0 | ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); |
1698 | 0 | const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; |
1699 | 0 | if (chunk_size_j == 0) { |
1700 | 0 | continue; |
1701 | 0 | } |
1702 | 0 | ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j, |
1703 | 0 | i_stop - i_start, chunk_size_j, chunk_size_full); |
1704 | 0 | offset_j += chunk_size_j; |
1705 | 0 | } |
1706 | 0 | GGML_ASSERT(offset_j == chunk_size_full); |
1707 | 0 | } break; |
1708 | 0 | case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { |
1709 | 0 | for (size_t j = 0; j < n_backends; j++) { |
1710 | 0 | ggml_backend_tensor_set_async( |
1711 | 0 | ggml_backend_meta_simple_backend(backend, j), ggml_backend_meta_buffer_simple_tensor(tensor, j), data, offset, size); |
1712 | 0 | } |
1713 | 0 | } break; |
1714 | 0 | default: { |
1715 | 0 | GGML_ABORT("fatal error"); |
1716 | 0 | } |
1717 | 0 | } |
1718 | 0 | } |
1719 | | |
1720 | 0 | static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { |
1721 | 0 | const size_t n_backends = ggml_backend_meta_n_backends(backend); |
1722 | 0 | GGML_ASSERT(offset == 0); |
1723 | 0 | GGML_ASSERT(ggml_is_contiguous(tensor)); |
1724 | |
|
1725 | 0 | const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); |
1726 | 0 | GGML_ASSERT(split_state.n_segments == 1); |
1727 | 0 | GGML_ASSERT(split_state.nr[0] == 1); |
1728 | |
|
1729 | 0 | switch (split_state.axis) { |
1730 | 0 | case GGML_BACKEND_SPLIT_AXIS_0: |
1731 | 0 | case GGML_BACKEND_SPLIT_AXIS_1: |
1732 | 0 | case GGML_BACKEND_SPLIT_AXIS_2: { |
1733 | | // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". |
1734 | 0 | const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; |
1735 | 0 | GGML_ASSERT(offset % chunk_size_full == 0); |
1736 | 0 | GGML_ASSERT(size % chunk_size_full == 0); |
1737 | 0 | const int64_t i_start = offset /chunk_size_full; |
1738 | 0 | const int64_t i_stop = (offset + size)/chunk_size_full; |
1739 | 0 | size_t offset_j = 0; |
1740 | 0 | for (size_t j = 0; j < n_backends; j++){ |
1741 | 0 | ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); |
1742 | 0 | const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); |
1743 | 0 | const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; |
1744 | 0 | if (chunk_size_j == 0) { |
1745 | 0 | continue; |
1746 | 0 | } |
1747 | 0 | ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j, |
1748 | 0 | i_stop - i_start, chunk_size_j, chunk_size_full); |
1749 | 0 | offset_j += chunk_size_j; |
1750 | 0 | } |
1751 | 0 | GGML_ASSERT(offset_j == chunk_size_full); |
1752 | 0 | } break; |
1753 | 0 | case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { |
1754 | | // TODO other simple backend may be better |
1755 | 0 | ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, 0); |
1756 | 0 | const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0); |
1757 | 0 | ggml_backend_tensor_get_async(simple_backend, simple_tensor, data, offset, size); |
1758 | 0 | } break; |
1759 | 0 | default: { |
1760 | 0 | GGML_ABORT("fatal error"); |
1761 | 0 | } |
1762 | 0 | } |
1763 | 0 | } |
1764 | | |
1765 | 0 | static void ggml_backend_meta_synchronize(ggml_backend_t backend) { |
1766 | 0 | const size_t n_backends = ggml_backend_meta_n_backends(backend); |
1767 | 0 | for (size_t i = 0; i < n_backends; i++) { |
1768 | 0 | ggml_backend_synchronize(ggml_backend_meta_simple_backend(backend, i)); |
1769 | 0 | } |
1770 | 0 | } |
1771 | | |
1772 | 0 | static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { |
1773 | 0 | GGML_ASSERT(cgraph->grads == nullptr); |
1774 | 0 | const size_t n_backends = ggml_backend_meta_n_backends(backend); |
1775 | 0 | ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; |
1776 | | |
1777 | | // If the previous cgraph had a defined UID it can be used to skip rebuilding the subgraphs per simple backend. |
1778 | 0 | const bool needs_rebuild = (cgraph->uid == 0) || (cgraph->uid != backend_ctx->uid); |
1779 | |
|
1780 | 0 | bool max_nnodes_raised = false; |
1781 | 0 | if (cgraph->n_nodes > backend_ctx->max_nnodes) { |
1782 | 0 | for (size_t j = 0; j < n_backends; j++) { |
1783 | 0 | auto & bcj = backend_ctx->backend_configs[j]; |
1784 | 0 | bcj.nodes.resize(cgraph->n_nodes); |
1785 | 0 | bcj.cgraphs.resize(cgraph->n_nodes); |
1786 | 0 | } |
1787 | 0 | backend_ctx->max_nnodes = cgraph->n_nodes; |
1788 | 0 | max_nnodes_raised = true; |
1789 | 0 | assert(needs_rebuild); |
1790 | 0 | } |
1791 | |
|
1792 | 0 | if (needs_rebuild) { |
1793 | 0 | std::set<ggml_backend_buffer_t> used_buffers; |
1794 | 0 | for (int i = 0; i < cgraph->n_leafs; i++) { |
1795 | 0 | if (ggml_backend_buffer_is_meta(cgraph->leafs[i]->buffer)) { |
1796 | 0 | used_buffers.emplace(cgraph->leafs[i]->buffer); |
1797 | 0 | } |
1798 | 0 | } |
1799 | 0 | for (int i = 0; i < cgraph->n_nodes; i++) { |
1800 | 0 | if (ggml_backend_buffer_is_meta(cgraph->nodes[i]->buffer)) { |
1801 | 0 | used_buffers.emplace(cgraph->nodes[i]->buffer); |
1802 | 0 | } |
1803 | 0 | } |
1804 | 0 | for (ggml_backend_buffer_t buf : used_buffers) { |
1805 | 0 | ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buf->context; |
1806 | 0 | buf_ctx->stc_compute_index_next = buf_ctx->stc_compute_index ^ 1; |
1807 | 0 | ggml_backend_meta_simple_tensor_container & stc = buf_ctx->stc_compute[buf_ctx->stc_compute_index_next]; |
1808 | 0 | for (ggml_context_ptr & ctx : stc.ctxs) { |
1809 | 0 | ggml_reset(ctx.get()); |
1810 | 0 | } |
1811 | 0 | stc.simple_tensors.clear(); |
1812 | 0 | } |
1813 | 0 | size_t n_subgraphs = 0; |
1814 | 0 | size_t max_tmp_size = 0; |
1815 | |
|
1816 | 0 | for (size_t j = 0; j < n_backends; j++) { |
1817 | 0 | auto & bcj = backend_ctx->backend_configs[j]; |
1818 | |
|
1819 | 0 | for (int i = 0; i < cgraph->n_nodes; i++) { |
1820 | 0 | ggml_tensor * node = cgraph->nodes[i]; |
1821 | 0 | if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { |
1822 | | // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes. |
1823 | | // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash. |
1824 | 0 | bcj.nodes[i] = node; |
1825 | 0 | continue; |
1826 | 0 | } |
1827 | 0 | bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j); |
1828 | 0 | GGML_ASSERT(bcj.nodes[i]); |
1829 | 0 | } |
1830 | 0 | } |
1831 | |
|
1832 | 0 | { |
1833 | | // For MoE models it may make sense to delay the AllReduce in order to reduce I/O: |
1834 | 0 | auto get_i_delayed = [&](const int i) -> int { |
1835 | 0 | int id = i; // i_delayed |
1836 | 0 | int idr = i; // i_delayed return, last safe return value |
1837 | |
|
1838 | 0 | ggml_tensor * node = cgraph->nodes[id]; |
1839 | 0 | int32_t n_used = ggml_node_get_use_count(cgraph, id); |
1840 | | |
1841 | | // Skip MIRRORED nodes that don't consume node |
1842 | 0 | auto skip_unrelated = [&]() { |
1843 | 0 | while (id + 1 < cgraph->n_nodes) { |
1844 | 0 | ggml_tensor * next = cgraph->nodes[id+1]; |
1845 | 0 | if (ggml_backend_meta_get_split_state(next, false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) { |
1846 | 0 | break; |
1847 | 0 | } |
1848 | 0 | bool safe = true; |
1849 | 0 | for (int s = 0; s < GGML_MAX_SRC; s++) { |
1850 | 0 | if (next->src[s] == nullptr) { |
1851 | 0 | continue; |
1852 | 0 | } |
1853 | 0 | if (next->src[s] == node) { |
1854 | 0 | safe = false; |
1855 | 0 | break; |
1856 | 0 | } |
1857 | 0 | if (ggml_backend_meta_get_split_state(next->src[s], false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) { |
1858 | 0 | safe = false; |
1859 | 0 | break; |
1860 | 0 | } |
1861 | 0 | } |
1862 | 0 | if (!safe) { |
1863 | 0 | break; |
1864 | 0 | } |
1865 | 0 | id++; |
1866 | 0 | } |
1867 | 0 | }; |
1868 | |
|
1869 | 0 | skip_unrelated(); |
1870 | 0 | if (id + 1 >= cgraph->n_nodes) { |
1871 | 0 | return idr; |
1872 | 0 | } |
1873 | 0 | { |
1874 | 0 | ggml_tensor * next = cgraph->nodes[id+1]; |
1875 | 0 | if (next->op == GGML_OP_ADD_ID && next->src[0] == node && |
1876 | 0 | ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL && |
1877 | 0 | ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { |
1878 | 0 | node = next; |
1879 | 0 | id++; |
1880 | 0 | idr = id; |
1881 | 0 | n_used = ggml_node_get_use_count(cgraph, id); |
1882 | 0 | } |
1883 | 0 | } |
1884 | | // Chain of MULs with MIRRORED src[1] |
1885 | 0 | while (true) { |
1886 | 0 | skip_unrelated(); |
1887 | 0 | if (id + 1 >= cgraph->n_nodes) { |
1888 | 0 | return idr; |
1889 | 0 | } |
1890 | 0 | ggml_tensor * next = cgraph->nodes[id+1]; |
1891 | 0 | if (next->op == GGML_OP_MUL && next->src[0] == node && |
1892 | 0 | ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { |
1893 | 0 | node = next; |
1894 | 0 | id++; |
1895 | 0 | idr = id; |
1896 | 0 | n_used = ggml_node_get_use_count(cgraph, id); |
1897 | 0 | } else { |
1898 | 0 | break; |
1899 | 0 | } |
1900 | 0 | } |
1901 | | |
1902 | 0 | if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) { |
1903 | 0 | return idr; |
1904 | 0 | } |
1905 | 0 | for (int32_t k = 0; k < n_used; k++) { |
1906 | 0 | ggml_tensor * next = cgraph->nodes[id+1]; |
1907 | 0 | if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] || |
1908 | 0 | next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] || |
1909 | 0 | ggml_node_get_use_count(cgraph, id+1) != 1) { |
1910 | 0 | return idr; |
1911 | 0 | } |
1912 | 0 | id++; |
1913 | 0 | } |
1914 | 0 | { |
1915 | 0 | ggml_tensor * next = cgraph->nodes[id+1]; |
1916 | 0 | if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] || |
1917 | 0 | next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { |
1918 | 0 | return idr; |
1919 | 0 | } |
1920 | 0 | id++; |
1921 | 0 | } |
1922 | 0 | for (int32_t k = 0; k < n_used - 2; k++) { |
1923 | 0 | ggml_tensor * next = cgraph->nodes[id+1]; |
1924 | 0 | if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] || |
1925 | 0 | next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { |
1926 | 0 | return idr; |
1927 | 0 | } |
1928 | 0 | id++; |
1929 | 0 | } |
1930 | 0 | idr = id; |
1931 | 0 | return idr; |
1932 | 0 | }; |
1933 | |
|
1934 | 0 | int i_start = 0; |
1935 | 0 | for (int i = 0; i < cgraph->n_nodes; i++) { |
1936 | 0 | ggml_tensor * node = cgraph->nodes[i]; |
1937 | 0 | if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { |
1938 | 0 | continue; |
1939 | 0 | } |
1940 | 0 | const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false); |
1941 | 0 | if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { |
1942 | 0 | max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node)); |
1943 | 0 | } |
1944 | 0 | const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL; |
1945 | 0 | if (!new_subgraph) { |
1946 | 0 | continue; |
1947 | 0 | } |
1948 | | |
1949 | 0 | const int i_delayed = get_i_delayed(i); |
1950 | | |
1951 | | // If we can delay the AllReduce we need to consider the interaction with zero-sized tensor slices. |
1952 | | // A backend with such a slice would normally have valid data after participating in the AllReduce with a node that has |
1953 | | // its compute flag disabled and thus gets its data zeroed out. |
1954 | | // If the AllReduce is delayed then the nodes until that point also need to have their compute flag disabled. |
1955 | 0 | if (i_delayed > i) { |
1956 | 0 | for (size_t j = 0; j < n_backends; j++) { |
1957 | 0 | auto & bcj = backend_ctx->backend_configs[j]; |
1958 | 0 | if ((bcj.nodes[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { |
1959 | 0 | for (int ii = i + 1; ii <= i_delayed; ii++) { |
1960 | 0 | bcj.nodes[ii]->flags &= ~GGML_TENSOR_FLAG_COMPUTE; |
1961 | 0 | } |
1962 | 0 | } |
1963 | 0 | } |
1964 | 0 | } |
1965 | |
|
1966 | 0 | i = i_delayed; |
1967 | |
|
1968 | 0 | for (size_t j = 0; j < n_backends; j++) { |
1969 | 0 | auto & bcj = backend_ctx->backend_configs[j]; |
1970 | 0 | bcj.cgraphs[n_subgraphs].offset = i_start; |
1971 | 0 | } |
1972 | 0 | n_subgraphs++; |
1973 | 0 | i_start = i + 1; |
1974 | 0 | } |
1975 | 0 | GGML_ASSERT(i_start == cgraph->n_nodes); |
1976 | 0 | } |
1977 | |
|
1978 | 0 | backend_ctx->uid = cgraph->uid; |
1979 | 0 | backend_ctx->n_subgraphs = n_subgraphs; |
1980 | |
|
1981 | 0 | if (max_tmp_size > backend_ctx->max_tmp_size) { |
1982 | 0 | for (size_t j = 0; j < n_backends; j++) { |
1983 | 0 | auto & bcj = backend_ctx->backend_configs[j]; |
1984 | 0 | for (size_t i = 0; i < backend_ctx->n_reduce_steps; i++) { |
1985 | 0 | bcj.bufs[i].reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); |
1986 | 0 | } |
1987 | 0 | } |
1988 | 0 | backend_ctx->max_tmp_size = max_tmp_size; |
1989 | 0 | } |
1990 | |
|
1991 | 0 | if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) { |
1992 | 0 | backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs); |
1993 | 0 | const size_t n_nodes_per_device = 3 * backend_ctx->n_reduce_steps; // tmp + ADD (+zeroing) graph per step and device |
1994 | 0 | const size_t n_cgraphs_per_device = 2 * backend_ctx->n_reduce_steps; // ADD ( + zeroing) graph per step and device |
1995 | 0 | const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); |
1996 | 0 | const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); |
1997 | 0 | const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); |
1998 | 0 | const ggml_init_params params = { |
1999 | 0 | /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux), |
2000 | 0 | /*.mem_buffer =*/ nullptr, |
2001 | 0 | /*.no_alloc =*/ true, |
2002 | 0 | }; |
2003 | 0 | backend_ctx->ctx.reset(ggml_init(params)); |
2004 | 0 | for (size_t j = 0; j < n_backends; j++) { |
2005 | 0 | auto & bcj = backend_ctx->backend_configs[j]; |
2006 | 0 | for (size_t i = 0; i < n_subgraphs; i++) { |
2007 | 0 | bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false); |
2008 | 0 | } |
2009 | 0 | } |
2010 | 0 | backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs); |
2011 | 0 | for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) { |
2012 | 0 | backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads); |
2013 | 0 | } |
2014 | 0 | backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs); |
2015 | 0 | for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) { |
2016 | 0 | backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1); |
2017 | 0 | } |
2018 | 0 | } |
2019 | |
|
2020 | 0 | for (size_t j = 0; j < n_backends; j++) { |
2021 | 0 | auto & bcj = backend_ctx->backend_configs[j]; |
2022 | 0 | for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) { |
2023 | 0 | ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main; |
2024 | 0 | const size_t i_node_start = bcj.cgraphs[i_graph].offset; |
2025 | 0 | const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes; |
2026 | 0 | cgraph_ij->n_nodes = i_node_stop - i_node_start; |
2027 | 0 | ggml_hash_set_reset(&cgraph_ij->visited_hash_set); |
2028 | 0 | for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) { |
2029 | 0 | ggml_tensor * node_ij = bcj.nodes[i_node]; |
2030 | 0 | cgraph_ij->nodes[i_node - i_node_start] = node_ij; |
2031 | 0 | const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]); |
2032 | 0 | const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij); |
2033 | 0 | cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig]; |
2034 | 0 | } |
2035 | 0 | cgraph_ij->uid = ggml_graph_next_uid(); |
2036 | 0 | } |
2037 | 0 | } |
2038 | 0 | } |
2039 | |
|
2040 | 0 | size_t iga = 0; // i graph aux |
2041 | 0 | size_t ina = 0; // i node aux |
2042 | |
|
2043 | 0 | auto get_node_aux = [&](ggml_tensor * t) -> ggml_tensor * { |
2044 | 0 | ggml_tensor * ret = backend_ctx->nodes_aux[ina++]; |
2045 | 0 | memset(ret, 0, sizeof(ggml_tensor)); |
2046 | 0 | ret->op = GGML_OP_NONE; |
2047 | 0 | ret->type = t->type; |
2048 | 0 | for (size_t k = 0; k < GGML_MAX_DIMS; k++) { |
2049 | 0 | ret->ne[k] = t->ne[k]; |
2050 | 0 | ret->nb[k] = t->nb[k]; |
2051 | 0 | } |
2052 | 0 | return ret; |
2053 | 0 | }; |
2054 | 0 | auto set_tmp_data = [&](ggml_tensor * tensor, const size_t j, const size_t i_buf) { |
2055 | 0 | auto & bcj = backend_ctx->backend_configs[j]; |
2056 | 0 | ggml_backend_buffer_ptr & buf_ptr = bcj.bufs[i_buf]; |
2057 | 0 | if (!buf_ptr || ggml_backend_buffer_get_size(buf_ptr.get()) < backend_ctx->max_tmp_size) { |
2058 | 0 | buf_ptr.reset(ggml_backend_alloc_buffer(bcj.backend, backend_ctx->max_tmp_size)); |
2059 | 0 | } |
2060 | 0 | tensor->buffer = buf_ptr.get(); |
2061 | 0 | tensor->data = ggml_backend_buffer_get_base(buf_ptr.get()); |
2062 | 0 | }; |
2063 | | // FIXME usage_counts |
2064 | 0 | auto get_cgraph_aux = [&]() -> ggml_cgraph * { |
2065 | 0 | ggml_cgraph * ret = backend_ctx->cgraphs_aux[iga++]; |
2066 | 0 | return ret; |
2067 | 0 | }; |
2068 | | |
2069 | | // Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable: |
2070 | 0 | auto allreduce_fallback = [&](size_t i) -> ggml_status { |
2071 | 0 | std::vector<ggml_cgraph *> step_cgraphs(n_backends, nullptr); |
2072 | | |
2073 | | // Zero out nodes that were disabled due to having a zero-sized slice: |
2074 | 0 | for (size_t j = 0; j < n_backends; j++) { |
2075 | 0 | auto & bcj = backend_ctx->backend_configs[j]; |
2076 | 0 | ggml_tensor * node = bcj.cgraphs[i].cgraph_main->nodes[bcj.cgraphs[i].cgraph_main->n_nodes - 1]; |
2077 | 0 | if (node->flags & GGML_TENSOR_FLAG_COMPUTE) { |
2078 | 0 | continue; |
2079 | 0 | } |
2080 | 0 | ggml_tensor * node_zero = get_node_aux(node); |
2081 | 0 | node_zero->op = GGML_OP_SCALE; // FIXME 0.0f * NaN == NaN |
2082 | 0 | node_zero->src[0] = node; |
2083 | 0 | ggml_set_op_params_f32(node_zero, 0, 0.0f); |
2084 | 0 | node_zero->data = node->data; |
2085 | 0 | node_zero->buffer = node->buffer; |
2086 | 0 | node_zero->flags |= GGML_TENSOR_FLAG_COMPUTE; |
2087 | |
|
2088 | 0 | step_cgraphs[j] = get_cgraph_aux(); |
2089 | 0 | step_cgraphs[j]->nodes[0] = node_zero; |
2090 | 0 | step_cgraphs[j]->n_nodes = 1; |
2091 | 0 | const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]); |
2092 | 0 | if (status != GGML_STATUS_SUCCESS) { |
2093 | 0 | return status; |
2094 | 0 | } |
2095 | 0 | } |
2096 | 0 | std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr); |
2097 | |
|
2098 | 0 | auto push_data = [&](const size_t j_src, const size_t j_dst, const size_t i_buf) { |
2099 | 0 | assert(step_cgraphs[j_dst] == nullptr); |
2100 | 0 | auto & bcj_src = backend_ctx->backend_configs[j_src]; |
2101 | 0 | auto & bcj_dst = backend_ctx->backend_configs[j_dst]; |
2102 | |
|
2103 | 0 | ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1]; |
2104 | 0 | ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1]; |
2105 | 0 | GGML_ASSERT(ggml_is_contiguous(node_src)); |
2106 | 0 | GGML_ASSERT(ggml_is_contiguous(node_dst)); |
2107 | |
|
2108 | 0 | ggml_tensor * node_tmp = get_node_aux(node_dst); |
2109 | 0 | set_tmp_data(node_tmp, j_dst, i_buf); |
2110 | |
|
2111 | 0 | ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_tmp); |
2112 | |
|
2113 | 0 | ggml_tensor * node_red = get_node_aux(node_dst); |
2114 | 0 | node_red->view_src = node_dst->view_src == nullptr ? node_dst : node_dst->view_src; |
2115 | 0 | node_red->view_offs = node_dst->view_offs; |
2116 | 0 | node_red->op = GGML_OP_ADD; |
2117 | 0 | node_red->src[0] = node_dst; |
2118 | 0 | node_red->src[1] = node_tmp; |
2119 | 0 | node_red->flags |= GGML_TENSOR_FLAG_COMPUTE; |
2120 | 0 | ggml_backend_view_init(node_red); |
2121 | |
|
2122 | 0 | ggml_cgraph * cgraph_aux = get_cgraph_aux(); |
2123 | 0 | cgraph_aux->nodes[0] = node_red; |
2124 | 0 | cgraph_aux->n_nodes = 1; |
2125 | 0 | step_cgraphs[j_dst] = cgraph_aux; |
2126 | 0 | }; |
2127 | |
|
2128 | 0 | size_t offset_j = n_backends/2; |
2129 | 0 | while ((offset_j & (offset_j - 1)) != 0) { |
2130 | 0 | offset_j--; |
2131 | 0 | } |
2132 | 0 | const size_t offset_j_max = offset_j; |
2133 | 0 | size_t i_buf = 0; |
2134 | | |
2135 | | // If n_backends is not a power of 2, fold in the excess prior to butterfly reduction: |
2136 | 0 | for (size_t j_src = 2*offset_j_max; j_src < n_backends; j_src++) { |
2137 | 0 | const size_t j_dst = j_src - 2*offset_j_max; |
2138 | 0 | push_data(j_src, j_dst, i_buf); |
2139 | 0 | const ggml_status status = ggml_backend_graph_compute_async(backend_ctx->backend_configs[j_dst].backend, step_cgraphs[j_dst]); |
2140 | 0 | if (status != GGML_STATUS_SUCCESS) { |
2141 | 0 | return status; |
2142 | 0 | } |
2143 | 0 | i_buf = 1; |
2144 | 0 | } |
2145 | | |
2146 | | // Butterfly reduction: |
2147 | 0 | for (; offset_j >= 1; offset_j /= 2) { |
2148 | 0 | std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr); |
2149 | |
|
2150 | 0 | for (size_t j = 0; j < 2*offset_j_max; j++) { |
2151 | 0 | const size_t j_other = j ^ offset_j; |
2152 | 0 | if (j_other >= n_backends) { |
2153 | 0 | continue; |
2154 | 0 | } |
2155 | 0 | push_data(j, j_other, i_buf); |
2156 | 0 | } |
2157 | |
|
2158 | 0 | for (size_t j = 0; j < 2*offset_j_max; j++) { |
2159 | 0 | if (step_cgraphs[j] == nullptr) { |
2160 | 0 | continue; |
2161 | 0 | } |
2162 | 0 | auto & bcj = backend_ctx->backend_configs[j]; |
2163 | 0 | const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]); |
2164 | 0 | if (status != GGML_STATUS_SUCCESS) { |
2165 | 0 | return status; |
2166 | 0 | } |
2167 | 0 | } |
2168 | 0 | i_buf++; |
2169 | 0 | } |
2170 | 0 | assert(i_buf == backend_ctx->n_reduce_steps); |
2171 | | |
2172 | | // If n_backends is not a power of 2, copy back the reduced tensors to the excess: |
2173 | 0 | for (size_t j = 2*offset_j_max; j < n_backends; j++) { |
2174 | 0 | auto & bcj_src = backend_ctx->backend_configs[j - 2*offset_j_max]; |
2175 | 0 | auto & bcj_dst = backend_ctx->backend_configs[j]; |
2176 | |
|
2177 | 0 | ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1]; |
2178 | 0 | ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1]; |
2179 | 0 | ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_dst); |
2180 | 0 | } |
2181 | |
|
2182 | 0 | return GGML_STATUS_SUCCESS; |
2183 | 0 | }; |
2184 | | |
2185 | |
|
2186 | 0 | for (size_t i = 0; i < backend_ctx->n_subgraphs; i++) { |
2187 | 0 | for (size_t j = 0; j < n_backends; j++) { |
2188 | 0 | auto & bcj = backend_ctx->backend_configs[j]; |
2189 | 0 | const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, bcj.cgraphs[i].cgraph_main); |
2190 | 0 | if (status != GGML_STATUS_SUCCESS) { |
2191 | 0 | return status; |
2192 | 0 | } |
2193 | 0 | } |
2194 | | |
2195 | 0 | if (n_backends > 1 && i < backend_ctx->n_subgraphs - 1) { |
2196 | 0 | bool backend_allreduce_success = false; |
2197 | 0 | if (backend_ctx->comm_ctx) { |
2198 | 0 | std::vector<ggml_tensor *> nodes; |
2199 | 0 | nodes.reserve(n_backends); |
2200 | 0 | for (size_t j = 0; j < n_backends; j++) { |
2201 | 0 | auto & bcj = backend_ctx->backend_configs[j]; |
2202 | 0 | ggml_cgraph * cgraph_ij = bcj.cgraphs[i].cgraph_main; |
2203 | 0 | nodes.push_back(cgraph_ij->nodes[cgraph_ij->n_nodes-1]); |
2204 | 0 | } |
2205 | 0 | backend_allreduce_success = backend_ctx->comm_allreduce(backend_ctx->comm_ctx, nodes.data()); |
2206 | 0 | } |
2207 | |
|
2208 | 0 | if (!backend_allreduce_success) { |
2209 | 0 | const ggml_status status = allreduce_fallback(i); |
2210 | 0 | if (status != GGML_STATUS_SUCCESS) { |
2211 | 0 | return status; |
2212 | 0 | } |
2213 | 0 | } |
2214 | 0 | } |
2215 | 0 | } |
2216 | 0 | return GGML_STATUS_SUCCESS; |
2217 | 0 | } |
2218 | | |
2219 | | static const ggml_backend_i ggml_backend_meta_i = { |
2220 | | /* .get_name = */ ggml_backend_meta_get_name, |
2221 | | /* .free = */ ggml_backend_meta_free, |
2222 | | /* .set_tensor_async = */ ggml_backend_meta_set_tensor_async, |
2223 | | /* .get_tensor_async = */ ggml_backend_meta_get_tensor_async, |
2224 | | /* .set_tensor_2d_async = */ nullptr, |
2225 | | /* .get_tensor_2d_async = */ nullptr, |
2226 | | /* .cpy_tensor_async = */ nullptr, |
2227 | | /* .synchronize = */ ggml_backend_meta_synchronize, |
2228 | | /* .graph_plan_create = */ nullptr, |
2229 | | /* .graph_plan_free = */ nullptr, |
2230 | | /* .graph_plan_update = */ nullptr, |
2231 | | /* .graph_plan_compute = */ nullptr, |
2232 | | /* .graph_compute = */ ggml_backend_meta_graph_compute, |
2233 | | /* .event_record = */ nullptr, |
2234 | | /* .event_wait = */ nullptr, |
2235 | | /* .graph_optimize = */ nullptr, |
2236 | | }; |
2237 | | |
2238 | 0 | bool ggml_backend_is_meta(ggml_backend_t backend) { |
2239 | 0 | return backend != nullptr && backend->iface.get_name == ggml_backend_meta_i.get_name; |
2240 | 0 | } |
2241 | | |
2242 | 0 | static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params) { |
2243 | 0 | ggml_backend_meta_context * backend_ctx = new ggml_backend_meta_context(dev, params); |
2244 | |
|
2245 | 0 | ggml_backend_t backend = new struct ggml_backend; |
2246 | 0 | backend->guid = ggml_backend_meta_guid(); |
2247 | 0 | backend->iface = ggml_backend_meta_i; |
2248 | 0 | backend->device = dev; |
2249 | 0 | backend->context = backend_ctx; |
2250 | 0 | return backend; |
2251 | 0 | } |
2252 | | |
2253 | 0 | size_t ggml_backend_meta_n_backends(ggml_backend_t meta_backend) { |
2254 | 0 | GGML_ASSERT(ggml_backend_is_meta(meta_backend)); |
2255 | 0 | const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; |
2256 | 0 | return backend_ctx->backend_configs.size(); |
2257 | 0 | } |
2258 | | |
2259 | 0 | ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index) { |
2260 | 0 | GGML_ASSERT(ggml_backend_is_meta(meta_backend)); |
2261 | 0 | const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; |
2262 | 0 | return backend_ctx->backend_configs[index].backend; |
2263 | 0 | } |