blob: b470120cab8d521af0e13e119f79e26692985491 [file] [log] [blame]
From 79f89e6e5a659846d1068e8b1bd8e491ccdef861 Mon Sep 17 00:00:00 2001
From: Pablo Galindo <Pablogsal@gmail.com>
Date: Thu, 23 Jan 2020 14:07:05 +0000
Subject: [PATCH] bpo-39421: Fix posible crash in heapq with custom comparison
operators (GH-18118)
* bpo-39421: Fix posible crash in heapq with custom comparison operators
* fixup! bpo-39421: Fix posible crash in heapq with custom comparison operators
* fixup! fixup! bpo-39421: Fix posible crash in heapq with custom comparison operators
---
Lib/test/test_heapq.py | 31 ++++++++++++++++
.../2020-01-22-15-53-37.bpo-39421.O3nG7u.rst | 2 ++
Modules/_heapqmodule.c | 35 ++++++++++++++-----
3 files changed, 59 insertions(+), 9 deletions(-)
create mode 100644 Misc/NEWS.d/next/Core and Builtins/2020-01-22-15-53-37.bpo-39421.O3nG7u.rst
Backport:
* Drop Misc/NEWS.d
* test_heapq.py:
+ Update hunk context
+ list.clear() -> del list[:]
* _heapqmodule.c: Port the patch with significant changes
+ PyObject_RichCompareBool -> cmp_lt
+ X[Y] -> PyList_GET_ITEM(X, Y)
+ 4th hunk: newitem refcount is already incremented, parent refcount extended
diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py
index 861ba7540d..6902573e8f 100644
--- a/Lib/test/test_heapq.py
+++ b/Lib/test/test_heapq.py
@@ -432,6 +432,37 @@ def test_heappop_mutating_heap(self):
with self.assertRaises((IndexError, RuntimeError)):
self.module.heappop(heap)
+ def test_comparison_operator_modifiying_heap(self):
+ # See bpo-39421: Strong references need to be taken
+ # when comparing objects as they can alter the heap
+ class EvilClass(int):
+ def __lt__(self, o):
+ del heap[:]
+ return NotImplemented
+
+ heap = []
+ self.module.heappush(heap, EvilClass(0))
+ self.assertRaises(IndexError, self.module.heappushpop, heap, 1)
+
+ def test_comparison_operator_modifiying_heap_two_heaps(self):
+
+ class h(int):
+ def __lt__(self, o):
+ del list2[:]
+ return NotImplemented
+
+ class g(int):
+ def __lt__(self, o):
+ del list1[:]
+ return NotImplemented
+
+ list1, list2 = [], []
+
+ self.module.heappush(list1, h(0))
+ self.module.heappush(list2, g(0))
+
+ self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
+ self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
class TestErrorHandlingPython(TestErrorHandling):
module = py_heapq
diff --git a/Modules/_heapqmodule.c b/Modules/_heapqmodule.c
index a84cade3aa..6bc18b5f82 100644
--- a/Modules/_heapqmodule.c
+++ b/Modules/_heapqmodule.c
@@ -36,7 +36,11 @@ siftdown(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos)
while (pos > startpos) {
parentpos = (pos - 1) >> 1;
parent = PyList_GET_ITEM(heap, parentpos);
+ Py_INCREF(newitem);
+ Py_INCREF(parent);
cmp = cmp_lt(newitem, parent);
+ Py_DECREF(parent);
+ Py_DECREF(newitem);
if (cmp == -1)
return -1;
if (size != PyList_GET_SIZE(heap)) {
@@ -78,9 +82,13 @@ siftup(PyListObject *heap, Py_ssize_t pos)
childpos = 2*pos + 1; /* leftmost child position */
rightpos = childpos + 1;
if (rightpos < endpos) {
- cmp = cmp_lt(
- PyList_GET_ITEM(heap, childpos),
- PyList_GET_ITEM(heap, rightpos));
+ PyObject* a = PyList_GET_ITEM(heap, childpos);
+ PyObject* b = PyList_GET_ITEM(heap, rightpos);
+ Py_INCREF(a);
+ Py_INCREF(b);
+ cmp = cmp_lt(a, b);
+ Py_DECREF(a);
+ Py_DECREF(b);
if (cmp == -1)
return -1;
if (cmp == 0)
@@ -264,7 +271,10 @@ _heapq_heappushpop_impl(PyObject *module, PyObject *heap, PyObject *item)
return item;
}
- cmp = cmp_lt(PyList_GET_ITEM(heap, 0), item);
+ PyObject* top = PyList_GET_ITEM(heap, 0);
+ Py_INCREF(top);
+ cmp = cmp_lt(top, item);
+ Py_DECREF(top);
if (cmp == -1)
return NULL;
if (cmp == 0) {
@@ -420,14 +430,17 @@ siftdown_max(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos)
while (pos > startpos){
parentpos = (pos - 1) >> 1;
parent = PyList_GET_ITEM(heap, parentpos);
+ Py_INCREF(parent);
cmp = cmp_lt(parent, newitem);
if (cmp == -1) {
+ Py_DECREF(parent);
Py_DECREF(newitem);
return -1;
}
- if (cmp == 0)
+ if (cmp == 0) {
+ Py_DECREF(parent);
break;
+ }
- Py_INCREF(parent);
Py_DECREF(PyList_GET_ITEM(heap, pos));
PyList_SET_ITEM(heap, pos, parent);
pos = parentpos;
@@ -462,9 +476,13 @@ siftup_max(PyListObject *heap, Py_ssize_t pos)
childpos = 2*pos + 1; /* leftmost child position */
rightpos = childpos + 1;
if (rightpos < endpos) {
- cmp = cmp_lt(
- PyList_GET_ITEM(heap, rightpos),
- PyList_GET_ITEM(heap, childpos));
+ PyObject* a = PyList_GET_ITEM(heap, rightpos);
+ PyObject* b = PyList_GET_ITEM(heap, childpos);
+ Py_INCREF(a);
+ Py_INCREF(b);
+ cmp = cmp_lt(a, b);
+ Py_DECREF(a);
+ Py_DECREF(b);
if (cmp == -1) {
Py_DECREF(newitem);
return -1;