Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) 2003, 2007-14 Matteo Frigo |
3 | | * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology |
4 | | * |
5 | | * This program is free software; you can redistribute it and/or modify |
6 | | * it under the terms of the GNU General Public License as published by |
7 | | * the Free Software Foundation; either version 2 of the License, or |
8 | | * (at your option) any later version. |
9 | | * |
10 | | * This program is distributed in the hope that it will be useful, |
11 | | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
12 | | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
13 | | * GNU General Public License for more details. |
14 | | * |
15 | | * You should have received a copy of the GNU General Public License |
16 | | * along with this program; if not, write to the Free Software |
17 | | * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA |
18 | | * |
19 | | */ |
20 | | |
21 | | |
22 | | /* direct DFT solver, if we have a codelet */ |
23 | | |
24 | | #include "dft/dft.h" |
25 | | |
26 | | typedef struct { |
27 | | solver super; |
28 | | const kdft_desc *desc; |
29 | | kdft k; |
30 | | int bufferedp; |
31 | | } S; |
32 | | |
33 | | typedef struct { |
34 | | plan_dft super; |
35 | | |
36 | | stride is, os, bufstride; |
37 | | INT n, vl, ivs, ovs; |
38 | | kdft k; |
39 | | const S *slv; |
40 | | } P; |
41 | | |
42 | | static void dobatch(const P *ego, R *ri, R *ii, R *ro, R *io, |
43 | | R *buf, INT batchsz) |
44 | 40 | { |
45 | 40 | X(cpy2d_pair_ci)(ri, ii, buf, buf+1, |
46 | 40 | ego->n, WS(ego->is, 1), WS(ego->bufstride, 1), |
47 | 40 | batchsz, ego->ivs, 2); |
48 | | |
49 | 40 | if (IABS(WS(ego->os, 1)) < IABS(ego->ovs)) { |
50 | | /* transform directly to output */ |
51 | 40 | ego->k(buf, buf+1, ro, io, |
52 | 40 | ego->bufstride, ego->os, batchsz, 2, ego->ovs); |
53 | 40 | } else { |
54 | | /* transform to buffer and copy back */ |
55 | 0 | ego->k(buf, buf+1, buf, buf+1, |
56 | 0 | ego->bufstride, ego->bufstride, batchsz, 2, 2); |
57 | 0 | X(cpy2d_pair_co)(buf, buf+1, ro, io, |
58 | 0 | ego->n, WS(ego->bufstride, 1), WS(ego->os, 1), |
59 | 0 | batchsz, 2, ego->ovs); |
60 | 0 | } |
61 | 40 | } |
62 | | |
63 | | static INT compute_batchsize(INT n) |
64 | 2.14k | { |
65 | | /* round up to multiple of 4 */ |
66 | 2.14k | n += 3; |
67 | 2.14k | n &= -4; |
68 | | |
69 | 2.14k | return (n + 2); |
70 | 2.14k | } |
71 | | |
72 | | static void apply_buf(const plan *ego_, R *ri, R *ii, R *ro, R *io) |
73 | 40 | { |
74 | 40 | const P *ego = (const P *) ego_; |
75 | 40 | R *buf; |
76 | 40 | INT vl = ego->vl, n = ego->n, batchsz = compute_batchsize(n); |
77 | 40 | INT i; |
78 | 40 | size_t bufsz = n * batchsz * 2 * sizeof(R); |
79 | | |
80 | 40 | BUF_ALLOC(R *, buf, bufsz); |
81 | | |
82 | 40 | for (i = 0; i < vl - batchsz; i += batchsz) { |
83 | 0 | dobatch(ego, ri, ii, ro, io, buf, batchsz); |
84 | 0 | ri += batchsz * ego->ivs; ii += batchsz * ego->ivs; |
85 | 0 | ro += batchsz * ego->ovs; io += batchsz * ego->ovs; |
86 | 0 | } |
87 | 40 | dobatch(ego, ri, ii, ro, io, buf, vl - i); |
88 | | |
89 | 40 | BUF_FREE(buf, bufsz); |
90 | 40 | } |
91 | | |
92 | | static void apply(const plan *ego_, R *ri, R *ii, R *ro, R *io) |
93 | 1.35k | { |
94 | 1.35k | const P *ego = (const P *) ego_; |
95 | 1.35k | ASSERT_ALIGNED_DOUBLE; |
96 | 1.35k | ego->k(ri, ii, ro, io, ego->is, ego->os, ego->vl, ego->ivs, ego->ovs); |
97 | 1.35k | } |
98 | | |
99 | | static void apply_extra_iter(const plan *ego_, R *ri, R *ii, R *ro, R *io) |
100 | 0 | { |
101 | 0 | const P *ego = (const P *) ego_; |
102 | 0 | INT vl = ego->vl; |
103 | |
|
104 | 0 | ASSERT_ALIGNED_DOUBLE; |
105 | | |
106 | | /* for 4-way SIMD when VL is odd: iterate over an |
107 | | even vector length VL, and then execute the last |
108 | | iteration as a 2-vector with vector stride 0. */ |
109 | 0 | ego->k(ri, ii, ro, io, ego->is, ego->os, vl - 1, ego->ivs, ego->ovs); |
110 | |
|
111 | 0 | ego->k(ri + (vl - 1) * ego->ivs, ii + (vl - 1) * ego->ivs, |
112 | 0 | ro + (vl - 1) * ego->ovs, io + (vl - 1) * ego->ovs, |
113 | 0 | ego->is, ego->os, 1, 0, 0); |
114 | 0 | } |
115 | | |
116 | | static void destroy(plan *ego_) |
117 | 1.49k | { |
118 | 1.49k | P *ego = (P *) ego_; |
119 | 1.49k | X(stride_destroy)(ego->is); |
120 | 1.49k | X(stride_destroy)(ego->os); |
121 | 1.49k | X(stride_destroy)(ego->bufstride); |
122 | 1.49k | } |
123 | | |
124 | | static void print(const plan *ego_, printer *p) |
125 | 0 | { |
126 | 0 | const P *ego = (const P *) ego_; |
127 | 0 | const S *s = ego->slv; |
128 | 0 | const kdft_desc *d = s->desc; |
129 | |
|
130 | 0 | if (ego->slv->bufferedp) |
131 | 0 | p->print(p, "(dft-directbuf/%D-%D%v \"%s\")", |
132 | 0 | compute_batchsize(d->sz), d->sz, ego->vl, d->nam); |
133 | 0 | else |
134 | 0 | p->print(p, "(dft-direct-%D%v \"%s\")", d->sz, ego->vl, d->nam); |
135 | 0 | } |
136 | | |
137 | | static int applicable_buf(const solver *ego_, const problem *p_, |
138 | | const planner *plnr) |
139 | 26.9k | { |
140 | 26.9k | const S *ego = (const S *) ego_; |
141 | 26.9k | const problem_dft *p = (const problem_dft *) p_; |
142 | 26.9k | const kdft_desc *d = ego->desc; |
143 | 26.9k | INT vl; |
144 | 26.9k | INT ivs, ovs; |
145 | 26.9k | INT batchsz; |
146 | | |
147 | 26.9k | return ( |
148 | 26.9k | 1 |
149 | 26.9k | && p->sz->rnk == 1 |
150 | 20.8k | && p->vecsz->rnk == 1 |
151 | 12.3k | && p->sz->dims[0].n == d->sz |
152 | | |
153 | | /* check strides etc */ |
154 | 668 | && X(tensor_tornk1)(p->vecsz, &vl, &ivs, &ovs) |
155 | | |
156 | | /* UGLY if IS <= IVS */ |
157 | 668 | && !(NO_UGLYP(plnr) && |
158 | 668 | X(iabs)(p->sz->dims[0].is) <= X(iabs)(ivs)) |
159 | | |
160 | 609 | && (batchsz = compute_batchsize(d->sz), 1) |
161 | 609 | && (d->genus->okp(d, 0, ((const R *)0) + 1, p->ro, p->io, |
162 | 609 | 2 * batchsz, p->sz->dims[0].os, |
163 | 609 | batchsz, 2, ovs, plnr)) |
164 | 609 | && (d->genus->okp(d, 0, ((const R *)0) + 1, p->ro, p->io, |
165 | 609 | 2 * batchsz, p->sz->dims[0].os, |
166 | 609 | vl % batchsz, 2, ovs, plnr)) |
167 | | |
168 | | |
169 | 609 | && (0 |
170 | | /* can operate out-of-place */ |
171 | 609 | || p->ri != p->ro |
172 | | |
173 | | /* can operate in-place as long as strides are the same */ |
174 | 208 | || X(tensor_inplace_strides2)(p->sz, p->vecsz) |
175 | | |
176 | | /* can do it if the problem fits in the buffer, no matter |
177 | | what the strides are */ |
178 | 182 | || vl <= batchsz |
179 | 609 | ) |
180 | 26.9k | ); |
181 | 26.9k | } |
182 | | |
183 | | static int applicable(const solver *ego_, const problem *p_, |
184 | | const planner *plnr, int *extra_iterp) |
185 | 27.3k | { |
186 | 27.3k | const S *ego = (const S *) ego_; |
187 | 27.3k | const problem_dft *p = (const problem_dft *) p_; |
188 | 27.3k | const kdft_desc *d = ego->desc; |
189 | 27.3k | INT vl; |
190 | 27.3k | INT ivs, ovs; |
191 | | |
192 | 27.3k | return ( |
193 | 27.3k | 1 |
194 | 27.3k | && p->sz->rnk == 1 |
195 | 21.2k | && p->vecsz->rnk <= 1 |
196 | 17.6k | && p->sz->dims[0].n == d->sz |
197 | | |
198 | | /* check strides etc */ |
199 | 1.06k | && X(tensor_tornk1)(p->vecsz, &vl, &ivs, &ovs) |
200 | | |
201 | 1.06k | && ((*extra_iterp = 0, |
202 | 1.06k | (d->genus->okp(d, p->ri, p->ii, p->ro, p->io, |
203 | 1.06k | p->sz->dims[0].is, p->sz->dims[0].os, |
204 | 1.06k | vl, ivs, ovs, plnr))) |
205 | 0 | || |
206 | 0 | (*extra_iterp = 1, |
207 | 0 | ((d->genus->okp(d, p->ri, p->ii, p->ro, p->io, |
208 | 0 | p->sz->dims[0].is, p->sz->dims[0].os, |
209 | 0 | vl - 1, ivs, ovs, plnr)) |
210 | 0 | && |
211 | 0 | (d->genus->okp(d, p->ri, p->ii, p->ro, p->io, |
212 | 0 | p->sz->dims[0].is, p->sz->dims[0].os, |
213 | 0 | 2, 0, 0, plnr))))) |
214 | | |
215 | 1.06k | && (0 |
216 | | /* can operate out-of-place */ |
217 | 1.06k | || p->ri != p->ro |
218 | | |
219 | | /* can always compute one transform */ |
220 | 202 | || vl == 1 |
221 | | |
222 | | /* can operate in-place as long as strides are the same */ |
223 | 202 | || X(tensor_inplace_strides2)(p->sz, p->vecsz) |
224 | 1.06k | ) |
225 | 27.3k | ); |
226 | 27.3k | } |
227 | | |
228 | | |
229 | | static plan *mkplan(const solver *ego_, const problem *p_, planner *plnr) |
230 | 54.3k | { |
231 | 54.3k | const S *ego = (const S *) ego_; |
232 | 54.3k | P *pln; |
233 | 54.3k | const problem_dft *p; |
234 | 54.3k | iodim *d; |
235 | 54.3k | const kdft_desc *e = ego->desc; |
236 | | |
237 | 54.3k | static const plan_adt padt = { |
238 | 54.3k | X(dft_solve), X(null_awake), print, destroy |
239 | 54.3k | }; |
240 | | |
241 | 54.3k | UNUSED(plnr); |
242 | | |
243 | 54.3k | if (ego->bufferedp) { |
244 | 26.9k | if (!applicable_buf(ego_, p_, plnr)) |
245 | 26.4k | return (plan *)0; |
246 | 556 | pln = MKPLAN_DFT(P, &padt, apply_buf); |
247 | 27.3k | } else { |
248 | 27.3k | int extra_iterp = 0; |
249 | 27.3k | if (!applicable(ego_, p_, plnr, &extra_iterp)) |
250 | 26.4k | return (plan *)0; |
251 | 942 | pln = MKPLAN_DFT(P, &padt, extra_iterp ? apply_extra_iter : apply); |
252 | 942 | } |
253 | | |
254 | 1.49k | p = (const problem_dft *) p_; |
255 | 1.49k | d = p->sz->dims; |
256 | 1.49k | pln->k = ego->k; |
257 | 1.49k | pln->n = d[0].n; |
258 | 1.49k | pln->is = X(mkstride)(pln->n, d[0].is); |
259 | 1.49k | pln->os = X(mkstride)(pln->n, d[0].os); |
260 | 1.49k | pln->bufstride = X(mkstride)(pln->n, 2 * compute_batchsize(pln->n)); |
261 | | |
262 | 1.49k | X(tensor_tornk1)(p->vecsz, &pln->vl, &pln->ivs, &pln->ovs); |
263 | 1.49k | pln->slv = ego; |
264 | | |
265 | 1.49k | X(ops_zero)(&pln->super.super.ops); |
266 | 1.49k | X(ops_madd2)(pln->vl / e->genus->vl, &e->ops, &pln->super.super.ops); |
267 | | |
268 | 1.49k | if (ego->bufferedp) |
269 | 556 | pln->super.super.ops.other += 4 * pln->n * pln->vl; |
270 | | |
271 | 1.49k | pln->super.super.could_prune_now_p = !ego->bufferedp; |
272 | 1.49k | return &(pln->super.super); |
273 | 54.3k | } |
274 | | |
275 | | static solver *mksolver(kdft k, const kdft_desc *desc, int bufferedp) |
276 | 38 | { |
277 | 38 | static const solver_adt sadt = { PROBLEM_DFT, mkplan, 0 }; |
278 | 38 | S *slv = MKSOLVER(S, &sadt); |
279 | 38 | slv->k = k; |
280 | 38 | slv->desc = desc; |
281 | 38 | slv->bufferedp = bufferedp; |
282 | 38 | return &(slv->super); |
283 | 38 | } |
284 | | |
285 | | solver *X(mksolver_dft_direct)(kdft k, const kdft_desc *desc) |
286 | 19 | { |
287 | 19 | return mksolver(k, desc, 0); |
288 | 19 | } |
289 | | |
290 | | solver *X(mksolver_dft_directbuf)(kdft k, const kdft_desc *desc) |
291 | 19 | { |
292 | 19 | return mksolver(k, desc, 1); |
293 | 19 | } |