bpf: Fix subreg optimization for BPF_FETCH

All 32-bit variants of BPF_FETCH (add, and, or, xor, xchg, cmpxchg)
define a 32-bit subreg and thus have zext_dst set. Their encoding,
however, uses dst_reg field as a base register, which causes
opt_subreg_zext_lo32_rnd_hi32() to zero-extend said base register
instead of the one the insn really defines (r0 or src_reg).

Fix by properly choosing a register being defined, similar to how
check_atomic() already does that.

Change-Id: I71fe565d5ded9edf8ec7113a1410d9dda0d899d8
Signed-off-by: Ilya Leoshkevich <iii@linux.ibm.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/bpf/20210210204502.83429-1-iii@linux.ibm.com
SOURCE=UPSTREAM(b2e37a7114ef52b862b4421ed4cd40c4ed2a0642)
BUG=b:153508410
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 000b722..4874955 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -10411,6 +10411,7 @@
 	for (i = 0; i < len; i++) {
 		int adj_idx = i + delta;
 		struct bpf_insn insn;
+		u8 load_reg;
 
 		insn = insns[adj_idx];
 		if (!aux[adj_idx].zext_dst) {
@@ -10463,9 +10464,27 @@
 		if (!bpf_jit_needs_zext() && !is_cmpxchg_insn(&insn))
 			continue;
 
+		/* zext_dst means that we want to zero-extend whatever register
+		 * the insn defines, which is dst_reg most of the time, with
+		 * the notable exception of BPF_STX + BPF_ATOMIC + BPF_FETCH.
+		 */
+		if (BPF_CLASS(insn.code) == BPF_STX &&
+		    BPF_MODE(insn.code) == BPF_ATOMIC) {
+			/* BPF_STX + BPF_ATOMIC insns without BPF_FETCH do not
+			 * define any registers, therefore zext_dst cannot be
+			 * set.
+			 */
+			if (WARN_ON(!(insn.imm & BPF_FETCH)))
+				return -EINVAL;
+			load_reg = insn.imm == BPF_CMPXCHG ? BPF_REG_0
+							   : insn.src_reg;
+		} else {
+			load_reg = insn.dst_reg;
+		}
+
 		zext_patch[0] = insn;
-		zext_patch[1].dst_reg = insn.dst_reg;
-		zext_patch[1].src_reg = insn.dst_reg;
+		zext_patch[1].dst_reg = load_reg;
+		zext_patch[1].src_reg = load_reg;
 		patch = zext_patch;
 		patch_len = 2;
 apply_patch_buffer: