/src/fftw3/dft/dftw-direct.c
Line | Count | Source (jump to first uncovered line) |
1 | | /* |
2 | | * Copyright (c) 2003, 2007-14 Matteo Frigo |
3 | | * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology |
4 | | * |
5 | | * This program is free software; you can redistribute it and/or modify |
6 | | * it under the terms of the GNU General Public License as published by |
7 | | * the Free Software Foundation; either version 2 of the License, or |
8 | | * (at your option) any later version. |
9 | | * |
10 | | * This program is distributed in the hope that it will be useful, |
11 | | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
12 | | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
13 | | * GNU General Public License for more details. |
14 | | * |
15 | | * You should have received a copy of the GNU General Public License |
16 | | * along with this program; if not, write to the Free Software |
17 | | * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA |
18 | | * |
19 | | */ |
20 | | |
21 | | |
22 | | #include "dft/ct.h" |
23 | | |
24 | | typedef struct { |
25 | | ct_solver super; |
26 | | const ct_desc *desc; |
27 | | int bufferedp; |
28 | | kdftw k; |
29 | | } S; |
30 | | |
31 | | typedef struct { |
32 | | plan_dftw super; |
33 | | kdftw k; |
34 | | INT r; |
35 | | stride rs; |
36 | | INT m, ms, v, vs, mb, me, extra_iter; |
37 | | stride brs; |
38 | | twid *td; |
39 | | const S *slv; |
40 | | } P; |
41 | | |
42 | | |
43 | | /************************************************************* |
44 | | Nonbuffered code |
45 | | *************************************************************/ |
46 | | static void apply(const plan *ego_, R *rio, R *iio) |
47 | 1.28k | { |
48 | 1.28k | const P *ego = (const P *) ego_; |
49 | 1.28k | INT i; |
50 | 1.28k | ASSERT_ALIGNED_DOUBLE; |
51 | 2.57k | for (i = 0; i < ego->v; ++i, rio += ego->vs, iio += ego->vs) { |
52 | 1.28k | INT mb = ego->mb, ms = ego->ms; |
53 | 1.28k | ego->k(rio + mb*ms, iio + mb*ms, ego->td->W, |
54 | 1.28k | ego->rs, mb, ego->me, ms); |
55 | 1.28k | } |
56 | 1.28k | } |
57 | | |
58 | | static void apply_extra_iter(const plan *ego_, R *rio, R *iio) |
59 | 0 | { |
60 | 0 | const P *ego = (const P *) ego_; |
61 | 0 | INT i, v = ego->v, vs = ego->vs; |
62 | 0 | INT mb = ego->mb, me = ego->me, mm = me - 1, ms = ego->ms; |
63 | 0 | ASSERT_ALIGNED_DOUBLE; |
64 | 0 | for (i = 0; i < v; ++i, rio += vs, iio += vs) { |
65 | 0 | ego->k(rio + mb*ms, iio + mb*ms, ego->td->W, |
66 | 0 | ego->rs, mb, mm, ms); |
67 | 0 | ego->k(rio + mm*ms, iio + mm*ms, ego->td->W, |
68 | 0 | ego->rs, mm, mm+2, 0); |
69 | 0 | } |
70 | 0 | } |
71 | | |
72 | | /************************************************************* |
73 | | Buffered code |
74 | | *************************************************************/ |
75 | | static void dobatch(const P *ego, R *rA, R *iA, INT mb, INT me, R *buf) |
76 | 0 | { |
77 | 0 | INT brs = WS(ego->brs, 1); |
78 | 0 | INT rs = WS(ego->rs, 1); |
79 | 0 | INT ms = ego->ms; |
80 | |
|
81 | 0 | X(cpy2d_pair_ci)(rA + mb*ms, iA + mb*ms, buf, buf + 1, |
82 | 0 | ego->r, rs, brs, |
83 | 0 | me - mb, ms, 2); |
84 | 0 | ego->k(buf, buf + 1, ego->td->W, ego->brs, mb, me, 2); |
85 | 0 | X(cpy2d_pair_co)(buf, buf + 1, rA + mb*ms, iA + mb*ms, |
86 | 0 | ego->r, brs, rs, |
87 | 0 | me - mb, 2, ms); |
88 | 0 | } |
89 | | |
90 | | /* must be even for SIMD alignment; should not be 2^k to avoid |
91 | | associativity conflicts */ |
92 | | static INT compute_batchsize(INT radix) |
93 | 1.74k | { |
94 | | /* round up to multiple of 4 */ |
95 | 1.74k | radix += 3; |
96 | 1.74k | radix &= -4; |
97 | | |
98 | 1.74k | return (radix + 2); |
99 | 1.74k | } |
100 | | |
101 | | static void apply_buf(const plan *ego_, R *rio, R *iio) |
102 | 0 | { |
103 | 0 | const P *ego = (const P *) ego_; |
104 | 0 | INT i, j, v = ego->v, r = ego->r; |
105 | 0 | INT batchsz = compute_batchsize(r); |
106 | 0 | R *buf; |
107 | 0 | INT mb = ego->mb, me = ego->me; |
108 | 0 | size_t bufsz = r * batchsz * 2 * sizeof(R); |
109 | |
|
110 | 0 | BUF_ALLOC(R *, buf, bufsz); |
111 | |
|
112 | 0 | for (i = 0; i < v; ++i, rio += ego->vs, iio += ego->vs) { |
113 | 0 | for (j = mb; j + batchsz < me; j += batchsz) |
114 | 0 | dobatch(ego, rio, iio, j, j + batchsz, buf); |
115 | |
|
116 | 0 | dobatch(ego, rio, iio, j, me, buf); |
117 | 0 | } |
118 | |
|
119 | 0 | BUF_FREE(buf, bufsz); |
120 | 0 | } |
121 | | |
122 | | /************************************************************* |
123 | | common code |
124 | | *************************************************************/ |
125 | | static void awake(plan *ego_, enum wakefulness wakefulness) |
126 | 690 | { |
127 | 690 | P *ego = (P *) ego_; |
128 | | |
129 | 690 | X(twiddle_awake)(wakefulness, &ego->td, ego->slv->desc->tw, |
130 | 690 | ego->r * ego->m, ego->r, ego->m + ego->extra_iter); |
131 | 690 | } |
132 | | |
133 | | static void destroy(plan *ego_) |
134 | 1.12k | { |
135 | 1.12k | P *ego = (P *) ego_; |
136 | 1.12k | X(stride_destroy)(ego->brs); |
137 | 1.12k | X(stride_destroy)(ego->rs); |
138 | 1.12k | } |
139 | | |
140 | | static void print(const plan *ego_, printer *p) |
141 | 0 | { |
142 | 0 | const P *ego = (const P *) ego_; |
143 | 0 | const S *slv = ego->slv; |
144 | 0 | const ct_desc *e = slv->desc; |
145 | |
|
146 | 0 | if (slv->bufferedp) |
147 | 0 | p->print(p, "(dftw-directbuf/%D-%D/%D%v \"%s\")", |
148 | 0 | compute_batchsize(ego->r), ego->r, |
149 | 0 | X(twiddle_length)(ego->r, e->tw), ego->v, e->nam); |
150 | 0 | else |
151 | 0 | p->print(p, "(dftw-direct-%D/%D%v \"%s\")", |
152 | 0 | ego->r, X(twiddle_length)(ego->r, e->tw), ego->v, e->nam); |
153 | 0 | } |
154 | | |
155 | | static int applicable0(const S *ego, |
156 | | INT r, INT irs, INT ors, |
157 | | INT m, INT ms, |
158 | | INT v, INT ivs, INT ovs, |
159 | | INT mb, INT me, |
160 | | R *rio, R *iio, |
161 | | const planner *plnr, INT *extra_iter) |
162 | 1.16k | { |
163 | 1.16k | const ct_desc *e = ego->desc; |
164 | 1.16k | UNUSED(v); |
165 | | |
166 | 1.16k | return ( |
167 | 1.16k | 1 |
168 | 1.16k | && r == e->radix |
169 | 1.16k | && irs == ors /* in-place along R */ |
170 | 1.16k | && ivs == ovs /* in-place along V */ |
171 | | |
172 | | /* check for alignment/vector length restrictions */ |
173 | 1.16k | && ((*extra_iter = 0, |
174 | 1.16k | e->genus->okp(e, rio, iio, irs, ivs, m, mb, me, ms, plnr)) |
175 | 1.16k | || |
176 | 1.16k | (*extra_iter = 1, |
177 | 0 | (1 |
178 | | /* FIXME: require full array, otherwise some threads |
179 | | may be extra_iter and other threads won't be. |
180 | | Generating the proper twiddle factors is a pain in |
181 | | this case */ |
182 | 0 | && mb == 0 && me == m |
183 | 0 | && e->genus->okp(e, rio, iio, irs, ivs, |
184 | 0 | m, mb, me - 1, ms, plnr) |
185 | 0 | && e->genus->okp(e, rio, iio, irs, ivs, |
186 | 0 | m, me - 1, me + 1, ms, plnr)))) |
187 | | |
188 | 1.16k | && (e->genus->okp(e, rio + ivs, iio + ivs, irs, ivs, |
189 | 1.16k | m, mb, me - *extra_iter, ms, plnr)) |
190 | | |
191 | 1.16k | ); |
192 | 1.16k | } |
193 | | |
194 | | static int applicable0_buf(const S *ego, |
195 | | INT r, INT irs, INT ors, |
196 | | INT m, INT ms, |
197 | | INT v, INT ivs, INT ovs, |
198 | | INT mb, INT me, |
199 | | R *rio, R *iio, |
200 | | const planner *plnr) |
201 | 624 | { |
202 | 624 | const ct_desc *e = ego->desc; |
203 | 624 | INT batchsz; |
204 | 624 | UNUSED(v); UNUSED(ms); UNUSED(rio); UNUSED(iio); |
205 | | |
206 | 624 | return ( |
207 | 624 | 1 |
208 | 624 | && r == e->radix |
209 | 624 | && irs == ors /* in-place along R */ |
210 | 624 | && ivs == ovs /* in-place along V */ |
211 | | |
212 | | /* check for alignment/vector length restrictions, both for |
213 | | batchsize and for the remainder */ |
214 | 624 | && (batchsz = compute_batchsize(r), 1) |
215 | 624 | && (e->genus->okp(e, 0, ((const R *)0) + 1, 2 * batchsz, 0, |
216 | 624 | m, mb, mb + batchsz, 2, plnr)) |
217 | 624 | && (e->genus->okp(e, 0, ((const R *)0) + 1, 2 * batchsz, 0, |
218 | 624 | m, mb, me, 2, plnr)) |
219 | 624 | ); |
220 | 624 | } |
221 | | |
222 | | static int applicable(const S *ego, |
223 | | INT r, INT irs, INT ors, |
224 | | INT m, INT ms, |
225 | | INT v, INT ivs, INT ovs, |
226 | | INT mb, INT me, |
227 | | R *rio, R *iio, |
228 | | const planner *plnr, INT *extra_iter) |
229 | 1.79k | { |
230 | 1.79k | if (ego->bufferedp) { |
231 | 624 | *extra_iter = 0; |
232 | 624 | if (!applicable0_buf(ego, |
233 | 624 | r, irs, ors, m, ms, v, ivs, ovs, mb, me, |
234 | 624 | rio, iio, plnr)) |
235 | 0 | return 0; |
236 | 1.16k | } else { |
237 | 1.16k | if (!applicable0(ego, |
238 | 1.16k | r, irs, ors, m, ms, v, ivs, ovs, mb, me, |
239 | 1.16k | rio, iio, plnr, extra_iter)) |
240 | 0 | return 0; |
241 | 1.16k | } |
242 | | |
243 | 1.79k | if (NO_UGLYP(plnr) && X(ct_uglyp)((ego->bufferedp? (INT)512 : (INT)16), |
244 | 1.79k | v, m * r, r)) |
245 | 668 | return 0; |
246 | | |
247 | 1.12k | if (m * r > 262144 && NO_FIXED_RADIX_LARGE_NP(plnr)) |
248 | 0 | return 0; |
249 | | |
250 | 1.12k | return 1; |
251 | 1.12k | } |
252 | | |
253 | | static plan *mkcldw(const ct_solver *ego_, |
254 | | INT r, INT irs, INT ors, |
255 | | INT m, INT ms, |
256 | | INT v, INT ivs, INT ovs, |
257 | | INT mstart, INT mcount, |
258 | | R *rio, R *iio, |
259 | | planner *plnr) |
260 | 1.79k | { |
261 | 1.79k | const S *ego = (const S *) ego_; |
262 | 1.79k | P *pln; |
263 | 1.79k | const ct_desc *e = ego->desc; |
264 | 1.79k | INT extra_iter; |
265 | | |
266 | 1.79k | static const plan_adt padt = { |
267 | 1.79k | 0, awake, print, destroy |
268 | 1.79k | }; |
269 | | |
270 | 1.79k | A(mstart >= 0 && mstart + mcount <= m); |
271 | 1.79k | if (!applicable(ego, |
272 | 1.79k | r, irs, ors, m, ms, v, ivs, ovs, mstart, mstart + mcount, |
273 | 1.79k | rio, iio, plnr, &extra_iter)) |
274 | 668 | return (plan *)0; |
275 | | |
276 | 1.12k | if (ego->bufferedp) { |
277 | 0 | pln = MKPLAN_DFTW(P, &padt, apply_buf); |
278 | 1.12k | } else { |
279 | 1.12k | pln = MKPLAN_DFTW(P, &padt, extra_iter ? apply_extra_iter : apply); |
280 | 1.12k | } |
281 | | |
282 | 1.12k | pln->k = ego->k; |
283 | 1.12k | pln->rs = X(mkstride)(r, irs); |
284 | 1.12k | pln->td = 0; |
285 | 1.12k | pln->r = r; |
286 | 1.12k | pln->m = m; |
287 | 1.12k | pln->ms = ms; |
288 | 1.12k | pln->v = v; |
289 | 1.12k | pln->vs = ivs; |
290 | 1.12k | pln->mb = mstart; |
291 | 1.12k | pln->me = mstart + mcount; |
292 | 1.12k | pln->slv = ego; |
293 | 1.12k | pln->brs = X(mkstride)(r, 2 * compute_batchsize(r)); |
294 | 1.12k | pln->extra_iter = extra_iter; |
295 | | |
296 | 1.12k | X(ops_zero)(&pln->super.super.ops); |
297 | 1.12k | X(ops_madd2)(v * (mcount/e->genus->vl), &e->ops, &pln->super.super.ops); |
298 | | |
299 | 1.12k | if (ego->bufferedp) { |
300 | | /* 8 load/stores * N * V */ |
301 | 0 | pln->super.super.ops.other += 8 * r * mcount * v; |
302 | 0 | } |
303 | | |
304 | 1.12k | pln->super.super.could_prune_now_p = |
305 | 1.12k | (!ego->bufferedp && r >= 5 && r < 64 && m >= r); |
306 | 1.12k | return &(pln->super.super); |
307 | 1.79k | } |
308 | | |
309 | | static void regone(planner *plnr, kdftw codelet, |
310 | | const ct_desc *desc, int dec, int bufferedp) |
311 | 50 | { |
312 | 50 | S *slv = (S *)X(mksolver_ct)(sizeof(S), desc->radix, dec, mkcldw, 0); |
313 | 50 | slv->k = codelet; |
314 | 50 | slv->desc = desc; |
315 | 50 | slv->bufferedp = bufferedp; |
316 | 50 | REGISTER_SOLVER(plnr, &(slv->super.super)); |
317 | 50 | if (X(mksolver_ct_hook)) { |
318 | 0 | slv = (S *)X(mksolver_ct_hook)(sizeof(S), desc->radix, |
319 | 0 | dec, mkcldw, 0); |
320 | 0 | slv->k = codelet; |
321 | 0 | slv->desc = desc; |
322 | 0 | slv->bufferedp = bufferedp; |
323 | 0 | REGISTER_SOLVER(plnr, &(slv->super.super)); |
324 | 0 | } |
325 | 50 | } |
326 | | |
327 | | void X(regsolver_ct_directw)(planner *plnr, kdftw codelet, |
328 | | const ct_desc *desc, int dec) |
329 | 25 | { |
330 | 25 | regone(plnr, codelet, desc, dec, /* bufferedp */ 0); |
331 | 25 | regone(plnr, codelet, desc, dec, /* bufferedp */ 1); |
332 | 25 | } |