Coverage Report

Created: 2025-11-24 06:10

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/ggml/src/ggml-opt.cpp
Line
Count
Source
1
#include "ggml-opt.h"
2
3
#include "ggml.h"
4
#include "ggml-alloc.h"
5
#include "ggml-backend.h"
6
#include "ggml-impl.h"
7
8
#include <algorithm>
9
#include <cmath>
10
#include <cstdint>
11
#include <cinttypes>
12
#include <map>
13
#include <random>
14
#include <vector>
15
16
struct ggml_opt_dataset {
17
    struct ggml_context   * ctx    = nullptr;
18
    ggml_backend_buffer_t   buf    = nullptr;
19
    struct ggml_tensor    * data   = nullptr;
20
    struct ggml_tensor    * labels = nullptr;
21
22
    int64_t ndata       = -1;
23
    int64_t ndata_shard = -1;
24
    size_t  nbs_data    = -1;
25
    size_t  nbs_labels  = -1;
26
27
    std::vector<int64_t> permutation;
28
};
29
30
struct ggml_opt_context {
31
    ggml_backend_sched_t       backend_sched        = nullptr;
32
    ggml_cgraph              * allocated_graph      = nullptr;
33
    ggml_cgraph              * allocated_graph_copy = nullptr;
34
    struct ggml_context      * ctx_static           = nullptr;
35
    struct ggml_context      * ctx_cpu              = nullptr;
36
    struct ggml_context      * ctx_compute          = nullptr;
37
    struct ggml_context      * ctx_copy             = nullptr;
38
    ggml_backend_buffer_t      buf_static           = nullptr;
39
    ggml_backend_buffer_t      buf_cpu              = nullptr;
40
    std::mt19937               rng;
41
    enum ggml_opt_loss_type    loss_type;
42
    enum ggml_opt_build_type   build_type;
43
    enum ggml_opt_build_type   build_type_alloc;
44
45
    struct ggml_tensor * inputs  = nullptr;
46
    struct ggml_tensor * outputs = nullptr;
47
    struct ggml_tensor * labels  = nullptr;
48
49
    struct ggml_tensor * loss     = nullptr;
50
    struct ggml_tensor * pred     = nullptr;
51
    struct ggml_tensor * ncorrect = nullptr;
52
53
    struct ggml_cgraph * gf      = nullptr;
54
    struct ggml_cgraph * gb_grad = nullptr;
55
    struct ggml_cgraph * gb_opt  = nullptr;
56
    bool static_graphs           = false;
57
    bool eval_ready              = false;
58
    std::vector<struct ggml_tensor *> grad_accs;
59
    std::vector<struct ggml_tensor *> grad_m;
60
    std::vector<struct ggml_tensor *> grad_v;
61
62
    int64_t iter               = 1;
63
    int32_t opt_period         = 1;
64
    int32_t opt_i              = 0;
65
    bool    loss_per_datapoint = false;
66
67
    ggml_opt_get_optimizer_params get_opt_pars    = nullptr;
68
    void *                        get_opt_pars_ud = nullptr;
69
    struct ggml_tensor *          opt_step_params = nullptr; // Stores output of get_opt_pars.
70
71
    enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
72
};
73
74
struct ggml_opt_result {
75
    int64_t              ndata    = 0;
76
    std::vector<float>   loss;
77
    std::vector<int32_t> pred;
78
    int64_t              ncorrect = 0;
79
80
    int64_t opt_period         = -1;
81
    bool    loss_per_datapoint = false;
82
};
83
84
// ====== Dataset ======
85
86
ggml_opt_dataset_t ggml_opt_dataset_init(
87
        enum ggml_type type_data,
88
        enum ggml_type type_label,
89
        int64_t        ne_datapoint,
90
        int64_t        ne_label,
91
        int64_t        ndata,
92
0
        int64_t        ndata_shard) {
93
0
    GGML_ASSERT(ne_datapoint >  0);
94
0
    GGML_ASSERT(ne_label     >= 0);
95
0
    GGML_ASSERT(ndata        >  0);
96
0
    GGML_ASSERT(ndata_shard  >  0);
97
98
0
    ggml_opt_dataset_t result = new ggml_opt_dataset;
99
0
    result->ndata       = ndata;
100
0
    result->ndata_shard = ndata_shard;
101
102
0
    {
103
0
        struct ggml_init_params params = {
104
0
            /*.mem_size   =*/ 2*ggml_tensor_overhead(),
105
0
            /*.mem_buffer =*/ nullptr,
106
0
            /*.no_alloc   =*/ true,
107
0
        };
108
0
        result->ctx = ggml_init(params);
109
0
    }
110
111
0
    result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata);
112
0
    result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata;
113
114
0
    if (ne_label > 0) {
115
0
        result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata);
116
0
        result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata;
117
0
    } else {
118
0
        result->labels = nullptr;
119
0
        result->nbs_labels = 0;
120
0
    }
121
122
0
    result->buf = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx, ggml_backend_cpu_buffer_type());
123
124
0
    const int64_t nshards = ndata/ndata_shard;
125
0
    result->permutation.resize(nshards);
126
0
    for (int64_t i = 0; i < nshards; ++i) {
127
0
        result->permutation[i] = i;
128
0
    }
129
0
    return result;
130
0
}
131
132
0
void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) {
133
0
    ggml_backend_buffer_free(dataset->buf);
134
0
    ggml_free(dataset->ctx);
135
0
    delete dataset;
136
0
}
137
138
0
int64_t ggml_opt_dataset_ndata(ggml_opt_dataset_t dataset) {
139
0
    return dataset->ndata;
140
0
}
141
142
0
struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) {
143
0
    return dataset->data;
144
0
}
145
146
0
struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset) {
147
0
    return dataset->labels;
148
0
}
149
150
0
void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata) {
151
0
    GGML_ASSERT(idata <= dataset->ndata);
152
153
0
    if (idata < 0) {
154
0
        std::shuffle(dataset->permutation.begin(), dataset->permutation.end(), opt_ctx->rng);
155
0
        return;
156
0
    }
157
158
0
    GGML_ASSERT(idata % dataset->ndata_shard == 0);
159
0
    const int64_t ishard_max = idata / dataset->ndata_shard;
160
0
    std::shuffle(dataset->permutation.begin(), dataset->permutation.begin() + ishard_max, opt_ctx->rng);
161
0
}
162
163
0
void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * data_batch, struct ggml_tensor * labels_batch, int64_t ibatch) {
164
0
    GGML_ASSERT(   data_batch && ggml_is_contiguous(data_batch));
165
0
    GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch));
166
0
    GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
167
0
    GGML_ASSERT(                   data_batch->type == dataset->data->type);
168
0
    GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type);
169
170
0
    const size_t nb_data_batch = ggml_nbytes(data_batch);
171
0
    GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
172
0
    const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
173
174
0
    if (labels_batch) {
175
0
        const size_t nb_labels_batch = ggml_nbytes(labels_batch);
176
0
        GGML_ASSERT(nb_labels_batch == shards_per_batch*dataset->nbs_labels);
177
0
    }
178
179
0
    GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
180
181
0
    for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
182
0
        const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
183
184
0
        const char * ptr_data = (const char *) dataset->data->data + ishard*dataset->nbs_data;
185
0
        ggml_backend_tensor_set(data_batch, ptr_data, ishard_batch*dataset->nbs_data, dataset->nbs_data);
186
187
0
        if (!labels_batch) {
188
0
            continue;
189
0
        }
190
191
0
        const char * ptr_labels = (const char *) dataset->labels->data + ishard*dataset->nbs_labels;
192
0
        ggml_backend_tensor_set(labels_batch, ptr_labels, ishard_batch*dataset->nbs_labels, dataset->nbs_labels);
193
0
    }
194
0
}
195
196
0
void ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) {
197
0
    GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
198
0
    GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
199
200
0
    const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
201
202
0
    GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
203
204
0
    for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
205
0
        const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
206
207
0
        const char * ptr_data       = (const char *) dataset->data->data + ishard      *dataset->nbs_data;
208
0
        char       * ptr_data_batch = (char       *) data_batch          + ishard_batch*dataset->nbs_data;
209
0
        memcpy(ptr_data_batch, ptr_data, dataset->nbs_data);
210
211
0
        if (!labels_batch) {
212
0
            continue;
213
0
        }
214
215
0
        const char * ptr_labels       = (const char *) dataset->labels->data + ishard      *dataset->nbs_labels;
216
0
        char       * ptr_labels_batch = (char       *) labels_batch          + ishard_batch*dataset->nbs_labels;
217
0
        memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels);
218
0
    }
219
0
}
220
221
// ====== Model / Context ======
222
223
0
struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) {
224
0
    GGML_UNUSED(userdata);
225
226
0
    ggml_opt_optimizer_params result;
227
228
0
    result.adamw.alpha = 0.001f;
229
0
    result.adamw.beta1 = 0.9f;
230
0
    result.adamw.beta2 = 0.999f;
231
0
    result.adamw.eps   = 1e-8f;
232
0
    result.adamw.wd    = 0.0f;
233
234
0
    result.sgd.alpha   = 1e-3f;
235
0
    result.sgd.wd      = 0.0f;
236
237
0
    return result;
238
0
}
239
240
241
0
struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
242
0
    return *((struct ggml_opt_optimizer_params *) userdata);
243
0
}
244
245
struct ggml_opt_params ggml_opt_default_params(
246
        ggml_backend_sched_t      backend_sched,
247
0
        enum ggml_opt_loss_type   loss_type) {
248
0
    return {
249
0
        /*backend_sched   =*/ backend_sched,
250
0
        /*ctx_compute     =*/ nullptr,
251
0
        /*inputs          =*/ nullptr,
252
0
        /*logits          =*/ nullptr,
253
0
        /*loss_type       =*/ loss_type,
254
0
        /*build_type      =*/ GGML_OPT_BUILD_TYPE_OPT,
255
0
        /*opt_period      =*/ 1,
256
0
        /*get_opt_pars    =*/ ggml_opt_get_default_optimizer_params,
257
0
        /*get_opt_pars_ud =*/ nullptr,
258
0
        /*optimizer       =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
259
0
    };
260
0
}
261
262
0
static ggml_tensor * map_tensor(std::map<ggml_tensor *, ggml_tensor *> & tensor_map, ggml_context * ctx, ggml_tensor * tensor) {
263
0
    if (!tensor) {
264
0
        return nullptr;
265
0
    }
266
267
0
    if (tensor_map.find(tensor) != tensor_map.end()) {
268
0
        return tensor_map[tensor];
269
0
    }
270
271
0
    ggml_tensor * new_tensor = ggml_dup_tensor(ctx, tensor);
272
0
    tensor_map[tensor] = new_tensor;
273
274
0
    new_tensor->op = tensor->op;
275
0
    for (int i = 0; i < GGML_MAX_DIMS; i++) {
276
0
        new_tensor->nb[i] = tensor->nb[i];
277
0
    }
278
0
    new_tensor->flags = tensor->flags;
279
0
    memcpy(new_tensor->op_params, tensor->op_params, sizeof(tensor->op_params));
280
0
    strcpy(new_tensor->name, tensor->name);
281
0
    new_tensor->data = tensor->data;
282
0
    new_tensor->buffer = tensor->buffer;
283
0
    new_tensor->extra = tensor->extra;
284
0
    new_tensor->view_offs = tensor->view_offs;
285
0
    new_tensor->view_src = map_tensor(tensor_map, ctx, tensor->view_src);
286
0
    for (int i = 0; i < GGML_MAX_SRC; i++) {
287
0
        new_tensor->src[i] = map_tensor(tensor_map, ctx, tensor->src[i]);
288
0
    }
289
290
0
    return new_tensor;
291
0
}
292
293
0
static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) {
294
0
    std::map<ggml_tensor *, ggml_tensor *> tensor_map;
295
296
0
    ggml_cgraph * dst = ggml_new_graph_custom(ctx, src->size, /*grads =*/ true);
297
298
0
    for (int i = 0; i < src->n_leafs; i++) {
299
0
        ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->leafs[i]));
300
0
    }
301
0
    GGML_ASSERT(dst->n_leafs == src->n_leafs);
302
0
    for (int i = 0; i < src->n_nodes; i++) {
303
0
        ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->nodes[i]));
304
0
    }
305
0
    GGML_ASSERT(dst->n_nodes == src->n_nodes);
306
0
    for (int i = 0; i < src->n_nodes; ++i) {
307
0
        const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
308
0
        const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
309
310
0
        GGML_ASSERT(igrad_src != GGML_HASHSET_FULL);
311
0
        GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src));
312
0
        GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL);
313
0
        GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
314
315
0
        dst->grads[igrad_dst]     = src->grads[igrad_src];
316
0
        dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
317
0
    }
318
319
0
    return dst;
320
0
}
321
322
0
static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
323
0
    GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
324
0
    GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
325
326
0
    const enum ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;
327
328
0
    const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
329
0
        !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
330
331
0
    const bool need_momenta = opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT &&
332
0
        opt_ctx->optimizer == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
333
334
0
    ggml_set_input(opt_ctx->inputs);
335
0
    ggml_set_output(opt_ctx->outputs);
336
337
0
    int n_param = 0;
338
0
    for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) {
339
0
        const struct ggml_tensor * node = opt_ctx->gf->nodes[i];
340
0
        if (node->flags & GGML_TENSOR_FLAG_PARAM) {
341
0
            n_param++;
342
0
        }
343
0
        GGML_ASSERT(!(node->flags & GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented");
344
0
    }
345
346
0
    if (!opt_ctx->ctx_static) {
347
        // The static context is used for:
348
        //   - gradients (1 per loss, 1 tensor per param if using gradient accumulation)
349
        //   - optimizer momenta (2 tensors per param)
350
        //   - labels (if using static graphs)
351
        //   - loss (if using static graphs, up to 5 tensors)
352
        //   - pred (if using static graphs)
353
        //   - ncorrect (if using static graphs, 2 tensors).
354
0
        constexpr size_t n_loss = 1;
355
0
        const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);
356
0
        const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
357
0
        const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
358
0
        struct ggml_init_params params = {
359
0
            /*.mem_size   =*/ size_meta,
360
0
            /*.mem_buffer =*/ nullptr,
361
0
            /*.no_alloc   =*/ true,
362
0
        };
363
0
        opt_ctx->ctx_static = ggml_init(params);
364
0
    }
365
0
    GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc);
366
367
0
    {
368
        // The cpu context is allocated statically if using static graphs, dynamically otherwise.
369
        // It is used for:
370
        //   - optimizer parameters (1 shared for all optimizer invocations)
371
0
        const size_t size_meta = 1 * ggml_tensor_overhead();
372
0
        struct ggml_init_params params = {
373
0
            /*.mem_size   =*/ size_meta,
374
0
            /*.mem_buffer =*/ nullptr,
375
0
            /*.no_alloc   =*/ true,
376
0
        };
377
0
        ggml_free(opt_ctx->ctx_cpu);
378
0
        opt_ctx->ctx_cpu = ggml_init(params);
379
380
0
        ggml_backend_buffer_free(opt_ctx->buf_cpu);
381
0
        opt_ctx->buf_cpu = nullptr;
382
0
    }
383
384
0
    struct ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute;
385
386
0
    switch (opt_ctx->loss_type) {
387
0
        case GGML_OPT_LOSS_TYPE_MEAN: {
388
0
            opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
389
0
            ggml_set_name(opt_ctx->loss, "loss_sum");
390
0
            const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
391
0
            opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
392
0
            ggml_set_name(opt_ctx->loss, "loss_mean");
393
0
            opt_ctx->loss_per_datapoint = true;
394
0
            break;
395
0
        }
396
0
        case GGML_OPT_LOSS_TYPE_SUM: {
397
0
            opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
398
0
            ggml_set_name(opt_ctx->loss, "loss_sum");
399
0
            opt_ctx->loss_per_datapoint = false;
400
0
            break;
401
0
        }
402
0
        case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
403
0
            opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
404
0
            ggml_set_input(opt_ctx->labels);
405
0
            ggml_set_name(opt_ctx->labels, "labels");
406
0
            opt_ctx->loss = ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels);
407
0
            ggml_set_name(opt_ctx->loss, "loss_cross_entropy");
408
0
            if (opt_ctx->opt_period > 1) {
409
0
                opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period);
410
0
                ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled");
411
0
            }
412
0
            opt_ctx->loss_per_datapoint = true;
413
0
            break;
414
0
        }
415
0
        case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
416
0
            opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
417
0
            ggml_set_input(opt_ctx->labels);
418
0
            ggml_set_name(opt_ctx->labels, "labels");
419
0
            opt_ctx->loss = ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels);
420
0
            ggml_set_name(opt_ctx->loss, "loss_error");
421
0
            opt_ctx->loss = ggml_sqr(ctx_results, opt_ctx->loss);
422
0
            ggml_set_name(opt_ctx->loss, "loss_squared_error");
423
0
            opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->loss);
424
0
            ggml_set_name(opt_ctx->loss, "loss_sum_squared_error");
425
0
            const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
426
0
            opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
427
0
            ggml_set_name(opt_ctx->loss, "loss_mean_squared_error");
428
0
            opt_ctx->loss_per_datapoint = true;
429
0
            break;
430
0
        }
431
0
    }
432
0
    ggml_set_output(opt_ctx->loss);
433
0
    ggml_set_loss(opt_ctx->loss);
434
0
    ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss);
435
436
0
    if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) {
437
0
        opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs);
438
0
        ggml_set_name(opt_ctx->pred, "pred");
439
0
        ggml_set_output(opt_ctx->pred);
440
0
        ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred);
441
442
0
        opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels));
443
0
        ggml_set_name(opt_ctx->ncorrect, "ncorrect");
444
0
        ggml_set_output(opt_ctx->ncorrect);
445
0
        ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect);
446
0
    }
447
448
0
    if (opt_ctx->buf_static) {
449
0
        if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
450
0
            return;
451
0
        }
452
0
    } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_FORWARD) {
453
0
        opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
454
0
            opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
455
0
        return;
456
0
    }
457
458
0
    if (opt_ctx->grad_accs.empty()) {
459
0
        GGML_ASSERT(opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD);
460
461
0
        const int n_nodes = opt_ctx->gf->n_nodes;
462
0
        opt_ctx->grad_accs.resize(n_nodes);
463
0
        for (int i = 0; i < n_nodes; ++i) {
464
0
            ggml_tensor * node = opt_ctx->gf->nodes[i];
465
0
            if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
466
0
                opt_ctx->grad_accs[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
467
0
            } else {
468
0
                opt_ctx->grad_accs[i] = nullptr;
469
0
            }
470
0
        }
471
472
0
        if (need_momenta && opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
473
0
            opt_ctx->grad_m.resize(n_nodes);
474
0
            opt_ctx->grad_v.resize(n_nodes);
475
0
            for (int i = 0; i < n_nodes; ++i) {
476
0
                ggml_tensor * node = opt_ctx->gf->nodes[i];
477
0
                if (node->flags & GGML_TENSOR_FLAG_PARAM) {
478
0
                    opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
479
0
                    opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
480
0
                } else {
481
0
                    opt_ctx->grad_m[i] = nullptr;
482
0
                    opt_ctx->grad_v[i] = nullptr;
483
0
                }
484
0
            }
485
0
        }
486
0
    }
487
488
    // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
489
0
    opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true);
490
0
    ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data());
491
492
0
    if (opt_ctx->buf_static) {
493
0
        if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_GRAD) {
494
0
            return;
495
0
        }
496
0
    } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_GRAD) {
497
0
        opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
498
0
        ggml_graph_reset(opt_ctx->gb_grad);
499
0
    }
500
501
0
    GGML_ASSERT(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT);
502
503
    // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
504
0
    opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
505
506
0
    opt_ctx->opt_step_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, need_momenta ? 7 : 2);
507
0
    ggml_tensor * adamw_params = opt_ctx->opt_step_params;
508
0
    ggml_set_input(adamw_params);
509
0
    const char * optimizer_name = ggml_opt_optimizer_name(opt_ctx->optimizer);
510
0
    ggml_format_name(adamw_params, "%s_params", optimizer_name);
511
0
    for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
512
0
        struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
513
0
        struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
514
515
0
        if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
516
0
            struct ggml_tensor * m = nullptr;
517
0
            struct ggml_tensor * v = nullptr;
518
0
            if (need_momenta) {
519
0
                m = opt_ctx->grad_m[i];
520
0
                v = opt_ctx->grad_v[i];
521
0
                ggml_format_name(m, "AdamW m for %s", node->name);
522
0
                ggml_format_name(v, "AdamW v for %s", node->name);
523
0
            }
524
0
            struct ggml_tensor * opt_step;
525
0
            switch (optimizer) {
526
0
                case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
527
0
                    opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
528
0
                    break;
529
0
                case GGML_OPT_OPTIMIZER_TYPE_SGD:
530
0
                    opt_step = ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
531
0
                    break;
532
0
                default:
533
0
                    GGML_ABORT("fatal error");
534
0
            }
535
0
            ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name);
536
0
            ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
537
0
        }
538
0
    }
539
540
0
    if (!opt_ctx->buf_static) {
541
0
        opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
542
0
            opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
543
0
        ggml_graph_reset(opt_ctx->gb_opt);
544
0
    }
545
546
0
    opt_ctx->buf_cpu = ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, ggml_backend_cpu_buffer_type());
547
0
}
548
549
0
ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
550
0
    ggml_opt_context_t result = new struct ggml_opt_context;
551
0
    result->backend_sched    = params.backend_sched;
552
0
    result->ctx_compute      = params.ctx_compute;
553
0
    result->loss_type        = params.loss_type;
554
0
    result->build_type       = params.build_type;
555
0
    result->build_type_alloc = params.build_type;
556
0
    result->inputs           = params.inputs;
557
0
    result->outputs          = params.outputs;
558
0
    result->opt_period       = params.opt_period;
559
0
    result->get_opt_pars     = params.get_opt_pars;
560
0
    result->get_opt_pars_ud  = params.get_opt_pars_ud;
561
0
    result->optimizer        = params.optimizer;
562
563
0
    GGML_ASSERT(result->opt_period >= 1);
564
565
0
    result->static_graphs = result->ctx_compute;
566
567
0
    if (!result->static_graphs) {
568
0
        GGML_ASSERT(!result->inputs);
569
0
        GGML_ASSERT(!result->outputs);
570
0
        return result;
571
0
    }
572
573
0
    GGML_ASSERT(result->inputs);
574
0
    GGML_ASSERT(result->outputs);
575
576
0
    result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
577
0
    ggml_build_forward_expand(result->gf, result->outputs);
578
579
0
    ggml_opt_build(result);
580
581
0
    return result;
582
0
}
583
584
0
void ggml_opt_free(ggml_opt_context_t opt_ctx) {
585
0
    if (opt_ctx == nullptr) {
586
0
        return;
587
0
    }
588
0
    ggml_backend_buffer_free(opt_ctx->buf_static);
589
0
    ggml_backend_buffer_free(opt_ctx->buf_cpu);
590
0
    ggml_free(opt_ctx->ctx_static);
591
0
    ggml_free(opt_ctx->ctx_cpu);
592
0
    delete opt_ctx;
593
0
}
594
595
0
void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) {
596
0
    if (optimizer) {
597
0
        ggml_graph_reset(opt_ctx->gb_opt);
598
0
        opt_ctx->iter = 1;
599
0
    } else {
600
0
        ggml_graph_reset(opt_ctx->gb_grad);
601
0
    }
602
0
}
603
604
0
bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) {
605
0
    return opt_ctx->static_graphs;
606
0
}
607
608
0
struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
609
0
    return opt_ctx->inputs;
610
0
}
611
612
0
struct ggml_tensor * ggml_opt_outputs(ggml_opt_context_t opt_ctx) {
613
0
    return opt_ctx->outputs;
614
0
}
615
616
0
struct ggml_tensor * ggml_opt_labels(ggml_opt_context_t opt_ctx) {
617
0
    return opt_ctx->labels;
618
0
}
619
620
0
struct ggml_tensor * ggml_opt_loss(ggml_opt_context_t opt_ctx) {
621
0
    return opt_ctx->loss;
622
0
}
623
624
0
struct ggml_tensor * ggml_opt_pred(ggml_opt_context_t opt_ctx) {
625
0
    return opt_ctx->pred;
626
0
}
627
628
0
struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx) {
629
0
    return opt_ctx->ncorrect;
630
0
}
631
632
0
struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node) {
633
0
    return ggml_graph_get_grad_acc(opt_ctx->gb_opt, node);
634
0
}
635
636
// ====== Optimization Result ======
637
638
0
ggml_opt_result_t ggml_opt_result_init() {
639
0
    return new ggml_opt_result;
640
0
}
641
642
0
void ggml_opt_result_free(ggml_opt_result_t result) {
643
0
    delete result;
644
0
}
645
646
0
void ggml_opt_result_reset(ggml_opt_result_t result) {
647
0
    result->ndata = 0;
648
0
    result->loss.clear();
649
0
    result->pred.clear();
650
0
    result->ncorrect = 0;
651
0
}
652
653
0
void ggml_opt_result_ndata(ggml_opt_result_t result, int64_t * ndata) {
654
0
    *ndata = result->ndata;
655
0
}
656
657
0
void ggml_opt_result_loss(ggml_opt_result_t result, double * loss, double * unc) {
658
0
    const int64_t nbatches = result->loss.size(); // Number of physical batches.
659
660
0
    if (nbatches == 0) {
661
0
        *loss = 0.0;
662
0
        *unc  = NAN;
663
0
        return;
664
0
    }
665
666
0
    double sum         = 0.0;
667
0
    double sum_squared = 0.0;
668
669
0
    for (const float & loss : result->loss) {
670
        // If the loss is per datapoint it was scaled by 1.0f/opt_period for each physical batch.
671
0
        const float loss_scaled = result->loss_per_datapoint ? loss*result->opt_period : loss;
672
0
        sum         += loss_scaled;
673
0
        sum_squared += loss_scaled*loss_scaled;
674
0
    }
675
676
0
    const double mean = sum/nbatches;
677
0
    *loss = result->loss_per_datapoint ? mean : sum;
678
679
0
    if (!unc) {
680
0
        return;
681
0
    }
682
683
0
    if (nbatches < 2) {
684
0
        *unc = NAN;
685
0
        return;
686
0
    }
687
688
0
    const double var_sum = sum_squared/nbatches - mean*mean; // variance without Bessel's correction, i.e. nbatches/(nbatches-1)
689
0
    *unc = result->loss_per_datapoint ? sqrt(var_sum / (nbatches - 1)) : sqrt(var_sum * nbatches/(nbatches - 1));
690
0
}
691
692
0
void ggml_opt_result_pred(ggml_opt_result_t result, int32_t * pred) {
693
0
    for (size_t i = 0; i < result->pred.size(); ++i) {
694
0
        pred[i] = result->pred[i];
695
0
    }
696
0
}
697
698
0
void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc) {
699
0
    *accuracy = result->ncorrect >= 0 ? double(result->ncorrect) / double(result->ndata) : NAN;
700
701
0
    if (!unc) {
702
0
        return;
703
0
    }
704
705
0
    *unc = result->ncorrect >= 0 && result->ndata >= 2 ?
706
0
        sqrt((*accuracy) * (1.0 - (*accuracy)) / double(result->ndata - 1)) : NAN;
707
0
}
708
709
// ====== Computation ======
710
711
void ggml_opt_prepare_alloc(
712
        ggml_opt_context_t    opt_ctx,
713
        struct ggml_context * ctx_compute,
714
        struct ggml_cgraph  * gf,
715
        struct ggml_tensor  * inputs,
716
0
        struct ggml_tensor  * outputs) {
717
0
    GGML_ASSERT(!opt_ctx->static_graphs);
718
0
    opt_ctx->ctx_compute = ctx_compute;
719
0
    opt_ctx->gf          = gf;
720
0
    opt_ctx->inputs      = inputs;
721
0
    opt_ctx->outputs     = outputs;
722
0
}
723
724
0
void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
725
0
    GGML_ASSERT(!opt_ctx->eval_ready);
726
0
    if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) {
727
0
        ggml_graph_reset(opt_ctx->gb_grad);
728
0
    }
729
0
    if (backward) {
730
0
        const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
731
0
        opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD;
732
0
    } else {
733
0
        opt_ctx->build_type = GGML_OPT_BUILD_TYPE_FORWARD;
734
0
    }
735
736
0
    if (!opt_ctx->static_graphs) {
737
0
        ggml_opt_build(opt_ctx);
738
0
    }
739
740
0
    struct ggml_cgraph * graph = nullptr;
741
0
    switch (opt_ctx->build_type) {
742
0
        case GGML_OPT_BUILD_TYPE_FORWARD: {
743
0
            graph = opt_ctx->gf;
744
0
        } break;
745
0
        case GGML_OPT_BUILD_TYPE_GRAD: {
746
0
            graph = opt_ctx->gb_grad;
747
0
        } break;
748
0
        case GGML_OPT_BUILD_TYPE_OPT: {
749
0
            graph = opt_ctx->gb_opt;
750
0
        } break;
751
0
    }
752
0
    GGML_ASSERT(graph);
753
754
0
    if (opt_ctx->allocated_graph == graph) {
755
0
        opt_ctx->eval_ready = true;
756
0
        return;
757
0
    }
758
759
0
    ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
760
761
0
    if (opt_ctx->static_graphs) {
762
0
        ggml_init_params params = {
763
0
            /*.mem_size   =*/ graph->size*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph->size, graph->grads),
764
0
            /*.mem_buffer =*/ nullptr,
765
0
            /*.no_alloc   =*/ true,
766
0
        };
767
0
        ggml_free(opt_ctx->ctx_copy);
768
0
        opt_ctx->ctx_copy = ggml_init(params);
769
770
0
        opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
771
0
    } else {
772
0
        opt_ctx->allocated_graph_copy = graph;
773
0
    }
774
775
0
    ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
776
0
    opt_ctx->allocated_graph = graph;
777
778
0
    opt_ctx->eval_ready = true;
779
0
}
780
781
0
void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
782
0
    GGML_ASSERT(opt_ctx->eval_ready);
783
0
    if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
784
0
        const ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
785
786
0
        switch (opt_ctx->optimizer) {
787
0
            case GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
788
0
                GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
789
0
                GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
790
0
                GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
791
0
                GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
792
0
                GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
793
0
                GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
794
0
                GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
795
0
                GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
796
797
                // beta1, beta2 after applying warmup
798
0
                const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
799
0
                const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
800
801
0
                float * adamw_par_data = ggml_get_data_f32(opt_ctx->opt_step_params);
802
0
                adamw_par_data[0] = opt_pars.adamw.alpha;
803
0
                adamw_par_data[1] = opt_pars.adamw.beta1;
804
0
                adamw_par_data[2] = opt_pars.adamw.beta2;
805
0
                adamw_par_data[3] = opt_pars.adamw.eps;
806
0
                adamw_par_data[4] = opt_pars.adamw.wd;
807
0
                adamw_par_data[5] = beta1h;
808
0
                adamw_par_data[6] = beta2h;
809
0
            } break;
810
0
            case GGML_OPT_OPTIMIZER_TYPE_SGD: {
811
0
                GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);
812
0
                GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
813
0
                GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
814
0
                float * sgd = ggml_get_data_f32(opt_ctx->opt_step_params);
815
0
                sgd[0] = opt_pars.sgd.alpha;
816
0
                sgd[1] = opt_pars.sgd.wd;
817
0
            } break;
818
0
            default:
819
0
                GGML_ABORT("fatal error");
820
0
        }
821
0
    }
822
823
0
    ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
824
0
    opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
825
0
    opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
826
827
0
    if (!opt_ctx->static_graphs) {
828
0
        opt_ctx->gf                   = nullptr;
829
0
        opt_ctx->gb_grad              = nullptr;
830
0
        opt_ctx->gb_opt               = nullptr;
831
0
        opt_ctx->allocated_graph      = nullptr;
832
0
        opt_ctx->allocated_graph_copy = nullptr;
833
0
    }
834
835
0
    opt_ctx->eval_ready = false;
836
837
0
    if (!result) {
838
0
        return;
839
0
    }
840
841
0
    if (result->ndata == 0) {
842
0
        result->loss_per_datapoint = opt_ctx->loss_per_datapoint;
843
0
        result->opt_period         = opt_ctx->opt_period;
844
0
    } else {
845
0
        GGML_ASSERT(result->loss_per_datapoint == opt_ctx->loss_per_datapoint);
846
0
        GGML_ASSERT(result->opt_period         == opt_ctx->opt_period);
847
0
    }
848
849
0
    const int64_t ndata = opt_ctx->outputs->ne[1];
850
0
    GGML_ASSERT(result->ndata == ndata*int64_t(result->loss.size()) && "varying batch size not supported");
851
0
    result->ndata += ndata;
852
853
0
    GGML_ASSERT(ggml_is_scalar(opt_ctx->loss));
854
0
    GGML_ASSERT(opt_ctx->loss->type == GGML_TYPE_F32);
855
0
    float loss;
856
0
    ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss));
857
0
    result->loss.push_back(loss);
858
859
0
    if (opt_ctx->pred) {
860
0
        GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
861
0
        std::vector<int32_t> pred(ndata);
862
0
        ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
863
0
        result->pred.insert(result->pred.end(), pred.begin(), pred.end());
864
0
    }
865
866
0
    if (!opt_ctx->ncorrect || result->ncorrect < 0) {
867
0
        result->ncorrect = -1;
868
0
        return;
869
0
    }
870
871
0
    GGML_ASSERT(ggml_is_scalar(opt_ctx->ncorrect));
872
0
    GGML_ASSERT(opt_ctx->ncorrect->type == GGML_TYPE_I64);
873
0
    int64_t ncorrect;
874
0
    ggml_backend_tensor_get(opt_ctx->ncorrect, &ncorrect, 0, ggml_nbytes(opt_ctx->ncorrect));
875
0
    result->ncorrect += ncorrect;
876
0
}
877
878
// ====== High-Level Functions ======
879
880
void ggml_opt_epoch(
881
        ggml_opt_context_t      opt_ctx,
882
        ggml_opt_dataset_t      dataset,
883
        ggml_opt_result_t       result_train,
884
        ggml_opt_result_t       result_eval,
885
        int64_t                 idata_split,
886
        ggml_opt_epoch_callback callback_train,
887
0
        ggml_opt_epoch_callback callback_eval) {
888
0
    GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs");
889
0
    struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
890
0
    struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
891
0
    struct ggml_tensor * data   = ggml_opt_dataset_data(dataset);
892
0
    GGML_ASSERT(data->ne[0] == inputs->ne[0]);
893
894
0
    const int64_t ndata       =   data->ne[1];
895
0
    const int64_t ndata_batch = inputs->ne[1];
896
897
0
    GGML_ASSERT(data->ne[1] % inputs->ne[1] == 0);
898
0
    const int64_t nbatches = ndata/ndata_batch;
899
900
0
    idata_split = idata_split < 0 ? ndata : idata_split;
901
0
    GGML_ASSERT(idata_split % ndata_batch == 0);
902
0
    const int64_t ibatch_split = idata_split / ndata_batch;
903
904
0
    int64_t ibatch = 0;
905
0
    int64_t t_loop_start = ggml_time_us();
906
0
    for (; ibatch < ibatch_split; ++ibatch) {
907
0
        ggml_opt_alloc(opt_ctx, /*backward =*/ true);
908
0
        ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
909
0
        ggml_opt_eval(opt_ctx, result_train);
910
0
        if (callback_train) {
911
0
            callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
912
0
        }
913
0
    }
914
0
    t_loop_start = ggml_time_us();
915
0
    for (; ibatch < nbatches; ++ibatch) {
916
0
        ggml_opt_alloc(opt_ctx, /*backward =*/ false);
917
0
        ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
918
0
        ggml_opt_eval(opt_ctx, result_eval);
919
0
        if (callback_eval) {
920
0
            callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
921
0
        }
922
0
    }
923
0
}
924
925
void ggml_opt_epoch_callback_progress_bar(
926
        bool               train,
927
        ggml_opt_context_t opt_ctx,
928
        ggml_opt_dataset_t dataset,
929
        ggml_opt_result_t  result,
930
        int64_t            ibatch,
931
        int64_t            ibatch_max,
932
0
        int64_t            t_start_us) {
933
0
    fprintf(stderr, "%s[", train ? "train: " : "val:   ");
934
935
    // The progress bar consists of partially filled blocks, unicode has 8 separate fill levels.
936
0
    constexpr int64_t bar_length = 8;
937
0
    const int64_t ibatch8 = 8 * ibatch;
938
0
    for (int64_t j = 0; j < bar_length; ++j) {
939
0
        if        (ibatch_max * (8*j + 8) / bar_length < ibatch8) {
940
0
            fprintf(stderr, "\u2588"); // full block
941
0
        } else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) {
942
0
            fprintf(stderr, "\u2589"); // 7/8 filled
943
0
        } else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) {
944
0
            fprintf(stderr, "\u258A"); // 6/8 filled
945
0
        } else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) {
946
0
            fprintf(stderr, "\u258B"); // 5/8 filled
947
0
        } else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) {
948
0
            fprintf(stderr, "\u258C"); // 4/8 filled
949
0
        } else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) {
950
0
            fprintf(stderr, "\u258D"); // 3/8 filled
951
0
        } else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) {
952
0
            fprintf(stderr, "\u258E"); // 2/8 filled
953
0
        } else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) {
954
0
            fprintf(stderr, "\u258F"); // 1/8 filled
955
0
        } else {
956
0
            fprintf(stderr, " ");
957
0
        }
958
0
    }
959
960
0
    const int64_t batch_size = ggml_opt_inputs(opt_ctx)->ne[1];
961
0
    const int64_t idata      = ibatch*batch_size;
962
0
    const int64_t idata_max  = ibatch_max*batch_size;
963
964
0
    double loss;
965
0
    double loss_unc;
966
0
    ggml_opt_result_loss(result, &loss, &loss_unc);
967
968
0
    double accuracy;
969
0
    double accuracy_unc;
970
0
    ggml_opt_result_accuracy(result, &accuracy, &accuracy_unc);
971
972
0
    const int64_t t_ibatch_us = ggml_time_us() - t_start_us;
973
0
    int64_t t_ibatch_s = t_ibatch_us / 1000000;
974
0
    const int64_t t_ibatch_h = t_ibatch_s / 3600;
975
0
    t_ibatch_s -= t_ibatch_h * 3600;
976
0
    const int64_t t_ibatch_m = t_ibatch_s / 60;
977
0
    t_ibatch_s -= t_ibatch_m * 60;
978
979
0
    const int64_t t_eta_us = t_ibatch_us * (ibatch_max - ibatch)/ibatch;
980
0
    int64_t t_eta_s = t_eta_us / 1000000;
981
0
    const int64_t t_eta_h = t_eta_s / 3600;
982
0
    t_eta_s -= t_eta_h * 3600;
983
0
    const int64_t t_eta_m = t_eta_s / 60;
984
0
    t_eta_s -= t_eta_m * 60;
985
986
0
    fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% "
987
0
            "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r",
988
0
            idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
989
0
            t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
990
0
    if (ibatch == ibatch_max) {
991
0
        fprintf(stderr, "\n");
992
0
    }
993
0
    fflush(stderr);
994
995
0
    GGML_UNUSED(dataset);
996
0
}
997
998
void ggml_opt_fit(
999
        ggml_backend_sched_t            backend_sched,
1000
        ggml_context                  * ctx_compute,
1001
        ggml_tensor                   * inputs,
1002
        ggml_tensor                   * outputs,
1003
        ggml_opt_dataset_t              dataset,
1004
        enum ggml_opt_loss_type         loss_type,
1005
        enum ggml_opt_optimizer_type    optimizer,
1006
        ggml_opt_get_optimizer_params   get_opt_pars,
1007
        int64_t                         nepoch,
1008
        int64_t                         nbatch_logical,
1009
        float                           val_split,
1010
0
        bool                            silent) {
1011
0
    ggml_time_init();
1012
0
    const int64_t t_start_us = ggml_time_us();
1013
1014
0
    const int64_t ndata           = ggml_opt_dataset_data(dataset)->ne[1];
1015
0
    const int64_t nbatch_physical = inputs->ne[1];
1016
0
    GGML_ASSERT(ndata          % nbatch_logical  == 0);
1017
0
    GGML_ASSERT(nbatch_logical % nbatch_physical == 0);
1018
1019
0
    const int64_t opt_period       = nbatch_logical / nbatch_physical;
1020
0
    const int64_t nbatches_logical = ndata / nbatch_logical;
1021
1022
0
    GGML_ASSERT(val_split >= 0.0f);
1023
0
    GGML_ASSERT(val_split <  1.0f);
1024
0
    const int64_t ibatch_split = int64_t(((1.0f - val_split) * nbatches_logical)) * opt_period; // train <-> val split index (physical)
1025
0
    const int64_t idata_split  = ibatch_split * nbatch_physical;
1026
1027
0
    int64_t epoch = 1;
1028
1029
0
    ggml_opt_params params = ggml_opt_default_params(backend_sched, loss_type);
1030
0
    params.ctx_compute     = ctx_compute;
1031
0
    params.inputs          = inputs;
1032
0
    params.outputs         = outputs;
1033
0
    params.opt_period      = opt_period;
1034
0
    params.get_opt_pars    = get_opt_pars;
1035
0
    params.get_opt_pars_ud = &epoch;
1036
0
    params.optimizer       = optimizer;
1037
0
    ggml_opt_context_t opt_ctx = ggml_opt_init(params);
1038
1039
    // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
1040
0
    if (nbatch_logical < ndata) {
1041
0
        ggml_opt_dataset_shuffle(opt_ctx, dataset, -1); // Shuffle all data (train + validation).
1042
0
    }
1043
1044
0
    ggml_opt_result_t result_train = ggml_opt_result_init();
1045
0
    ggml_opt_result_t result_val   = ggml_opt_result_init();
1046
1047
0
    ggml_opt_epoch_callback epoch_callback = silent ? nullptr : ggml_opt_epoch_callback_progress_bar;
1048
1049
0
    for (; epoch <= nepoch; ++epoch) {
1050
0
        if (nbatch_logical < idata_split) {
1051
0
            ggml_opt_dataset_shuffle(opt_ctx, dataset, idata_split);
1052
0
        }
1053
1054
0
        ggml_opt_result_reset(result_train);
1055
0
        ggml_opt_result_reset(result_val);
1056
1057
0
        if (!silent) {
1058
0
            fprintf(stderr, "%s: epoch %04" PRId64 "/%04" PRId64 ":\n", __func__, epoch, nepoch);
1059
0
        }
1060
0
        ggml_opt_epoch(opt_ctx, dataset, result_train, result_val, idata_split, epoch_callback, epoch_callback);
1061
0
        if (!silent) {
1062
0
            fprintf(stderr, "\n");
1063
0
        }
1064
0
    }
1065
1066
0
    if (!silent) {
1067
0
        int64_t t_total_s = (ggml_time_us() - t_start_us) / 1000000;
1068
0
        const int64_t t_total_h = t_total_s / 3600;
1069
0
        t_total_s -= t_total_h * 3600;
1070
0
        const int64_t t_total_m = t_total_s / 60;
1071
0
        t_total_s -= t_total_m * 60;
1072
0
        fprintf(stderr, "%s: training took %02" PRId64 ":%02" PRId64 ":%02" PRId64 "\n", __func__, t_total_h, t_total_m, t_total_s);
1073
0
    }
1074
1075
0
    ggml_opt_free(opt_ctx);
1076
0
    ggml_opt_result_free(result_train);
1077
0
    ggml_opt_result_free(result_val);
1078
0
}
1079
1080
0
enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t c) {
1081
0
    return c->optimizer;
1082
0
}
1083
1084
0
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type o) {
1085
0
    switch (o) {
1086
0
        case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
1087
0
            return "adamw";
1088
0
        case GGML_OPT_OPTIMIZER_TYPE_SGD:
1089
0
            return "sgd";
1090
0
        default:
1091
0
            return "undefined";
1092
0
    };
1093
0
}