blob: 821e47085a5bddb70c9a0d8f2948a63025df0b68 [file] [log] [blame] [edit]
# RUN: %PYTHON %s 2>&1 | FileCheck %s
from mlir.ir import *
from mlir.passmanager import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects import arith
from mlir.rewrite import *
def run(f):
print("\nTEST:", f.__name__)
f()
# CHECK-LABEL: TEST: testRewritePattern
@run
def testRewritePattern():
def to_muli(op, rewriter):
with rewriter.ip:
assert isinstance(op, arith.AddIOp)
new_op = arith.muli(op.lhs, op.rhs, loc=op.location)
rewriter.replace_op(op, new_op.owner)
def constant_1_to_2(op, rewriter):
c = op.value.value
if c != 1:
return True # failed to match
with rewriter.ip:
new_op = arith.constant(op.type, 2, loc=op.location)
rewriter.replace_op(op, [new_op])
with Context():
patterns = RewritePatternSet()
patterns.add(arith.AddIOp, to_muli)
patterns.add(arith.ConstantOp, constant_1_to_2)
frozen = patterns.freeze()
module = ModuleOp.parse(
r"""
module {
func.func @add(%a: i64, %b: i64) -> i64 {
%sum = arith.addi %a, %b : i64
return %sum : i64
}
}
"""
)
apply_patterns_and_fold_greedily(module, frozen)
# CHECK: %0 = arith.muli %arg0, %arg1 : i64
# CHECK: return %0 : i64
print(module)
module = ModuleOp.parse(
r"""
module {
func.func @const() -> (i64, i64) {
%0 = arith.constant 1 : i64
%1 = arith.constant 3 : i64
return %0, %1 : i64, i64
}
}
"""
)
apply_patterns_and_fold_greedily(module, frozen)
# CHECK: %c2_i64 = arith.constant 2 : i64
# CHECK: %c3_i64 = arith.constant 3 : i64
# CHECK: return %c2_i64, %c3_i64 : i64, i64
print(module)