Coverage Report

Created: 2024-10-16 07:58

/rust/registry/src/index.crates.io-6f17d22bba15001f/cranelift-codegen-0.91.1/src/egraph.rs
Line
Count
Source (jump to first uncovered line)
1
//! Egraph-based mid-end optimization framework.
2
3
use crate::dominator_tree::DominatorTree;
4
use crate::egraph::stores::PackedMemoryState;
5
use crate::flowgraph::ControlFlowGraph;
6
use crate::loop_analysis::{LoopAnalysis, LoopLevel};
7
use crate::trace;
8
use crate::{
9
    fx::{FxHashMap, FxHashSet},
10
    inst_predicates::has_side_effect,
11
    ir::{Block, Function, Inst, InstructionData, InstructionImms, Opcode, Type},
12
};
13
use alloc::vec::Vec;
14
use core::ops::Range;
15
use cranelift_egraph::{EGraph, Id, Language, NewOrExisting};
16
use cranelift_entity::EntityList;
17
use cranelift_entity::SecondaryMap;
18
19
mod domtree;
20
mod elaborate;
21
mod node;
22
mod stores;
23
24
use elaborate::Elaborator;
25
pub use node::{Node, NodeCtx};
26
pub use stores::{AliasAnalysis, MemoryState};
27
28
pub struct FuncEGraph<'a> {
29
    /// Dominator tree, used for elaboration pass.
30
    domtree: &'a DominatorTree,
31
    /// Loop analysis results, used for built-in LICM during elaboration.
32
    loop_analysis: &'a LoopAnalysis,
33
    /// Last-store tracker for integrated alias analysis during egraph build.
34
    alias_analysis: AliasAnalysis,
35
    /// The egraph itself.
36
    pub(crate) egraph: EGraph<NodeCtx, Analysis>,
37
    /// "node context", containing arenas for node data.
38
    pub(crate) node_ctx: NodeCtx,
39
    /// Ranges in `side_effect_ids` for sequences of side-effecting
40
    /// eclasses per block.
41
    side_effects: SecondaryMap<Block, Range<u32>>,
42
    side_effect_ids: Vec<Id>,
43
    /// Map from store instructions to their nodes; used for store-to-load forwarding.
44
    pub(crate) store_nodes: FxHashMap<Inst, (Type, Id)>,
45
    /// Ranges in `blockparam_ids_tys` for sequences of blockparam
46
    /// eclass IDs and types per block.
47
    blockparams: SecondaryMap<Block, Range<u32>>,
48
    blockparam_ids_tys: Vec<(Id, Type)>,
49
    /// Which canonical node IDs do we want to rematerialize in each
50
    /// block where they're used?
51
    pub(crate) remat_ids: FxHashSet<Id>,
52
    /// Which canonical node IDs have an enode whose value subsumes
53
    /// all others it's unioned with?
54
    pub(crate) subsume_ids: FxHashSet<Id>,
55
    /// Statistics recorded during the process of building,
56
    /// optimizing, and lowering out of this egraph.
57
    pub(crate) stats: Stats,
58
    /// Current rewrite-recursion depth. Used to enforce a finite
59
    /// limit on rewrite rule application so that we don't get stuck
60
    /// in an infinite chain.
61
    pub(crate) rewrite_depth: usize,
62
}
63
64
#[derive(Clone, Debug, Default)]
65
pub(crate) struct Stats {
66
    pub(crate) node_created: u64,
67
    pub(crate) node_param: u64,
68
    pub(crate) node_result: u64,
69
    pub(crate) node_pure: u64,
70
    pub(crate) node_inst: u64,
71
    pub(crate) node_load: u64,
72
    pub(crate) node_dedup_query: u64,
73
    pub(crate) node_dedup_hit: u64,
74
    pub(crate) node_dedup_miss: u64,
75
    pub(crate) node_ctor_created: u64,
76
    pub(crate) node_ctor_deduped: u64,
77
    pub(crate) node_union: u64,
78
    pub(crate) node_subsume: u64,
79
    pub(crate) store_map_insert: u64,
80
    pub(crate) side_effect_nodes: u64,
81
    pub(crate) rewrite_rule_invoked: u64,
82
    pub(crate) rewrite_depth_limit: u64,
83
    pub(crate) store_to_load_forward: u64,
84
    pub(crate) elaborate_visit_node: u64,
85
    pub(crate) elaborate_memoize_hit: u64,
86
    pub(crate) elaborate_memoize_miss: u64,
87
    pub(crate) elaborate_memoize_miss_remat: u64,
88
    pub(crate) elaborate_licm_hoist: u64,
89
    pub(crate) elaborate_func: u64,
90
    pub(crate) elaborate_func_pre_insts: u64,
91
    pub(crate) elaborate_func_post_insts: u64,
92
}
93
94
impl<'a> FuncEGraph<'a> {
95
    /// Create a new EGraph for the given function. Requires the
96
    /// domtree to be precomputed as well; the domtree is used for
97
    /// scheduling when lowering out of the egraph.
98
0
    pub fn new(
99
0
        func: &Function,
100
0
        domtree: &'a DominatorTree,
101
0
        loop_analysis: &'a LoopAnalysis,
102
0
        cfg: &ControlFlowGraph,
103
0
    ) -> FuncEGraph<'a> {
104
0
        let num_values = func.dfg.num_values();
105
0
        let num_blocks = func.dfg.num_blocks();
106
0
        let node_count_estimate = num_values * 2;
107
0
        let alias_analysis = AliasAnalysis::new(func, cfg);
108
0
        let mut this = Self {
109
0
            domtree,
110
0
            loop_analysis,
111
0
            alias_analysis,
112
0
            egraph: EGraph::with_capacity(node_count_estimate, Some(Analysis)),
113
0
            node_ctx: NodeCtx::with_capacity_for_dfg(&func.dfg),
114
0
            side_effects: SecondaryMap::with_capacity(num_blocks),
115
0
            side_effect_ids: Vec::with_capacity(node_count_estimate),
116
0
            store_nodes: FxHashMap::default(),
117
0
            blockparams: SecondaryMap::with_capacity(num_blocks),
118
0
            blockparam_ids_tys: Vec::with_capacity(num_blocks * 10),
119
0
            remat_ids: FxHashSet::default(),
120
0
            subsume_ids: FxHashSet::default(),
121
0
            stats: Default::default(),
122
0
            rewrite_depth: 0,
123
0
        };
124
0
        this.store_nodes.reserve(func.dfg.num_values() / 8);
125
0
        this.remat_ids.reserve(func.dfg.num_values() / 4);
126
0
        this.subsume_ids.reserve(func.dfg.num_values() / 4);
127
0
        this.build(func);
128
0
        this
129
0
    }
130
131
0
    fn build(&mut self, func: &Function) {
132
0
        // Mapping of SSA `Value` to eclass ID.
133
0
        let mut value_to_id = FxHashMap::default();
134
135
        // For each block in RPO, create an enode for block entry, for
136
        // each block param, and for each instruction.
137
0
        for &block in self.domtree.cfg_postorder().iter().rev() {
138
0
            let loop_level = self.loop_analysis.loop_level(block);
139
0
            let blockparam_start =
140
0
                u32::try_from(self.blockparam_ids_tys.len()).expect("Overflow in blockparam count");
141
0
            for (i, &value) in func.dfg.block_params(block).iter().enumerate() {
142
0
                let ty = func.dfg.value_type(value);
143
0
                let param = self
144
0
                    .egraph
145
0
                    .add(
146
0
                        Node::Param {
147
0
                            block,
148
0
                            index: i
149
0
                                .try_into()
150
0
                                .expect("blockparam index should fit in Node::Param"),
151
0
                            ty,
152
0
                            loop_level,
153
0
                        },
154
0
                        &mut self.node_ctx,
155
0
                    )
156
0
                    .get();
157
0
                value_to_id.insert(value, param);
158
0
                self.blockparam_ids_tys.push((param, ty));
159
0
                self.stats.node_created += 1;
160
0
                self.stats.node_param += 1;
161
0
            }
162
0
            let blockparam_end =
163
0
                u32::try_from(self.blockparam_ids_tys.len()).expect("Overflow in blockparam count");
164
0
            self.blockparams[block] = blockparam_start..blockparam_end;
165
0
166
0
            let side_effect_start =
167
0
                u32::try_from(self.side_effect_ids.len()).expect("Overflow in side-effect count");
168
0
            for inst in func.layout.block_insts(block) {
169
                // Build args from SSA values.
170
0
                let args = EntityList::from_iter(
171
0
                    func.dfg.inst_args(inst).iter().map(|&arg| {
172
0
                        let arg = func.dfg.resolve_aliases(arg);
173
0
                        *value_to_id
174
0
                            .get(&arg)
175
0
                            .expect("Must have seen def before this use")
176
0
                    }),
177
0
                    &mut self.node_ctx.args,
178
0
                );
179
0
180
0
                let results = func.dfg.inst_results(inst);
181
0
                let ty = if results.len() == 1 {
182
0
                    func.dfg.value_type(results[0])
183
                } else {
184
0
                    crate::ir::types::INVALID
185
                };
186
187
0
                let load_mem_state = self.alias_analysis.get_state_for_load(inst);
188
0
                let is_readonly_load = match func.dfg[inst] {
189
                    InstructionData::Load {
190
                        opcode: Opcode::Load,
191
0
                        flags,
192
0
                        ..
193
0
                    } => flags.readonly() && flags.notrap(),
194
0
                    _ => false,
195
                };
196
197
                // Create the egraph node.
198
0
                let op = InstructionImms::from(&func.dfg[inst]);
199
0
                let opcode = op.opcode();
200
0
                let srcloc = func.srclocs[inst];
201
0
                let arity = u16::try_from(results.len())
202
0
                    .expect("More than 2^16 results from an instruction");
203
204
0
                let node = if is_readonly_load {
205
0
                    self.stats.node_created += 1;
206
0
                    self.stats.node_pure += 1;
207
0
                    Node::Pure {
208
0
                        op,
209
0
                        args,
210
0
                        ty,
211
0
                        arity,
212
0
                    }
213
0
                } else if let Some(load_mem_state) = load_mem_state {
214
0
                    let addr = args.as_slice(&self.node_ctx.args)[0];
215
0
                    trace!("load at inst {} has mem state {:?}", inst, load_mem_state);
216
0
                    self.stats.node_created += 1;
217
0
                    self.stats.node_load += 1;
218
0
                    Node::Load {
219
0
                        op,
220
0
                        ty,
221
0
                        addr,
222
0
                        mem_state: load_mem_state,
223
0
                        srcloc,
224
0
                    }
225
0
                } else if has_side_effect(func, inst) || opcode.can_load() {
226
0
                    self.stats.node_created += 1;
227
0
                    self.stats.node_inst += 1;
228
0
                    Node::Inst {
229
0
                        op,
230
0
                        args,
231
0
                        ty,
232
0
                        arity,
233
0
                        srcloc,
234
0
                        loop_level,
235
0
                    }
236
                } else {
237
0
                    self.stats.node_created += 1;
238
0
                    self.stats.node_pure += 1;
239
0
                    Node::Pure {
240
0
                        op,
241
0
                        args,
242
0
                        ty,
243
0
                        arity,
244
0
                    }
245
                };
246
0
                let dedup_needed = self.node_ctx.needs_dedup(&node);
247
0
                let is_pure = matches!(node, Node::Pure { .. });
248
249
0
                let mut id = self.egraph.add(node, &mut self.node_ctx);
250
0
251
0
                if dedup_needed {
252
0
                    self.stats.node_dedup_query += 1;
253
0
                    match id {
254
0
                        NewOrExisting::New(_) => {
255
0
                            self.stats.node_dedup_miss += 1;
256
0
                        }
257
0
                        NewOrExisting::Existing(_) => {
258
0
                            self.stats.node_dedup_hit += 1;
259
0
                        }
260
                    }
261
0
                }
262
263
0
                if opcode == Opcode::Store {
264
0
                    let store_data_ty = func.dfg.value_type(func.dfg.inst_args(inst)[0]);
265
0
                    self.store_nodes.insert(inst, (store_data_ty, id.get()));
266
0
                    self.stats.store_map_insert += 1;
267
0
                }
268
269
                // Loads that did not already merge into an existing
270
                // load: try to forward from a store (store-to-load
271
                // forwarding).
272
0
                if let NewOrExisting::New(new_id) = id {
273
0
                    if load_mem_state.is_some() {
274
0
                        let opt_id = crate::opts::store_to_load(new_id, self);
275
0
                        trace!("store_to_load: {} -> {}", new_id, opt_id);
276
0
                        if opt_id != new_id {
277
0
                            id = NewOrExisting::Existing(opt_id);
278
0
                        }
279
0
                    }
280
0
                }
281
282
                // Now either optimize (for new pure nodes), or add to
283
                // the side-effecting list (for all other new nodes).
284
0
                let id = match id {
285
0
                    NewOrExisting::Existing(id) => id,
286
0
                    NewOrExisting::New(id) if is_pure => {
287
0
                        // Apply all optimization rules immediately; the
288
0
                        // aegraph (acyclic egraph) works best when we do
289
0
                        // this so all uses pick up the eclass with all
290
0
                        // possible enodes.
291
0
                        crate::opts::optimize_eclass(id, self)
292
                    }
293
0
                    NewOrExisting::New(id) => {
294
0
                        self.side_effect_ids.push(id);
295
0
                        self.stats.side_effect_nodes += 1;
296
0
                        id
297
                    }
298
                };
299
300
                // Create results and save in Value->Id map.
301
0
                match results {
302
0
                    &[] => {}
303
0
                    &[one_result] => {
304
0
                        trace!("build: value {} -> id {}", one_result, id);
305
0
                        value_to_id.insert(one_result, id);
306
                    }
307
0
                    many_results => {
308
0
                        debug_assert!(many_results.len() > 1);
309
0
                        for (i, &result) in many_results.iter().enumerate() {
310
0
                            let ty = func.dfg.value_type(result);
311
0
                            let projection = self
312
0
                                .egraph
313
0
                                .add(
314
0
                                    Node::Result {
315
0
                                        value: id,
316
0
                                        result: i,
317
0
                                        ty,
318
0
                                    },
319
0
                                    &mut self.node_ctx,
320
0
                                )
321
0
                                .get();
322
0
                            self.stats.node_created += 1;
323
0
                            self.stats.node_result += 1;
324
0
                            trace!("build: value {} -> id {}", result, projection);
325
0
                            value_to_id.insert(result, projection);
326
                        }
327
                    }
328
                }
329
            }
330
331
0
            let side_effect_end =
332
0
                u32::try_from(self.side_effect_ids.len()).expect("Overflow in side-effect count");
333
0
            let side_effect_range = side_effect_start..side_effect_end;
334
0
            self.side_effects[block] = side_effect_range;
335
        }
336
0
    }
337
338
    /// Scoped elaboration: compute a final ordering of op computation
339
    /// for each block and replace the given Func body.
340
    ///
341
    /// This works in concert with the domtree. We do a preorder
342
    /// traversal of the domtree, tracking a scoped map from Id to
343
    /// (new) Value. The map's scopes correspond to levels in the
344
    /// domtree.
345
    ///
346
    /// At each block, we iterate forward over the side-effecting
347
    /// eclasses, and recursively generate their arg eclasses, then
348
    /// emit the ops themselves.
349
    ///
350
    /// To use an eclass in a given block, we first look it up in the
351
    /// scoped map, and get the Value if already present. If not, we
352
    /// need to generate it. We emit the extracted enode for this
353
    /// eclass after recursively generating its args. Eclasses are
354
    /// thus computed "as late as possible", but then memoized into
355
    /// the Id-to-Value map and available to all dominated blocks and
356
    /// for the rest of this block. (This subsumes GVN.)
357
0
    pub fn elaborate(&mut self, func: &mut Function) {
358
0
        let mut elab = Elaborator::new(
359
0
            func,
360
0
            self.domtree,
361
0
            self.loop_analysis,
362
0
            &self.egraph,
363
0
            &self.node_ctx,
364
0
            &self.remat_ids,
365
0
            &mut self.stats,
366
0
        );
367
0
        elab.elaborate(
368
0
            |block| {
369
0
                let blockparam_range = self.blockparams[block].clone();
370
0
                &self.blockparam_ids_tys
371
0
                    [blockparam_range.start as usize..blockparam_range.end as usize]
372
0
            },
373
0
            |block| {
374
0
                let side_effect_range = self.side_effects[block].clone();
375
0
                &self.side_effect_ids
376
0
                    [side_effect_range.start as usize..side_effect_range.end as usize]
377
0
            },
378
0
        );
379
0
    }
380
}
381
382
/// State for egraph analysis that computes all needed properties.
383
pub(crate) struct Analysis;
384
385
/// Analysis results for each eclass id.
386
#[derive(Clone, Debug)]
387
pub(crate) struct AnalysisValue {
388
    pub(crate) loop_level: LoopLevel,
389
}
390
391
impl Default for AnalysisValue {
392
0
    fn default() -> Self {
393
0
        Self {
394
0
            loop_level: LoopLevel::root(),
395
0
        }
396
0
    }
397
}
398
399
impl cranelift_egraph::Analysis for Analysis {
400
    type L = NodeCtx;
401
    type Value = AnalysisValue;
402
403
0
    fn for_node(
404
0
        &self,
405
0
        ctx: &NodeCtx,
406
0
        n: &Node,
407
0
        values: &SecondaryMap<Id, AnalysisValue>,
408
0
    ) -> AnalysisValue {
409
0
        let loop_level = match n {
410
0
            &Node::Pure { ref args, .. } => args
411
0
                .as_slice(&ctx.args)
412
0
                .iter()
413
0
                .map(|&arg| values[arg].loop_level)
414
0
                .max()
415
0
                .unwrap_or(LoopLevel::root()),
416
0
            &Node::Load { addr, .. } => values[addr].loop_level,
417
0
            &Node::Result { value, .. } => values[value].loop_level,
418
0
            &Node::Inst { loop_level, .. } | &Node::Param { loop_level, .. } => loop_level,
419
        };
420
421
0
        AnalysisValue { loop_level }
422
0
    }
423
424
0
    fn meet(&self, _ctx: &NodeCtx, v1: &AnalysisValue, v2: &AnalysisValue) -> AnalysisValue {
425
0
        AnalysisValue {
426
0
            loop_level: std::cmp::max(v1.loop_level, v2.loop_level),
427
0
        }
428
0
    }
429
}