Skip to content

Various fixes and improvements to the MIR Dataflow framework #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: mir-dflow
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 33 additions & 73 deletions src/librustc/mir/transform/dataflow.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
use mir::repr as mir;
use mir::cfg::CFG;
use mir::repr::BasicBlock;
use mir::repr::{BasicBlock, START_BLOCK};
use rustc_data_structures::bitvec::BitVector;

use mir::transform::lattice::Lattice;


pub trait DataflowPass<'tcx> {
type Lattice: Lattice;
type Rewrite: Rewrite<'tcx, Self::Lattice>;
type Transfer: Transfer<'tcx, Self::Lattice>;
}

pub trait Rewrite<'tcx, L: Lattice> {
/// The rewrite function which given a statement optionally produces an alternative graph to be
/// placed in place of the original statement.
Expand All @@ -23,7 +16,7 @@ pub trait Rewrite<'tcx, L: Lattice> {
/// that is, given some fact `fact` true before both the statement and relacement graph, and
/// a fact `fact2` which is true after the statement, the same `fact2` must be true after the
/// replacement graph too.
fn stmt(&mir::Statement<'tcx>, &L, &mut CFG<'tcx>) -> StatementChange<'tcx>;
fn stmt(&self, &mir::Statement<'tcx>, &L, &mut CFG<'tcx>) -> StatementChange<'tcx>;

/// The rewrite function which given a terminator optionally produces an alternative graph to
/// be placed in place of the original statement.
Expand All @@ -35,49 +28,39 @@ pub trait Rewrite<'tcx, L: Lattice> {
/// that is, given some fact `fact` true before both the terminator and relacement graph, and
/// a fact `fact2` which is true after the statement, the same `fact2` must be true after the
/// replacement graph too.
fn term(&mir::Terminator<'tcx>, &L, &mut CFG<'tcx>) -> TerminatorChange<'tcx>;
fn term(&self, &mir::Terminator<'tcx>, &L, &mut CFG<'tcx>) -> TerminatorChange<'tcx>;

fn and_then<R2>(self, other: R2) -> RewriteAndThen<Self, R2> where Self: Sized {
RewriteAndThen(self, other)
}
}

/// This combinator has the following behaviour:
///
/// * Rewrite the node with the first rewriter.
/// * if the first rewriter replaced the node, 2nd rewriter is used to rewrite the replacement.
/// * otherwise 2nd rewriter is used to rewrite the original node.
pub struct RewriteAndThen<'tcx, R1, R2>(::std::marker::PhantomData<(&'tcx (), R1, R2)>);
impl<'tcx, L, R1, R2> Rewrite<'tcx, L> for RewriteAndThen<'tcx, R1, R2>
pub struct RewriteAndThen<R1, R2>(R1, R2);
impl<'tcx, L, R1, R2> Rewrite<'tcx, L> for RewriteAndThen<R1, R2>
where L: Lattice, R1: Rewrite<'tcx, L>, R2: Rewrite<'tcx, L> {
fn stmt(s: &mir::Statement<'tcx>, l: &L, c: &mut CFG<'tcx>) -> StatementChange<'tcx> {
let rs = <R1 as Rewrite<L>>::stmt(s, l, c);
fn stmt(&self, s: &mir::Statement<'tcx>, l: &L, c: &mut CFG<'tcx>) -> StatementChange<'tcx> {
let rs = self.0.stmt(s, l, c);
match rs {
StatementChange::None => <R2 as Rewrite<L>>::stmt(s, l, c),
StatementChange::None => self.1.stmt(s, l, c),
StatementChange::Remove => StatementChange::Remove,
StatementChange::Statement(ns) =>
match <R2 as Rewrite<L>>::stmt(&ns, l, c) {
match self.1.stmt(&ns, l, c) {
StatementChange::None => StatementChange::Statement(ns),
x => x
},
StatementChange::Statements(nss) => {
// We expect the common case of all statements in this vector being replaced/not
// replaced by other statements 1:1
let mut new_new_stmts = Vec::with_capacity(nss.len());
for s in nss {
match <R2 as Rewrite<L>>::stmt(&s, l, c) {
StatementChange::None => new_new_stmts.push(s),
StatementChange::Remove => {},
StatementChange::Statement(ns) => new_new_stmts.push(ns),
StatementChange::Statements(nss) => new_new_stmts.extend(nss)
}
}
StatementChange::Statements(new_new_stmts)
}
}
}

fn term(t: &mir::Terminator<'tcx>, l: &L, c: &mut CFG<'tcx>) -> TerminatorChange<'tcx> {
let rt = <R1 as Rewrite<L>>::term(t, l, c);
fn term(&self, t: &mir::Terminator<'tcx>, l: &L, c: &mut CFG<'tcx>) -> TerminatorChange<'tcx> {
let rt = self.0.term(t, l, c);
match rt {
TerminatorChange::None => <R2 as Rewrite<L>>::term(t, l, c),
TerminatorChange::Terminator(nt) => match <R2 as Rewrite<L>>::term(&nt, l, c) {
TerminatorChange::None => self.1.term(t, l, c),
TerminatorChange::Terminator(nt) => match self.1.term(&nt, l, c) {
TerminatorChange::None => TerminatorChange::Terminator(nt),
x => x
}
Expand All @@ -99,39 +82,22 @@ pub enum StatementChange<'tcx> {
Remove,
/// Replace with another single statement
Statement(mir::Statement<'tcx>),
/// Replace with a list of statements
Statements(Vec<mir::Statement<'tcx>>),
}

impl<'tcx> StatementChange<'tcx> {
fn normalise(&mut self) {
let old = ::std::mem::replace(self, StatementChange::None);
*self = match old {
StatementChange::Statements(mut stmts) => {
match stmts.len() {
0 => StatementChange::Remove,
1 => StatementChange::Statement(stmts.pop().unwrap()),
_ => StatementChange::Statements(stmts)
}
}
o => o
}
}
}
pub trait Transfer<'tcx> {
type Lattice: Lattice;

pub trait Transfer<'tcx, L: Lattice> {
type TerminatorOut;
/// The transfer function which given a statement and a fact produces a fact which is true
/// after the statement.
fn stmt(&mir::Statement<'tcx>, L) -> L;
fn stmt(&self, &mir::Statement<'tcx>, Self::Lattice) -> Self::Lattice;

/// The transfer function which given a terminator and a fact produces a fact for each
/// successor of the terminator.
///
/// Corectness precondtition:
/// * The list of facts produced should only contain the facts for blocks which are successors
/// of the terminator being transfered.
fn term(&mir::Terminator<'tcx>, L) -> Self::TerminatorOut;
fn term(&self, &mir::Terminator<'tcx>, Self::Lattice) -> Vec<Self::Lattice>;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change makes the trait not very reusable for backward analysis, where terminators only have a single edge regardless of how many edges the terminator has in the forward direction. That is the primary motivation for TerminatorOut associated type.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but I'm not sure that the same trait should be reused for the forward and backwards cases. Regardless, even if the two directions share a trait, I'd prefer a clearer way of denoting which direction a pass goes than using Vec<Lattice> for forward and Lattice for backwards. Transfer could have another argument (Direction) which itself has an associated type denoting what output term should have.

I got rid of the associated TerminatorOut because it felt clumsy, it cause recursive trait bounds when I moved in the associated lattice type, and because it wasn't necessary yet. When backwards transformation gets implemented, I'd be happy for this to be generalized.

}


Expand Down Expand Up @@ -168,12 +134,14 @@ impl<F: Lattice> ::std::ops::IndexMut<BasicBlock> for Facts<F> {
}

/// Analyse and rewrite using dataflow in the forward direction
pub fn ar_forward<'tcx, T, P>(cfg: &CFG<'tcx>, fs: Facts<P::Lattice>, mut queue: BitVector)
-> (CFG<'tcx>, Facts<P::Lattice>)
// FIXME: shouldn’t need that T generic.
where T: Transfer<'tcx, P::Lattice, TerminatorOut=Vec<P::Lattice>>,
P: DataflowPass<'tcx, Transfer=T>
pub fn ar_forward<'tcx, T, R>(cfg: &CFG<'tcx>, fs: Facts<T::Lattice>, transfer: T, rewrite: R)
-> (CFG<'tcx>, Facts<T::Lattice>)
where T: Transfer<'tcx>,
R: Rewrite<'tcx, T::Lattice>
{
let mut queue = BitVector::new(cfg.len());
queue.insert(START_BLOCK.index());
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here, I do not believe that beginning at the start block is always correct/desired/etc.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give an example of when another location would be desired? For a forwards pass, the start block seems like a natural place to start, and allows for a number of blocks to be transformed only once - if the CFG is acyclic, fixpoint iteration is unnecessary starting at the start block.


fixpoint(cfg, Direction::Forward, |bb, fact, cfg| {
let new_graph = cfg.start_new_block();
let mut fact = fact.clone();
Expand All @@ -183,33 +151,25 @@ where T: Transfer<'tcx, P::Lattice, TerminatorOut=Vec<P::Lattice>>,
for stmt in &old_statements {
// Given a fact and statement produce a new fact and optionally a replacement
// graph.
let mut new_repl = P::Rewrite::stmt(&stmt, &fact, cfg);
new_repl.normalise();
match new_repl {
match rewrite.stmt(&stmt, &fact, cfg) {
StatementChange::None => {
fact = P::Transfer::stmt(stmt, fact);
fact = transfer.stmt(stmt, fact);
cfg.push(new_graph, stmt.clone());
}
StatementChange::Remove => changed = true,
StatementChange::Statement(stmt) => {
changed = true;
fact = P::Transfer::stmt(&stmt, fact);
fact = transfer.stmt(&stmt, fact);
cfg.push(new_graph, stmt);
}
StatementChange::Statements(stmts) => {
changed = true;
for stmt in &stmts { fact = P::Transfer::stmt(stmt, fact); }
cfg[new_graph].statements.extend(stmts);
}

}
}
// Swap the statements back in.
::std::mem::replace(&mut cfg[bb].statements, old_statements);

// Handle the terminator replacement and transfer.
let terminator = ::std::mem::replace(&mut cfg[bb].terminator, None).unwrap();
let repl = P::Rewrite::term(&terminator, &fact, cfg);
let repl = rewrite.term(&terminator, &fact, cfg);
match repl {
TerminatorChange::None => {
cfg[new_graph].terminator = Some(terminator.clone());
Expand All @@ -219,7 +179,7 @@ where T: Transfer<'tcx, P::Lattice, TerminatorOut=Vec<P::Lattice>>,
cfg[new_graph].terminator = Some(t);
}
}
let new_facts = P::Transfer::term(cfg[new_graph].terminator(), fact);
let new_facts = transfer.term(cfg[new_graph].terminator(), fact);
::std::mem::replace(&mut cfg[bb].terminator, Some(terminator));

(if changed { Some(new_graph) } else { None }, new_facts)
Expand Down
2 changes: 1 addition & 1 deletion src/librustc/mir/transform/lattice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl<T: Lattice> Lattice for WTop<T> {
/// ⊤ + V = ⊤ (no change)
/// V + ⊤ = ⊤
/// ⊤ + ⊤ = ⊤ (no change)
default fn join(&mut self, other: &Self) -> bool {
fn join(&mut self, other: &Self) -> bool {
match (self, other) {
(&mut WTop::Value(ref mut this), &WTop::Value(ref o)) => <T as Lattice>::join(this, o),
(&mut WTop::Top, _) => false,
Expand Down
2 changes: 1 addition & 1 deletion src/librustc_driver/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ pub fn phase_3_run_analysis_passes<'tcx, F, R>(sess: &'tcx Session,
passes.push_pass(box mir::transform::remove_dead_blocks::RemoveDeadBlocks);
passes.push_pass(box mir::transform::qualify_consts::QualifyAndPromoteConstants);
passes.push_pass(box mir::transform::type_check::TypeckMir);
passes.push_pass(box mir::transform::acs_propagate::ACSPropagate);
passes.push_pass(box mir::transform::acs_propagate::AcsPropagate);
// passes.push_pass(box mir::transform::simplify_cfg::SimplifyCfg);
passes.push_pass(box mir::transform::remove_dead_blocks::RemoveDeadBlocks);
// And run everything.
Expand Down
60 changes: 30 additions & 30 deletions src/librustc_mir/transform/acs_propagate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
//! manually here.

use rustc_data_structures::fnv::FnvHashMap;
use rustc_data_structures::bitvec::BitVector;
use rustc::mir::repr::*;
use rustc::mir::visit::{MutVisitor, LvalueContext};
use rustc::mir::transform::lattice::Lattice;
Expand Down Expand Up @@ -58,35 +57,32 @@ impl<'tcx> Lattice for Either<'tcx> {
}
}

pub type ACSLattice<'a> = FnvHashMap<Lvalue<'a>, Either<'a>>;
pub type AcsLattice<'a> = FnvHashMap<Lvalue<'a>, Either<'a>>;

pub struct ACSPropagate;
pub struct AcsPropagate;

impl Pass for ACSPropagate {}
impl Pass for AcsPropagate {}

impl<'tcx> MirPass<'tcx> for ACSPropagate {
impl<'tcx> MirPass<'tcx> for AcsPropagate {
fn run_pass<'a>(&mut self, tcx: TyCtxt<'a, 'tcx, 'tcx>, src: MirSource, mir: &mut Mir<'tcx>) {
let mut q = BitVector::new(mir.cfg.len());
q.insert(START_BLOCK.index());
let ret = ar_forward::<ACSPropagateTransfer, ACSPropagate>(&mut mir.cfg, Facts::new(), q);
let ret = ar_forward(
&mut mir.cfg,
Facts::new(),
AcsPropagateTransfer,
AliasRewrite.and_then(ConstRewrite).and_then(SimplifyRewrite)
);
mir.cfg = ret.0;
pretty::dump_mir(tcx, "acs_propagate", &0, src, mir, None);
}

}

impl<'tcx> DataflowPass<'tcx> for ACSPropagate {
type Lattice = ACSLattice<'tcx>;
type Rewrite = RewriteAndThen<'tcx, AliasRewrite,
RewriteAndThen<'tcx, ConstRewrite, SimplifyRewrite>>;
type Transfer = ACSPropagateTransfer;
}
pub struct AcsPropagateTransfer;

pub struct ACSPropagateTransfer;
impl<'tcx> Transfer<'tcx> for AcsPropagateTransfer {
type Lattice = AcsLattice<'tcx>;

impl<'tcx> Transfer<'tcx, ACSLattice<'tcx>> for ACSPropagateTransfer {
type TerminatorOut = Vec<ACSLattice<'tcx>>;
fn stmt(s: &Statement<'tcx>, mut lat: ACSLattice<'tcx>) -> ACSLattice<'tcx> {
fn stmt(&self, s: &Statement<'tcx>, mut lat: AcsLattice<'tcx>) -> AcsLattice<'tcx> {
let StatementKind::Assign(ref lval, ref rval) = s.kind;
match *rval {
Rvalue::Use(Operand::Consume(ref nlval)) =>
Expand All @@ -97,7 +93,8 @@ impl<'tcx> Transfer<'tcx, ACSLattice<'tcx>> for ACSPropagateTransfer {
};
lat
}
fn term(t: &Terminator<'tcx>, lat: ACSLattice<'tcx>) -> Self::TerminatorOut {

fn term(&self, t: &Terminator<'tcx>, lat: AcsLattice<'tcx>) -> Vec<AcsLattice<'tcx>> {
// FIXME: this should inspect the terminators and set their known values to constants. Esp.
// for the if: in the truthy branch the operand is known to be true and in the falsy branch
// the operand is known to be false. Now we just ignore the potential here.
Expand All @@ -109,15 +106,16 @@ impl<'tcx> Transfer<'tcx, ACSLattice<'tcx>> for ACSPropagateTransfer {

pub struct AliasRewrite;

impl<'tcx> Rewrite<'tcx, ACSLattice<'tcx>> for AliasRewrite {
fn stmt(s: &Statement<'tcx>, l: &ACSLattice<'tcx>, cfg: &mut CFG<'tcx>)
impl<'tcx> Rewrite<'tcx, AcsLattice<'tcx>> for AliasRewrite {
fn stmt(&self, s: &Statement<'tcx>, l: &AcsLattice<'tcx>, cfg: &mut CFG<'tcx>)
-> StatementChange<'tcx> {
let mut ns = s.clone();
let mut vis = RewriteAliasVisitor(&l, false);
vis.visit_statement(START_BLOCK, &mut ns);
if vis.1 { StatementChange::Statement(ns) } else { StatementChange::None }
}
fn term(t: &Terminator<'tcx>, l: &ACSLattice<'tcx>, cfg: &mut CFG<'tcx>)

fn term(&self, t: &Terminator<'tcx>, l: &AcsLattice<'tcx>, cfg: &mut CFG<'tcx>)
-> TerminatorChange<'tcx> {
let mut nt = t.clone();
let mut vis = RewriteAliasVisitor(&l, false);
Expand All @@ -126,7 +124,7 @@ impl<'tcx> Rewrite<'tcx, ACSLattice<'tcx>> for AliasRewrite {
}
}

struct RewriteAliasVisitor<'a, 'tcx: 'a>(pub &'a ACSLattice<'tcx>, pub bool);
struct RewriteAliasVisitor<'a, 'tcx: 'a>(pub &'a AcsLattice<'tcx>, pub bool);
impl<'a, 'tcx> MutVisitor<'tcx> for RewriteAliasVisitor<'a, 'tcx> {
fn visit_lvalue(&mut self, lvalue: &mut Lvalue<'tcx>, context: LvalueContext) {
match context {
Expand All @@ -148,15 +146,16 @@ impl<'a, 'tcx> MutVisitor<'tcx> for RewriteAliasVisitor<'a, 'tcx> {

pub struct ConstRewrite;

impl<'tcx> Rewrite<'tcx, ACSLattice<'tcx>> for ConstRewrite {
fn stmt(s: &Statement<'tcx>, l: &ACSLattice<'tcx>, cfg: &mut CFG<'tcx>)
impl<'tcx> Rewrite<'tcx, AcsLattice<'tcx>> for ConstRewrite {
fn stmt(&self, s: &Statement<'tcx>, l: &AcsLattice<'tcx>, cfg: &mut CFG<'tcx>)
-> StatementChange<'tcx> {
let mut ns = s.clone();
let mut vis = RewriteConstVisitor(&l, false);
vis.visit_statement(START_BLOCK, &mut ns);
if vis.1 { StatementChange::Statement(ns) } else { StatementChange::None }
}
fn term(t: &Terminator<'tcx>, l: &ACSLattice<'tcx>, cfg: &mut CFG<'tcx>)

fn term(&self, t: &Terminator<'tcx>, l: &AcsLattice<'tcx>, cfg: &mut CFG<'tcx>)
-> TerminatorChange<'tcx> {
let mut nt = t.clone();
let mut vis = RewriteConstVisitor(&l, false);
Expand All @@ -165,7 +164,7 @@ impl<'tcx> Rewrite<'tcx, ACSLattice<'tcx>> for ConstRewrite {
}
}

struct RewriteConstVisitor<'a, 'tcx: 'a>(pub &'a ACSLattice<'tcx>, pub bool);
struct RewriteConstVisitor<'a, 'tcx: 'a>(pub &'a AcsLattice<'tcx>, pub bool);
impl<'a, 'tcx> MutVisitor<'tcx> for RewriteConstVisitor<'a, 'tcx> {
fn visit_operand(&mut self, op: &mut Operand<'tcx>) {
let repl = if let Operand::Consume(ref lval) = *op {
Expand All @@ -187,12 +186,13 @@ impl<'a, 'tcx> MutVisitor<'tcx> for RewriteConstVisitor<'a, 'tcx> {

pub struct SimplifyRewrite;

impl<'tcx> Rewrite<'tcx, ACSLattice<'tcx>> for SimplifyRewrite {
fn stmt(s: &Statement<'tcx>, l: &ACSLattice<'tcx>, cfg: &mut CFG<'tcx>)
impl<'tcx> Rewrite<'tcx, AcsLattice<'tcx>> for SimplifyRewrite {
fn stmt(&self, s: &Statement<'tcx>, l: &AcsLattice<'tcx>, cfg: &mut CFG<'tcx>)
-> StatementChange<'tcx> {
StatementChange::None
}
fn term(t: &Terminator<'tcx>, l: &ACSLattice<'tcx>, cfg: &mut CFG<'tcx>)

fn term(&self, t: &Terminator<'tcx>, l: &AcsLattice<'tcx>, cfg: &mut CFG<'tcx>)
-> TerminatorChange<'tcx> {
match t.kind {
TerminatorKind::If { ref targets, .. } if targets.0 == targets.1 => {
Expand Down