Coverage Report

Created: 2025-07-11 06:55

/src/fftw3/dft/dftw-direct.c
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * Copyright (c) 2003, 2007-14 Matteo Frigo
3
 * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
4
 *
5
 * This program is free software; you can redistribute it and/or modify
6
 * it under the terms of the GNU General Public License as published by
7
 * the Free Software Foundation; either version 2 of the License, or
8
 * (at your option) any later version.
9
 *
10
 * This program is distributed in the hope that it will be useful,
11
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
 * GNU General Public License for more details.
14
 *
15
 * You should have received a copy of the GNU General Public License
16
 * along with this program; if not, write to the Free Software
17
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
18
 *
19
 */
20
21
22
#include "dft/ct.h"
23
24
typedef struct {
25
     ct_solver super;
26
     const ct_desc *desc;
27
     int bufferedp;
28
     kdftw k;
29
} S;
30
31
typedef struct {
32
     plan_dftw super;
33
     kdftw k;
34
     INT r;
35
     stride rs;
36
     INT m, ms, v, vs, mb, me, extra_iter;
37
     stride brs;
38
     twid *td;
39
     const S *slv;
40
} P;
41
42
43
/*************************************************************
44
  Nonbuffered code
45
 *************************************************************/
46
static void apply(const plan *ego_, R *rio, R *iio)
47
1.28k
{
48
1.28k
     const P *ego = (const P *) ego_;
49
1.28k
     INT i;
50
1.28k
     ASSERT_ALIGNED_DOUBLE;
51
2.57k
     for (i = 0; i < ego->v; ++i, rio += ego->vs, iio += ego->vs) {
52
1.28k
    INT  mb = ego->mb, ms = ego->ms;
53
1.28k
    ego->k(rio + mb*ms, iio + mb*ms, ego->td->W, 
54
1.28k
     ego->rs, mb, ego->me, ms);
55
1.28k
     }
56
1.28k
}
57
58
static void apply_extra_iter(const plan *ego_, R *rio, R *iio)
59
0
{
60
0
     const P *ego = (const P *) ego_;
61
0
     INT i, v = ego->v, vs = ego->vs;
62
0
     INT mb = ego->mb, me = ego->me, mm = me - 1, ms = ego->ms;
63
0
     ASSERT_ALIGNED_DOUBLE;
64
0
     for (i = 0; i < v; ++i, rio += vs, iio += vs) {
65
0
    ego->k(rio + mb*ms, iio + mb*ms, ego->td->W, 
66
0
     ego->rs, mb, mm, ms);
67
0
    ego->k(rio + mm*ms, iio + mm*ms, ego->td->W, 
68
0
     ego->rs, mm, mm+2, 0);
69
0
     }
70
0
}
71
72
/*************************************************************
73
  Buffered code
74
 *************************************************************/
75
static void dobatch(const P *ego, R *rA, R *iA, INT mb, INT me, R *buf)
76
0
{
77
0
     INT brs = WS(ego->brs, 1);
78
0
     INT rs = WS(ego->rs, 1);
79
0
     INT ms = ego->ms;
80
81
0
     X(cpy2d_pair_ci)(rA + mb*ms, iA + mb*ms, buf, buf + 1,
82
0
          ego->r, rs, brs,
83
0
          me - mb, ms, 2);
84
0
     ego->k(buf, buf + 1, ego->td->W, ego->brs, mb, me, 2);
85
0
     X(cpy2d_pair_co)(buf, buf + 1, rA + mb*ms, iA + mb*ms,
86
0
          ego->r, brs, rs,
87
0
          me - mb, 2, ms);
88
0
}
89
90
/* must be even for SIMD alignment; should not be 2^k to avoid
91
   associativity conflicts */
92
static INT compute_batchsize(INT radix)
93
1.74k
{
94
     /* round up to multiple of 4 */
95
1.74k
     radix += 3;
96
1.74k
     radix &= -4;
97
98
1.74k
     return (radix + 2);
99
1.74k
}
100
101
static void apply_buf(const plan *ego_, R *rio, R *iio)
102
0
{
103
0
     const P *ego = (const P *) ego_;
104
0
     INT i, j, v = ego->v, r = ego->r;
105
0
     INT batchsz = compute_batchsize(r);
106
0
     R *buf;
107
0
     INT mb = ego->mb, me = ego->me;
108
0
     size_t bufsz = r * batchsz * 2 * sizeof(R);
109
110
0
     BUF_ALLOC(R *, buf, bufsz);
111
112
0
     for (i = 0; i < v; ++i, rio += ego->vs, iio += ego->vs) {
113
0
    for (j = mb; j + batchsz < me; j += batchsz) 
114
0
         dobatch(ego, rio, iio, j, j + batchsz, buf);
115
116
0
    dobatch(ego, rio, iio, j, me, buf);
117
0
     }
118
119
0
     BUF_FREE(buf, bufsz);
120
0
}
121
122
/*************************************************************
123
  common code
124
 *************************************************************/
125
static void awake(plan *ego_, enum wakefulness wakefulness)
126
690
{
127
690
     P *ego = (P *) ego_;
128
129
690
     X(twiddle_awake)(wakefulness, &ego->td, ego->slv->desc->tw,
130
690
          ego->r * ego->m, ego->r, ego->m + ego->extra_iter);
131
690
}
132
133
static void destroy(plan *ego_)
134
1.12k
{
135
1.12k
     P *ego = (P *) ego_;
136
1.12k
     X(stride_destroy)(ego->brs);
137
1.12k
     X(stride_destroy)(ego->rs);
138
1.12k
}
139
140
static void print(const plan *ego_, printer *p)
141
0
{
142
0
     const P *ego = (const P *) ego_;
143
0
     const S *slv = ego->slv;
144
0
     const ct_desc *e = slv->desc;
145
146
0
     if (slv->bufferedp)
147
0
    p->print(p, "(dftw-directbuf/%D-%D/%D%v \"%s\")",
148
0
       compute_batchsize(ego->r), ego->r,
149
0
       X(twiddle_length)(ego->r, e->tw), ego->v, e->nam);
150
0
     else
151
0
    p->print(p, "(dftw-direct-%D/%D%v \"%s\")",
152
0
       ego->r, X(twiddle_length)(ego->r, e->tw), ego->v, e->nam);
153
0
}
154
155
static int applicable0(const S *ego,
156
           INT r, INT irs, INT ors,
157
           INT m, INT ms,
158
           INT v, INT ivs, INT ovs,
159
           INT mb, INT me,
160
           R *rio, R *iio,
161
           const planner *plnr, INT *extra_iter)
162
1.16k
{
163
1.16k
     const ct_desc *e = ego->desc;
164
1.16k
     UNUSED(v);
165
166
1.16k
     return (
167
1.16k
    1
168
1.16k
    && r == e->radix
169
1.16k
    && irs == ors /* in-place along R */
170
1.16k
    && ivs == ovs /* in-place along V */
171
172
    /* check for alignment/vector length restrictions */
173
1.16k
    && ((*extra_iter = 0,
174
1.16k
         e->genus->okp(e, rio, iio, irs, ivs, m, mb, me, ms, plnr))
175
1.16k
        ||
176
1.16k
        (*extra_iter = 1,
177
0
         (1
178
    /* FIXME: require full array, otherwise some threads
179
       may be extra_iter and other threads won't be.
180
       Generating the proper twiddle factors is a pain in
181
       this case */
182
0
    && mb == 0 && me == m
183
0
    && e->genus->okp(e, rio, iio, irs, ivs,
184
0
         m, mb, me - 1, ms, plnr)
185
0
    && e->genus->okp(e, rio, iio, irs, ivs,
186
0
         m, me - 1, me + 1, ms, plnr))))
187
188
1.16k
    && (e->genus->okp(e, rio + ivs, iio + ivs, irs, ivs,
189
1.16k
          m, mb, me - *extra_iter, ms, plnr))
190
191
1.16k
    );
192
1.16k
}
193
194
static int applicable0_buf(const S *ego,
195
         INT r, INT irs, INT ors,
196
         INT m, INT ms,
197
         INT v, INT ivs, INT ovs,
198
         INT mb, INT me,
199
         R *rio, R *iio,
200
         const planner *plnr)
201
624
{
202
624
     const ct_desc *e = ego->desc;
203
624
     INT batchsz;
204
624
     UNUSED(v); UNUSED(ms); UNUSED(rio); UNUSED(iio);
205
206
624
     return (
207
624
    1
208
624
    && r == e->radix
209
624
    && irs == ors /* in-place along R */
210
624
    && ivs == ovs /* in-place along V */
211
212
    /* check for alignment/vector length restrictions, both for
213
       batchsize and for the remainder */
214
624
    && (batchsz = compute_batchsize(r), 1)
215
624
    && (e->genus->okp(e, 0, ((const R *)0) + 1, 2 * batchsz, 0,
216
624
          m, mb, mb + batchsz, 2, plnr))
217
624
    && (e->genus->okp(e, 0, ((const R *)0) + 1, 2 * batchsz, 0,
218
624
          m, mb, me, 2, plnr))
219
624
    );
220
624
}
221
222
static int applicable(const S *ego,
223
          INT r, INT irs, INT ors,
224
          INT m, INT ms,
225
          INT v, INT ivs, INT ovs,
226
          INT mb, INT me,
227
          R *rio, R *iio,
228
          const planner *plnr, INT *extra_iter)
229
1.79k
{
230
1.79k
     if (ego->bufferedp) {
231
624
    *extra_iter = 0;
232
624
    if (!applicable0_buf(ego,
233
624
             r, irs, ors, m, ms, v, ivs, ovs, mb, me,
234
624
             rio, iio, plnr))
235
0
         return 0;
236
1.16k
     } else {
237
1.16k
    if (!applicable0(ego,
238
1.16k
         r, irs, ors, m, ms, v, ivs, ovs, mb, me,
239
1.16k
         rio, iio, plnr, extra_iter))
240
0
         return 0;
241
1.16k
     }
242
243
1.79k
     if (NO_UGLYP(plnr) && X(ct_uglyp)((ego->bufferedp? (INT)512 : (INT)16),
244
1.79k
               v, m * r, r))
245
668
    return 0;
246
247
1.12k
     if (m * r > 262144 && NO_FIXED_RADIX_LARGE_NP(plnr))
248
0
    return 0;
249
250
1.12k
     return 1;
251
1.12k
}
252
253
static plan *mkcldw(const ct_solver *ego_,
254
        INT r, INT irs, INT ors,
255
        INT m, INT ms,
256
        INT v, INT ivs, INT ovs,
257
        INT mstart, INT mcount,
258
        R *rio, R *iio,
259
        planner *plnr)
260
1.79k
{
261
1.79k
     const S *ego = (const S *) ego_;
262
1.79k
     P *pln;
263
1.79k
     const ct_desc *e = ego->desc;
264
1.79k
     INT extra_iter;
265
266
1.79k
     static const plan_adt padt = {
267
1.79k
    0, awake, print, destroy
268
1.79k
     };
269
270
1.79k
     A(mstart >= 0 && mstart + mcount <= m);
271
1.79k
     if (!applicable(ego,
272
1.79k
         r, irs, ors, m, ms, v, ivs, ovs, mstart, mstart + mcount,
273
1.79k
         rio, iio, plnr, &extra_iter))
274
668
          return (plan *)0;
275
276
1.12k
     if (ego->bufferedp) {
277
0
    pln = MKPLAN_DFTW(P, &padt, apply_buf);
278
1.12k
     } else {
279
1.12k
    pln = MKPLAN_DFTW(P, &padt, extra_iter ? apply_extra_iter : apply);
280
1.12k
     }
281
282
1.12k
     pln->k = ego->k;
283
1.12k
     pln->rs = X(mkstride)(r, irs);
284
1.12k
     pln->td = 0;
285
1.12k
     pln->r = r;
286
1.12k
     pln->m = m;
287
1.12k
     pln->ms = ms;
288
1.12k
     pln->v = v;
289
1.12k
     pln->vs = ivs;
290
1.12k
     pln->mb = mstart;
291
1.12k
     pln->me = mstart + mcount;
292
1.12k
     pln->slv = ego;
293
1.12k
     pln->brs = X(mkstride)(r, 2 * compute_batchsize(r));
294
1.12k
     pln->extra_iter = extra_iter;
295
296
1.12k
     X(ops_zero)(&pln->super.super.ops);
297
1.12k
     X(ops_madd2)(v * (mcount/e->genus->vl), &e->ops, &pln->super.super.ops);
298
299
1.12k
     if (ego->bufferedp) {
300
    /* 8 load/stores * N * V */
301
0
    pln->super.super.ops.other += 8 * r * mcount * v;
302
0
     }
303
304
1.12k
     pln->super.super.could_prune_now_p =
305
1.12k
    (!ego->bufferedp && r >= 5 && r < 64 && m >= r);
306
1.12k
     return &(pln->super.super);
307
1.79k
}
308
309
static void regone(planner *plnr, kdftw codelet,
310
       const ct_desc *desc, int dec, int bufferedp)
311
50
{
312
50
     S *slv = (S *)X(mksolver_ct)(sizeof(S), desc->radix, dec, mkcldw, 0);
313
50
     slv->k = codelet;
314
50
     slv->desc = desc;
315
50
     slv->bufferedp = bufferedp;
316
50
     REGISTER_SOLVER(plnr, &(slv->super.super));
317
50
     if (X(mksolver_ct_hook)) {
318
0
    slv = (S *)X(mksolver_ct_hook)(sizeof(S), desc->radix,
319
0
           dec, mkcldw, 0);
320
0
    slv->k = codelet;
321
0
    slv->desc = desc;
322
0
    slv->bufferedp = bufferedp;
323
0
    REGISTER_SOLVER(plnr, &(slv->super.super));
324
0
     }
325
50
}
326
327
void X(regsolver_ct_directw)(planner *plnr, kdftw codelet,
328
           const ct_desc *desc, int dec)
329
25
{
330
25
     regone(plnr, codelet, desc, dec, /* bufferedp */ 0);
331
25
     regone(plnr, codelet, desc, dec, /* bufferedp */ 1);
332
25
}