| # RUN: %PYTHON %s 2>&1 | FileCheck %s |
| |
| from mlir.ir import * |
| from mlir.passmanager import * |
| from mlir.dialects.builtin import ModuleOp |
| from mlir.dialects import arith |
| from mlir.rewrite import * |
| |
| |
| def run(f): |
| print("\nTEST:", f.__name__) |
| f() |
| |
| |
| # CHECK-LABEL: TEST: testRewritePattern |
| @run |
| def testRewritePattern(): |
| def to_muli(op, rewriter): |
| with rewriter.ip: |
| assert isinstance(op, arith.AddIOp) |
| new_op = arith.muli(op.lhs, op.rhs, loc=op.location) |
| rewriter.replace_op(op, new_op.owner) |
| |
| def constant_1_to_2(op, rewriter): |
| c = op.value.value |
| if c != 1: |
| return True # failed to match |
| with rewriter.ip: |
| new_op = arith.constant(op.type, 2, loc=op.location) |
| rewriter.replace_op(op, [new_op]) |
| |
| with Context(): |
| patterns = RewritePatternSet() |
| patterns.add(arith.AddIOp, to_muli) |
| patterns.add(arith.ConstantOp, constant_1_to_2) |
| frozen = patterns.freeze() |
| |
| module = ModuleOp.parse( |
| r""" |
| module { |
| func.func @add(%a: i64, %b: i64) -> i64 { |
| %sum = arith.addi %a, %b : i64 |
| return %sum : i64 |
| } |
| } |
| """ |
| ) |
| |
| apply_patterns_and_fold_greedily(module, frozen) |
| # CHECK: %0 = arith.muli %arg0, %arg1 : i64 |
| # CHECK: return %0 : i64 |
| print(module) |
| |
| module = ModuleOp.parse( |
| r""" |
| module { |
| func.func @const() -> (i64, i64) { |
| %0 = arith.constant 1 : i64 |
| %1 = arith.constant 3 : i64 |
| return %0, %1 : i64, i64 |
| } |
| } |
| """ |
| ) |
| |
| apply_patterns_and_fold_greedily(module, frozen) |
| # CHECK: %c2_i64 = arith.constant 2 : i64 |
| # CHECK: %c3_i64 = arith.constant 3 : i64 |
| # CHECK: return %c2_i64, %c3_i64 : i64, i64 |
| print(module) |