blob: 49e29f60c31b6f863921ee8d34ac434c202d1a23 [file] [log] [blame]
commit 7bb8bfa0622b8ee55c3f748004dcf4d83d48cf97
Author: Sanjay Patel <spatel@rotateright.com>
Date: Sun May 30 06:43:33 2021 -0400
[InstCombine] fix miscompile from vector select substitution
This is similar to the fix in c590a9880d7a ( PR49832 ), but
we missed handling the pattern for select of bools (no compare
inst).
We can't substitute a vector value because the equality condition
replacement that we are attempting requires that the condition
is true/false for the entire value. Vector select can be partly
true/false.
I added an assert for vector types, so we shouldn't hit this again.
Fixed formatting while auditing the callers.
https://llvm.org/PR50500
diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index e1e7da14376e..75ce4e38df16 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -299,7 +299,7 @@ Value *SimplifyInstruction(Instruction *I, const SimplifyQuery &Q,
/// return null.
/// AllowRefinement specifies whether the simplification can be a refinement
/// (e.g. 0 instead of poison), or whether it needs to be strictly identical.
-Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
const SimplifyQuery &Q, bool AllowRefinement);
/// Replace all uses of 'I' with 'SimpleV' and simplify the uses recursively.
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index c72670b901fe..e6baed1779cd 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3852,10 +3852,12 @@ Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit);
}
-static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
const SimplifyQuery &Q,
bool AllowRefinement,
unsigned MaxRecurse) {
+ assert(!Op->getType()->isVectorTy() && "This is not safe for vectors");
+
// Trivial replacement.
if (V == Op)
return RepOp;
@@ -3965,10 +3967,10 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI);
}
-Value *llvm::SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+Value *llvm::simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
const SimplifyQuery &Q,
bool AllowRefinement) {
- return ::SimplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement,
+ return ::simplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement,
RecursionLimit);
}
@@ -4090,17 +4092,17 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
// Note that the equivalence/replacement opportunity does not hold for vectors
// because each element of a vector select is chosen independently.
if (Pred == ICmpInst::ICMP_EQ && !CondVal->getType()->isVectorTy()) {
- if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
+ if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
/* AllowRefinement */ false, MaxRecurse) ==
TrueVal ||
- SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q,
+ simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q,
/* AllowRefinement */ false, MaxRecurse) ==
TrueVal)
return FalseVal;
- if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
+ if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
/* AllowRefinement */ true, MaxRecurse) ==
FalseVal ||
- SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q,
+ simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q,
/* AllowRefinement */ true, MaxRecurse) ==
FalseVal)
return FalseVal;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 68124a188f71..f48cc901b742 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1142,7 +1142,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
if (TrueVal != CmpLHS &&
isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) {
- if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ,
+ if (Value *V = simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ,
/* AllowRefinement */ true))
return replaceOperand(Sel, Swapped ? 2 : 1, V);
@@ -1164,7 +1164,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
}
if (TrueVal != CmpRHS &&
isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT))
- if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ,
+ if (Value *V = simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ,
/* AllowRefinement */ true))
return replaceOperand(Sel, Swapped ? 2 : 1, V);
@@ -1195,9 +1195,9 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
// We have an 'EQ' comparison, so the select's false value will propagate.
// Example:
// (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
- if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ,
+ if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ,
/* AllowRefinement */ false) == TrueVal ||
- SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ,
+ simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ,
/* AllowRefinement */ false) == TrueVal) {
return replaceInstUsesWith(Sel, FalseVal);
}
@@ -2714,12 +2714,14 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero()))
return replaceOperand(SI, 0, A);
- if (Value *S = SimplifyWithOpReplaced(TrueVal, CondVal, One, SQ,
- /* AllowRefinement */ true))
- return replaceOperand(SI, 1, S);
- if (Value *S = SimplifyWithOpReplaced(FalseVal, CondVal, Zero, SQ,
- /* AllowRefinement */ true))
- return replaceOperand(SI, 2, S);
+ if (!SelType->isVectorTy()) {
+ if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal, One, SQ,
+ /* AllowRefinement */ true))
+ return replaceOperand(SI, 1, S);
+ if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal, Zero, SQ,
+ /* AllowRefinement */ true))
+ return replaceOperand(SI, 2, S);
+ }
if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) {
Use *Y = nullptr;
diff --git a/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll b/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll
index c15a64ee7315..fef4081c0bb6 100644
--- a/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll
+++ b/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll
@@ -468,3 +468,28 @@ define i1 @bor_lor_right2(i1 %A, i1 %B) {
ret i1 %res
}
+; Value equivalence substitution does not account for vector
+; transforms, so it needs a scalar condition operand.
+; For example, this would miscompile if %a = {1, 0}.
+
+define <2 x i1> @PR50500_trueval(<2 x i1> %a, <2 x i1> %b) {
+; CHECK-LABEL: @PR50500_trueval(
+; CHECK-NEXT: [[S:%.*]] = shufflevector <2 x i1> [[A:%.*]], <2 x i1> poison, <2 x i32> <i32 1, i32 0>
+; CHECK-NEXT: [[R:%.*]] = select <2 x i1> [[A]], <2 x i1> [[S]], <2 x i1> [[B:%.*]]
+; CHECK-NEXT: ret <2 x i1> [[R]]
+;
+ %s = shufflevector <2 x i1> %a, <2 x i1> poison, <2 x i32> <i32 1, i32 0>
+ %r = select <2 x i1> %a, <2 x i1> %s, <2 x i1> %b
+ ret <2 x i1> %r
+}
+
+define <2 x i1> @PR50500_falseval(<2 x i1> %a, <2 x i1> %b) {
+; CHECK-LABEL: @PR50500_falseval(
+; CHECK-NEXT: [[S:%.*]] = shufflevector <2 x i1> [[A:%.*]], <2 x i1> poison, <2 x i32> <i32 1, i32 0>
+; CHECK-NEXT: [[R:%.*]] = select <2 x i1> [[A]], <2 x i1> [[B:%.*]], <2 x i1> [[S]]
+; CHECK-NEXT: ret <2 x i1> [[R]]
+;
+ %s = shufflevector <2 x i1> %a, <2 x i1> poison, <2 x i32> <i32 1, i32 0>
+ %r = select <2 x i1> %a, <2 x i1> %b, <2 x i1> %s
+ ret <2 x i1> %r
+}