Coverage Report

Created: 2025-11-16 06:54

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/fftw3/rdft/ct-hc2c-direct.c
Line
Count
Source
1
/*
2
 * Copyright (c) 2003, 2007-14 Matteo Frigo
3
 * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
4
 *
5
 * This program is free software; you can redistribute it and/or modify
6
 * it under the terms of the GNU General Public License as published by
7
 * the Free Software Foundation; either version 2 of the License, or
8
 * (at your option) any later version.
9
 *
10
 * This program is distributed in the hope that it will be useful,
11
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
 * GNU General Public License for more details.
14
 *
15
 * You should have received a copy of the GNU General Public License
16
 * along with this program; if not, write to the Free Software
17
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
18
 *
19
 */
20
21
22
#include "ct-hc2c.h"
23
24
typedef struct {
25
     hc2c_solver super;
26
     const hc2c_desc *desc;
27
     int bufferedp;
28
     khc2c k;
29
} S;
30
31
typedef struct {
32
     plan_hc2c super;
33
     khc2c k;
34
     plan *cld0, *cldm; /* children for 0th and middle butterflies */
35
     INT r, m, v, extra_iter;
36
     INT ms, vs;
37
     stride rs, brs;
38
     twid *td;
39
     const S *slv;
40
} P;
41
42
/*************************************************************
43
  Nonbuffered code
44
 *************************************************************/
45
static void apply(const plan *ego_, R *cr, R *ci)
46
0
{
47
0
     const P *ego = (const P *) ego_;
48
0
     plan_rdft2 *cld0 = (plan_rdft2 *) ego->cld0;
49
0
     plan_rdft2 *cldm = (plan_rdft2 *) ego->cldm;
50
0
     INT i, m = ego->m, v = ego->v;
51
0
     INT ms = ego->ms, vs = ego->vs;
52
53
0
     for (i = 0; i < v; ++i, cr += vs, ci += vs) {
54
0
    cld0->apply((plan *) cld0, cr, ci, cr, ci);
55
0
    ego->k(cr + ms, ci + ms, cr + (m-1)*ms, ci + (m-1)*ms,
56
0
     ego->td->W, ego->rs, 1, (m+1)/2, ms);
57
0
    cldm->apply((plan *) cldm, cr + (m/2)*ms, ci + (m/2)*ms, 
58
0
          cr + (m/2)*ms, ci + (m/2)*ms);
59
0
     }
60
0
}
61
62
static void apply_extra_iter(const plan *ego_, R *cr, R *ci)
63
0
{
64
0
     const P *ego = (const P *) ego_;
65
0
     plan_rdft2 *cld0 = (plan_rdft2 *) ego->cld0;
66
0
     plan_rdft2 *cldm = (plan_rdft2 *) ego->cldm;
67
0
     INT i, m = ego->m, v = ego->v;
68
0
     INT ms = ego->ms, vs = ego->vs;
69
0
     INT mm = (m-1)/2;
70
71
0
     for (i = 0; i < v; ++i, cr += vs, ci += vs) {
72
0
    cld0->apply((plan *) cld0, cr, ci, cr, ci);
73
74
    /* for 4-way SIMD when (m+1)/2-1 is odd: iterate over an
75
       even vector length MM-1, and then execute the last
76
       iteration as a 2-vector with vector stride 0.  The
77
       twiddle factors of the second half of the last iteration
78
       are bogus, but we only store the results of the first
79
       half. */
80
0
    ego->k(cr + ms, ci + ms, cr + (m-1)*ms, ci + (m-1)*ms,
81
0
     ego->td->W, ego->rs, 1, mm, ms);
82
0
    ego->k(cr + mm*ms, ci + mm*ms, cr + (m-mm)*ms, ci + (m-mm)*ms,
83
0
     ego->td->W, ego->rs, mm, mm+2, 0);
84
0
    cldm->apply((plan *) cldm, cr + (m/2)*ms, ci + (m/2)*ms, 
85
0
          cr + (m/2)*ms, ci + (m/2)*ms);
86
0
     }
87
88
0
}
89
90
/*************************************************************
91
  Buffered code
92
 *************************************************************/
93
94
/* should not be 2^k to avoid associativity conflicts */
95
static INT compute_batchsize(INT radix)
96
0
{
97
     /* round up to multiple of 4 */
98
0
     radix += 3;
99
0
     radix &= -4;
100
101
0
     return (radix + 2);
102
0
}
103
104
static void dobatch(const P *ego, R *Rp, R *Ip, R *Rm, R *Im,
105
        INT mb, INT me, INT extra_iter, R *bufp)
106
0
{
107
0
     INT b = WS(ego->brs, 1);
108
0
     INT rs = WS(ego->rs, 1);
109
0
     INT ms = ego->ms;
110
0
     R *bufm = bufp + b - 2;
111
0
     INT n = me - mb;
112
113
0
     X(cpy2d_pair_ci)(Rp + mb * ms, Ip + mb * ms, bufp, bufp + 1,
114
0
          ego->r / 2, rs, b,
115
0
          n, ms, 2);
116
0
     X(cpy2d_pair_ci)(Rm - mb * ms, Im - mb * ms, bufm, bufm + 1,
117
0
          ego->r / 2, rs, b,
118
0
          n, -ms, -2);
119
120
0
     if (extra_iter) {
121
          /* initialize the extra_iter element to 0.  It would be ok
122
             to leave it uninitialized, since we transform uninitialized
123
             data and ignore the result.  However, we want to avoid
124
             FP exceptions in case somebody is trapping them. */
125
0
          A(n < compute_batchsize(ego->r));
126
0
          X(zero1d_pair)(bufp + 2*n, bufp + 1 + 2*n, ego->r / 2, b);
127
0
          X(zero1d_pair)(bufm - 2*n, bufm + 1 - 2*n, ego->r / 2, b);
128
0
     }
129
130
0
     ego->k(bufp, bufp + 1, bufm, bufm + 1, ego->td->W, 
131
0
      ego->brs, mb, me + extra_iter, 2);
132
0
     X(cpy2d_pair_co)(bufp, bufp + 1, Rp + mb * ms, Ip + mb * ms, 
133
0
          ego->r / 2, b, rs,
134
0
          n, 2, ms);
135
0
     X(cpy2d_pair_co)(bufm, bufm + 1, Rm - mb * ms, Im - mb * ms,
136
0
          ego->r / 2, b, rs,
137
0
          n, -2, -ms);
138
0
}
139
140
static void apply_buf(const plan *ego_, R *cr, R *ci)
141
0
{
142
0
     const P *ego = (const P *) ego_;
143
0
     plan_rdft2 *cld0 = (plan_rdft2 *) ego->cld0;
144
0
     plan_rdft2 *cldm = (plan_rdft2 *) ego->cldm;
145
0
     INT i, j, ms = ego->ms, v = ego->v;
146
0
     INT batchsz = compute_batchsize(ego->r);
147
0
     R *buf;
148
0
     INT mb = 1, me = (ego->m+1) / 2;
149
0
     size_t bufsz = ego->r * batchsz * 2 * sizeof(R);
150
151
0
     BUF_ALLOC(R *, buf, bufsz);
152
153
0
     for (i = 0; i < v; ++i, cr += ego->vs, ci += ego->vs) {
154
0
    R *Rp = cr;
155
0
    R *Ip = ci;
156
0
    R *Rm = cr + ego->m * ms;
157
0
    R *Im = ci + ego->m * ms;
158
159
0
    cld0->apply((plan *) cld0, Rp, Ip, Rp, Ip);
160
161
0
    for (j = mb; j + batchsz < me; j += batchsz) 
162
0
         dobatch(ego, Rp, Ip, Rm, Im, j, j + batchsz, 0, buf);
163
164
0
    dobatch(ego, Rp, Ip, Rm, Im, j, me, ego->extra_iter, buf);
165
166
0
    cldm->apply((plan *) cldm, 
167
0
          Rp + me * ms, Ip + me * ms,
168
0
          Rp + me * ms, Ip + me * ms);
169
170
0
     }
171
172
0
     BUF_FREE(buf, bufsz);
173
0
}
174
175
/*************************************************************
176
  common code
177
 *************************************************************/
178
static void awake(plan *ego_, enum wakefulness wakefulness)
179
0
{
180
0
     P *ego = (P *) ego_;
181
182
0
     X(plan_awake)(ego->cld0, wakefulness);
183
0
     X(plan_awake)(ego->cldm, wakefulness);
184
0
     X(twiddle_awake)(wakefulness, &ego->td, ego->slv->desc->tw, 
185
0
          ego->r * ego->m, ego->r, 
186
0
          (ego->m - 1) / 2 + ego->extra_iter);
187
0
}
188
189
static void destroy(plan *ego_)
190
0
{
191
0
     P *ego = (P *) ego_;
192
0
     X(plan_destroy_internal)(ego->cld0);
193
0
     X(plan_destroy_internal)(ego->cldm);
194
0
     X(stride_destroy)(ego->rs);
195
0
     X(stride_destroy)(ego->brs);
196
0
}
197
198
static void print(const plan *ego_, printer *p)
199
0
{
200
0
     const P *ego = (const P *) ego_;
201
0
     const S *slv = ego->slv;
202
0
     const hc2c_desc *e = slv->desc;
203
204
0
     if (slv->bufferedp)
205
0
    p->print(p, "(hc2c-directbuf/%D-%D/%D/%D%v \"%s\"%(%p%)%(%p%))",
206
0
       compute_batchsize(ego->r),
207
0
       ego->r, X(twiddle_length)(ego->r, e->tw),
208
0
       ego->extra_iter, ego->v, e->nam, 
209
0
       ego->cld0, ego->cldm);
210
0
     else
211
0
    p->print(p, "(hc2c-direct-%D/%D/%D%v \"%s\"%(%p%)%(%p%))",
212
0
       ego->r, X(twiddle_length)(ego->r, e->tw), 
213
0
       ego->extra_iter, ego->v, e->nam, 
214
0
       ego->cld0, ego->cldm);
215
0
}
216
217
static int applicable0(const S *ego, rdft_kind kind,
218
           INT r, INT rs,
219
           INT m, INT ms, 
220
           INT v, INT vs,
221
           const R *cr, const R *ci,
222
           const planner *plnr,
223
           INT *extra_iter)
224
0
{
225
0
     const hc2c_desc *e = ego->desc;
226
0
     UNUSED(v);
227
228
0
     return (
229
0
    1
230
0
    && r == e->radix
231
0
    && kind == e->genus->kind
232
233
    /* first v-loop iteration */
234
0
    && ((*extra_iter = 0,
235
0
         e->genus->okp(cr + ms, ci + ms, cr + (m-1)*ms, ci + (m-1)*ms,
236
0
           rs, 1, (m+1)/2, ms, plnr))
237
0
              ||
238
0
        (*extra_iter = 1,
239
0
         ((e->genus->okp(cr + ms, ci + ms, cr + (m-1)*ms, ci + (m-1)*ms,
240
0
             rs, 1, (m-1)/2, ms, plnr))
241
0
    &&
242
0
    (e->genus->okp(cr + ms, ci + ms, cr + (m-1)*ms, ci + (m-1)*ms,
243
0
             rs, (m-1)/2, (m-1)/2 + 2, 0, plnr)))))
244
    
245
    /* subsequent v-loop iterations */
246
0
    && (cr += vs, ci += vs, 1)
247
248
0
    && e->genus->okp(cr + ms, ci + ms, cr + (m-1)*ms, ci + (m-1)*ms,
249
0
         rs, 1, (m+1)/2 - *extra_iter, ms, plnr)
250
0
    );
251
0
}
252
253
static int applicable0_buf(const S *ego, rdft_kind kind,
254
         INT r, INT rs,
255
         INT m, INT ms, 
256
         INT v, INT vs,
257
         const R *cr, const R *ci,
258
         const planner *plnr, INT *extra_iter)
259
0
{
260
0
     const hc2c_desc *e = ego->desc;
261
0
     INT batchsz, brs;
262
0
     UNUSED(v); UNUSED(rs); UNUSED(ms); UNUSED(vs);
263
264
0
     return (
265
0
    1
266
0
    && r == e->radix
267
0
    && kind == e->genus->kind
268
269
    /* ignore cr, ci, use buffer */
270
0
    && (cr = (const R *)0, ci = cr + 1, 
271
0
        batchsz = compute_batchsize(r), 
272
0
        brs = 4 * batchsz, 1)
273
274
0
    && e->genus->okp(cr, ci, cr + brs - 2, ci + brs - 2, 
275
0
         brs, 1, 1+batchsz, 2, plnr)
276
277
0
    && ((*extra_iter = 0,
278
0
         e->genus->okp(cr, ci, cr + brs - 2, ci + brs - 2, 
279
0
           brs, 1, 1 + (((m-1)/2) % batchsz), 2, plnr))
280
0
        ||
281
0
        (*extra_iter = 1,
282
0
         e->genus->okp(cr, ci, cr + brs - 2, ci + brs - 2, 
283
0
           brs, 1, 1 + 1 + (((m-1)/2) % batchsz), 2, plnr)))
284
        
285
0
    );
286
0
}
287
288
static int applicable(const S *ego, rdft_kind kind,
289
          INT r, INT rs,
290
          INT m, INT ms, 
291
          INT v, INT vs,
292
          R *cr, R *ci,
293
          const planner *plnr, INT *extra_iter)
294
0
{
295
0
     if (ego->bufferedp) {
296
0
    if (!applicable0_buf(ego, kind, r, rs, m, ms, v, vs, cr, ci, plnr,
297
0
             extra_iter))
298
0
         return 0;
299
0
     } else {
300
0
    if (!applicable0(ego, kind, r, rs, m, ms, v, vs, cr, ci, plnr,
301
0
         extra_iter))
302
0
         return 0;
303
0
     }
304
305
0
     if (NO_UGLYP(plnr) && X(ct_uglyp)((ego->bufferedp? (INT)512 : (INT)16),
306
0
               v, m * r, r))
307
0
    return 0;
308
309
0
     return 1;
310
0
}
311
312
static plan *mkcldw(const hc2c_solver *ego_, rdft_kind kind,
313
        INT r, INT rs,
314
        INT m, INT ms, 
315
        INT v, INT vs,
316
        R *cr, R *ci,
317
        planner *plnr)
318
0
{
319
0
     const S *ego = (const S *) ego_;
320
0
     P *pln;
321
0
     const hc2c_desc *e = ego->desc;
322
0
     plan *cld0 = 0, *cldm = 0;
323
0
     INT imid = (m / 2) * ms;
324
0
     INT extra_iter;
325
326
0
     static const plan_adt padt = {
327
0
    0, awake, print, destroy
328
0
     };
329
330
0
     if (!applicable(ego, kind, r, rs, m, ms, v, vs, cr, ci, plnr, 
331
0
         &extra_iter))
332
0
          return (plan *)0;
333
334
0
     cld0 = X(mkplan_d)(
335
0
    plnr, 
336
0
    X(mkproblem_rdft2_d)(X(mktensor_1d)(r, rs, rs),
337
0
             X(mktensor_0d)(),
338
0
             TAINT(cr, vs), TAINT(ci, vs),
339
0
             TAINT(cr, vs), TAINT(ci, vs),
340
0
             kind));
341
0
     if (!cld0) goto nada;
342
343
0
     cldm = X(mkplan_d)(
344
0
    plnr, 
345
0
    X(mkproblem_rdft2_d)(((m % 2) ?
346
0
        X(mktensor_0d)() : X(mktensor_1d)(r, rs, rs) ),
347
0
             X(mktensor_0d)(),
348
0
             TAINT(cr + imid, vs), TAINT(ci + imid, vs),
349
0
             TAINT(cr + imid, vs), TAINT(ci + imid, vs),
350
0
             kind == R2HC ? R2HCII : HC2RIII));
351
0
     if (!cldm) goto nada;
352
353
0
     if (ego->bufferedp)
354
0
    pln = MKPLAN_HC2C(P, &padt, apply_buf);
355
0
     else
356
0
    pln = MKPLAN_HC2C(P, &padt, extra_iter ? apply_extra_iter : apply);
357
358
0
     pln->k = ego->k;
359
0
     pln->td = 0;
360
0
     pln->r = r; pln->rs = X(mkstride)(r, rs);
361
0
     pln->m = m; pln->ms = ms;
362
0
     pln->v = v; pln->vs = vs;
363
0
     pln->slv = ego;
364
0
     pln->brs = X(mkstride)(r, 4 * compute_batchsize(r));
365
0
     pln->cld0 = cld0;
366
0
     pln->cldm = cldm;
367
0
     pln->extra_iter = extra_iter;
368
369
0
     X(ops_zero)(&pln->super.super.ops);
370
0
     X(ops_madd2)(v * (((m - 1) / 2) / e->genus->vl),
371
0
      &e->ops, &pln->super.super.ops);
372
0
     X(ops_madd2)(v, &cld0->ops, &pln->super.super.ops);
373
0
     X(ops_madd2)(v, &cldm->ops, &pln->super.super.ops);
374
375
0
     if (ego->bufferedp) 
376
0
    pln->super.super.ops.other += 4 * r * m * v;
377
378
0
     return &(pln->super.super);
379
380
0
 nada:
381
0
     X(plan_destroy_internal)(cld0);
382
0
     X(plan_destroy_internal)(cldm);
383
0
     return 0;
384
0
}
385
386
static void regone(planner *plnr, khc2c codelet,
387
       const hc2c_desc *desc, 
388
       hc2c_kind hc2ckind, 
389
       int bufferedp)
390
112
{
391
112
     S *slv = (S *)X(mksolver_hc2c)(sizeof(S), desc->radix, hc2ckind, mkcldw);
392
112
     slv->k = codelet;
393
112
     slv->desc = desc;
394
112
     slv->bufferedp = bufferedp;
395
112
     REGISTER_SOLVER(plnr, &(slv->super.super));
396
112
}
397
398
void X(regsolver_hc2c_direct)(planner *plnr, khc2c codelet,
399
            const hc2c_desc *desc,
400
            hc2c_kind hc2ckind)
401
56
{
402
56
     regone(plnr, codelet, desc, hc2ckind, /* bufferedp */0);
403
56
     regone(plnr, codelet, desc, hc2ckind, /* bufferedp */1);
404
56
}