Coverage Report

Created: 2024-09-08 06:43

/src/fftw3/rdft/hc2hc-direct.c
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * Copyright (c) 2003, 2007-14 Matteo Frigo
3
 * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
4
 *
5
 * This program is free software; you can redistribute it and/or modify
6
 * it under the terms of the GNU General Public License as published by
7
 * the Free Software Foundation; either version 2 of the License, or
8
 * (at your option) any later version.
9
 *
10
 * This program is distributed in the hope that it will be useful,
11
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
 * GNU General Public License for more details.
14
 *
15
 * You should have received a copy of the GNU General Public License
16
 * along with this program; if not, write to the Free Software
17
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
18
 *
19
 */
20
21
22
#include "rdft/hc2hc.h"
23
24
typedef struct {
25
     hc2hc_solver super;
26
     const hc2hc_desc *desc;
27
     khc2hc k;
28
     int bufferedp;
29
} S;
30
31
typedef struct {
32
     plan_hc2hc super;
33
     khc2hc k;
34
     plan *cld0, *cldm; /* children for 0th and middle butterflies */
35
     INT r, m, v;
36
     INT ms, vs, mb, me;
37
     stride rs, brs;
38
     twid *td;
39
     const S *slv;
40
} P;
41
42
/*************************************************************
43
  Nonbuffered code
44
*************************************************************/
45
static void apply(const plan *ego_, R *IO)
46
0
{
47
0
     const P *ego = (const P *) ego_;
48
0
     plan_rdft *cld0 = (plan_rdft *) ego->cld0;
49
0
     plan_rdft *cldm = (plan_rdft *) ego->cldm;
50
0
     INT i, m = ego->m, v = ego->v;
51
0
     INT mb = ego->mb, me = ego->me;
52
0
     INT ms = ego->ms, vs = ego->vs;
53
54
0
     for (i = 0; i < v; ++i, IO += vs) {
55
0
    cld0->apply((plan *) cld0, IO, IO);
56
0
    ego->k(IO + ms * mb, IO + (m - mb) * ms, 
57
0
     ego->td->W, ego->rs, mb, me, ms);
58
0
    cldm->apply((plan *) cldm, IO + (m/2) * ms, IO + (m/2) * ms);
59
0
     }
60
0
}
61
62
/*************************************************************
63
  Buffered code
64
*************************************************************/
65
66
/* should not be 2^k to avoid associativity conflicts */
67
static INT compute_batchsize(INT radix)
68
0
{
69
     /* round up to multiple of 4 */
70
0
     radix += 3;
71
0
     radix &= -4;
72
73
0
     return (radix + 2);
74
0
}
75
76
static void dobatch(const P *ego, R *IOp, R *IOm,
77
        INT mb, INT me, R *bufp)
78
0
{
79
0
     INT b = WS(ego->brs, 1);
80
0
     INT rs = WS(ego->rs, 1);
81
0
     INT r = ego->r;
82
0
     INT ms = ego->ms;
83
0
     R *bufm = bufp + b - 1;
84
85
0
     X(cpy2d_ci)(IOp + mb * ms, bufp, r, rs, b, me - mb,  ms,  1, 1);
86
0
     X(cpy2d_ci)(IOm - mb * ms, bufm, r, rs, b, me - mb, -ms, -1, 1);
87
88
0
     ego->k(bufp, bufm, ego->td->W, ego->brs, mb, me, 1);
89
90
0
     X(cpy2d_co)(bufp, IOp + mb * ms, r, b, rs, me - mb,  1,  ms, 1);
91
0
     X(cpy2d_co)(bufm, IOm - mb * ms, r, b, rs, me - mb, -1, -ms, 1);
92
0
}
93
94
static void apply_buf(const plan *ego_, R *IO)
95
0
{
96
0
     const P *ego = (const P *) ego_;
97
0
     plan_rdft *cld0 = (plan_rdft *) ego->cld0;
98
0
     plan_rdft *cldm = (plan_rdft *) ego->cldm;
99
0
     INT i, j, m = ego->m, v = ego->v, r = ego->r;
100
0
     INT mb = ego->mb, me = ego->me, ms = ego->ms;
101
0
     INT batchsz = compute_batchsize(r);
102
0
     R *buf;
103
0
     size_t bufsz = r * batchsz * 2 * sizeof(R);
104
105
0
     BUF_ALLOC(R *, buf, bufsz);
106
107
0
     for (i = 0; i < v; ++i, IO += ego->vs) {
108
0
    R *IOp = IO;
109
0
    R *IOm = IO + m * ms;
110
111
0
    cld0->apply((plan *) cld0, IO, IO);
112
113
0
    for (j = mb; j + batchsz < me; j += batchsz)          
114
0
         dobatch(ego, IOp, IOm, j, j + batchsz, buf);
115
116
0
    dobatch(ego, IOp, IOm, j, me, buf);
117
118
0
    cldm->apply((plan *) cldm, IO + ms * (m/2), IO + ms * (m/2));
119
0
     }
120
121
0
     BUF_FREE(buf, bufsz);
122
0
}
123
124
static void awake(plan *ego_, enum wakefulness wakefulness)
125
0
{
126
0
     P *ego = (P *) ego_;
127
128
0
     X(plan_awake)(ego->cld0, wakefulness);
129
0
     X(plan_awake)(ego->cldm, wakefulness);
130
0
     X(twiddle_awake)(wakefulness, &ego->td, ego->slv->desc->tw, 
131
0
          ego->r * ego->m, ego->r, (ego->m - 1) / 2);
132
0
}
133
134
static void destroy(plan *ego_)
135
0
{
136
0
     P *ego = (P *) ego_;
137
0
     X(plan_destroy_internal)(ego->cld0);
138
0
     X(plan_destroy_internal)(ego->cldm);
139
0
     X(stride_destroy)(ego->rs);
140
0
     X(stride_destroy)(ego->brs);
141
0
}
142
143
static void print(const plan *ego_, printer *p)
144
0
{
145
0
     const P *ego = (const P *) ego_;
146
0
     const S *slv = ego->slv;
147
0
     const hc2hc_desc *e = slv->desc;
148
0
     INT batchsz = compute_batchsize(ego->r);
149
150
0
     if (slv->bufferedp)
151
0
    p->print(p, "(hc2hc-directbuf/%D-%D/%D%v \"%s\"%(%p%)%(%p%))",
152
0
       batchsz, ego->r, X(twiddle_length)(ego->r, e->tw), 
153
0
       ego->v, e->nam, ego->cld0, ego->cldm);
154
0
     else
155
0
    p->print(p, "(hc2hc-direct-%D/%D%v \"%s\"%(%p%)%(%p%))",
156
0
       ego->r, X(twiddle_length)(ego->r, e->tw), ego->v, e->nam,
157
0
       ego->cld0, ego->cldm);
158
0
}
159
160
static int applicable0(const S *ego, rdft_kind kind, INT r)
161
0
{
162
0
     const hc2hc_desc *e = ego->desc;
163
164
0
     return (1
165
0
       && r == e->radix
166
0
       && kind == e->genus->kind
167
0
    );
168
0
}
169
170
static int applicable(const S *ego, rdft_kind kind, INT r, INT m, INT v,
171
          const planner *plnr)
172
0
{
173
0
     if (!applicable0(ego, kind, r))
174
0
          return 0;
175
176
0
     if (NO_UGLYP(plnr) && X(ct_uglyp)((ego->bufferedp? (INT)512 : (INT)16),
177
0
               v, m * r, r)) 
178
0
    return 0;
179
180
0
     return 1;
181
0
}
182
183
0
#define CLDMP(m, mstart, mcount) (2 * ((mstart) + (mcount)) == (m) + 2)
184
0
#define CLD0P(mstart) ((mstart) == 0)
185
186
static plan *mkcldw(const hc2hc_solver *ego_, 
187
        rdft_kind kind, INT r, INT m, INT ms, INT v, INT vs, 
188
        INT mstart, INT mcount,
189
        R *IO, planner *plnr)
190
0
{
191
0
     const S *ego = (const S *) ego_;
192
0
     P *pln;
193
0
     const hc2hc_desc *e = ego->desc;
194
0
     plan *cld0 = 0, *cldm = 0;
195
0
     INT imid = (m / 2) * ms;
196
0
     INT rs = m * ms;
197
198
0
     static const plan_adt padt = {
199
0
    0, awake, print, destroy
200
0
     };
201
202
0
     if (!applicable(ego, kind, r, m, v, plnr))
203
0
          return (plan *)0;
204
205
0
     cld0 = X(mkplan_d)(
206
0
    plnr, 
207
0
    X(mkproblem_rdft_1_d)((CLD0P(mstart) ?
208
0
         X(mktensor_1d)(r, rs, rs) : X(mktensor_0d)()),
209
0
        X(mktensor_0d)(),
210
0
        TAINT(IO, vs), TAINT(IO, vs), 
211
0
        kind));
212
0
     if (!cld0) goto nada;
213
214
0
     cldm = X(mkplan_d)(
215
0
    plnr, 
216
0
    X(mkproblem_rdft_1_d)((CLDMP(m, mstart, mcount) ?
217
0
         X(mktensor_1d)(r, rs, rs) : X(mktensor_0d)()),
218
0
        X(mktensor_0d)(),
219
0
        TAINT(IO + imid, vs), TAINT(IO + imid, vs),
220
0
        kind == R2HC ? R2HCII : HC2RIII));
221
0
     if (!cldm) goto nada;
222
    
223
0
     pln = MKPLAN_HC2HC(P, &padt, ego->bufferedp ? apply_buf : apply);
224
225
0
     pln->k = ego->k;
226
0
     pln->td = 0;
227
0
     pln->r = r; pln->rs = X(mkstride)(r, rs);
228
0
     pln->m = m; pln->ms = ms;
229
0
     pln->v = v; pln->vs = vs;
230
0
     pln->slv = ego;
231
0
     pln->brs = X(mkstride)(r, 2 * compute_batchsize(r));
232
0
     pln->cld0 = cld0;
233
0
     pln->cldm = cldm;
234
0
     pln->mb = mstart + CLD0P(mstart);
235
0
     pln->me = mstart + mcount - CLDMP(m, mstart, mcount);
236
237
0
     X(ops_zero)(&pln->super.super.ops);
238
0
     X(ops_madd2)(v * ((pln->me - pln->mb) / e->genus->vl),
239
0
      &e->ops, &pln->super.super.ops);
240
0
     X(ops_madd2)(v, &cld0->ops, &pln->super.super.ops);
241
0
     X(ops_madd2)(v, &cldm->ops, &pln->super.super.ops);
242
243
0
     if (ego->bufferedp) 
244
0
    pln->super.super.ops.other += 4 * r * (pln->me - pln->mb) * v;
245
246
0
     pln->super.super.could_prune_now_p =
247
0
    (!ego->bufferedp && r >= 5 && r < 64 && m >= r);
248
249
0
     return &(pln->super.super);
250
251
0
 nada:
252
0
     X(plan_destroy_internal)(cld0);
253
0
     X(plan_destroy_internal)(cldm);
254
0
     return 0;
255
0
}
256
257
static void regone(planner *plnr, khc2hc codelet, const hc2hc_desc *desc,
258
       int bufferedp)
259
92
{
260
92
     S *slv = (S *)X(mksolver_hc2hc)(sizeof(S), desc->radix, mkcldw);
261
92
     slv->k = codelet;
262
92
     slv->desc = desc;
263
92
     slv->bufferedp = bufferedp;
264
92
     REGISTER_SOLVER(plnr, &(slv->super.super));
265
92
     if (X(mksolver_hc2hc_hook)) {
266
0
    slv = (S *)X(mksolver_hc2hc_hook)(sizeof(S), desc->radix, mkcldw);
267
0
    slv->k = codelet;
268
0
    slv->desc = desc;
269
0
    slv->bufferedp = bufferedp;
270
0
    REGISTER_SOLVER(plnr, &(slv->super.super));
271
0
     }
272
92
}
273
274
void X(regsolver_hc2hc_direct)(planner *plnr, khc2hc codelet,
275
             const hc2hc_desc *desc)
276
46
{
277
46
     regone(plnr, codelet, desc, /* bufferedp */0);
278
46
     regone(plnr, codelet, desc, /* bufferedp */1);
279
46
}