本文同步刊載於 「為你自己學 Python - 轉呀轉呀七彩迭代器)」
在 Python 裡迭代器(Iterator)使用的頻率很高,讓我們可以不用像其它程式語言一樣用 for
迴圈就能遍歷各種「容器」,而且這所謂的容器還不是只有串列,字典、字串、範圍(Range)這些都能用類似的操作,這個章節我們就來看看迭代器是怎麼實作的。
在「為你自己學 Python」的物件導向程式設計 - 進階篇有提到三個看起來有點像的東西,分別是 Iteration、Iterable 以及 Iterator,很快的複習一下:
根據 Python 對迭代器的定義,只要有實作「迭代器協議(Iterator Protocol)」的物件,就能被稱之迭代器。迭代器協議的內容也很簡單,只要有實作 __iter__()
以及 __next__()
這兩個魔術方法就可以了。不過這是在 Python 層級的定義,我們來看看在 CPython 是怎麼實作的。
在 Python 要建立一個迭代器,可以使用內建函數 iter()
:
iter([9, 5, 2, 7])
所以我們先從這個函數的實作原始碼看起:
// 檔案:Python/clinic/bltinmodule.c.h
static PyObject *
builtin_iter(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
{
PyObject *return_value = NULL;
PyObject *object;
PyObject *sentinel = NULL;
if (!_PyArg_CheckPositional("iter", nargs, 1, 2)) {
goto exit;
}
object = args[0];
if (nargs < 2) {
goto skip_optional;
}
sentinel = args[1];
skip_optional:
return_value = builtin_iter_impl(module, object, sentinel);
exit:
return return_value;
}
看起來真正實作的函數是 builtin_iter_impl()
:
// 檔案:Python/bltinmodule.c
static PyObject *
builtin_iter_impl(PyObject *module, PyObject *object, PyObject *sentinel)
{
if (sentinel == NULL)
return PyObject_GetIter(object);
if (!PyCallable_Check(object)) {
PyErr_SetString(PyExc_TypeError,
"iter(object, sentinel): object must be callable");
return NULL;
}
return PyCallIter_New(object, sentinel);
}
還滿容易懂的,如果沒有帶「哨兵(sentinel)」的話就直接呼叫 PyObject_GetIter()
,否則就是 PyCallIter_New()
。但這裡的哨兵是什麼意思?
其實哨兵的意思是如果迭代器回傳的值等於哨兵的話就停止迭代,舉個例子:
from random import randint
numbers = iter(lambda: randint(1, 10), 7)
for num in numbers:
print(num)
因為我在 iter()
的第二個參數帶了 7,所以上面這段程式碼會不斷的產生 1 到 10 之間的隨機數字,直到數字等於 7 為止。如果沒有帶哨兵的話就會一直迭代下去。
我們先從比較簡單的 PyCallIter_New()
開始看:
// 檔案:Objects/iterobject.c
PyObject *
PyCallIter_New(PyObject *callable, PyObject *sentinel)
{
calliterobject *it;
it = PyObject_GC_New(calliterobject, &PyCallIter_Type);
if (it == NULL)
return NULL;
it->it_callable = Py_NewRef(callable);
it->it_sentinel = Py_NewRef(sentinel);
_PyObject_GC_TRACK(it);
return (PyObject *)it;
}
這裡會建立一個 calliterobject
結構的物件,這個結構是用來儲存迭代器的資訊:
// 檔案:Objects/iterobject.c
typedef struct {
PyObject_HEAD
PyObject *it_callable;
PyObject *it_sentinel;
} calliterobject;
還滿單純的,我們順便看一下 PyCallIter_Type
的結構:
// 檔案:Objects/iterobject.c
PyTypeObject PyCallIter_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
"callable_iterator", /* tp_name */
sizeof(calliterobject), /* tp_basicsize */
0, /* tp_itemsize */
// ... 略 ...
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
(iternextfunc)calliter_iternext, /* tp_iternext */
calliter_methods, /* tp_methods */
};
這裡的重點應該在 tp_iter
跟 tp_iternext
,tp_iter
會回傳迭代器物件自己,而 tp_iternext
則應該要回傳下一個元素,其實這就是我們前面提到迭代器協議需要實作的 __iter__
以及 __next__
方法。
// 檔案:Objects/object.c
PyObject *
PyObject_SelfIter(PyObject *obj)
{
return Py_NewRef(obj);
}
就真的是回傳自己這個迭代器物件而已,很簡單,再看看 calliter_iternext()
函數:
// 檔案:Objects/iterobject.c
static PyObject *
calliter_iternext(calliterobject *it)
{
PyObject *result;
// ... 錯誤處理 ...
result = _PyObject_CallNoArgs(it->it_callable);
if (result != NULL && it->it_sentinel != NULL){
int ok;
ok = PyObject_RichCompareBool(it->it_sentinel, result, Py_EQ);
if (ok == 0) {
return result;
}
if (ok > 0) {
Py_CLEAR(it->it_callable);
Py_CLEAR(it->it_sentinel);
}
}
else if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
PyErr_Clear();
Py_CLEAR(it->it_callable);
Py_CLEAR(it->it_sentinel);
}
Py_XDECREF(result);
return NULL;
}
中間使用 PyObject_RichCompareBool()
函數比對是不是等於哨兵,如果是的話就停止迭代,並且把 it_callable
以及 it_sentinel
的值清空,不然的話就回傳元素。
這裡可以另外看一個小亮點,在這裡也可以看到如果迭代器回傳的值是 StopIteration
的時候,會呼叫 PyErr_Clear()
清除當前執行緒的錯誤狀態並停止迭代,這也是為什麼一般我們用 next()
拿下一個拿到沒東西的時候會拋出 StopIteration
例外,但在 for
迴圈或是串列推導式裡不會出錯的原因。
整個看起來算滿簡單的。我們再看另一個 PyObject_GetIter()
函數,這個就稍微複雜一點:
// 檔案:Objects/abstract.c
PyObject *
PyObject_GetIter(PyObject *o)
{
PyTypeObject *t = Py_TYPE(o);
getiterfunc f;
f = t->tp_iter;
if (f == NULL) {
if (PySequence_Check(o))
return PySeqIter_New(o);
return type_error("'%.200s' object is not iterable", o);
}
else {
PyObject *res = (*f)(o);
if (res != NULL && !PyIter_Check(res)) {
PyErr_Format(PyExc_TypeError,
"iter() returned non-iterator "
"of type '%.100s'",
Py_TYPE(res)->tp_name);
Py_SETREF(res, NULL);
}
return res;
}
}
先看看有沒有實作 tp_iter
成員,如果有就呼叫它:
PyObject *res = (*f)(o);
這行程式碼就是在做這件事,不過這有一些有趣的細節待會看。
如果沒有沒有實作 tp_iter
成員的話沒關係,就再檢查是不是一種序列,如果是的話就用這個序列建立一個泛用型的 PySeqIter_New()
的迭代器,再看看這個 PySeqIter_New()
在做什麼:
// 檔案:Objects/iterobject.c
PyObject *
PySeqIter_New(PyObject *seq)
{
seqiterobject *it;
if (!PySequence_Check(seq)) {
PyErr_BadInternalCall();
return NULL;
}
it = PyObject_GC_New(seqiterobject, &PySeqIter_Type);
if (it == NULL)
return NULL;
it->it_index = 0;
it->it_seq = Py_NewRef(seq);
_PyObject_GC_TRACK(it);
return (PyObject *)it;
}
這個函數會建立一個 seqiterobject
結構的物件,這個結構的設計滿簡單的:
// 檔案:Objects/iterobject.c
typedef struct {
PyObject_HEAD
Py_ssize_t it_index;
PyObject *it_seq;
} seqiterobject;
it_index
記錄的是目前迭代的位置,而 it_seq
則是指向被迭代的序列物件。再看一下 PySeqIter_Type
的結構:
// 檔案:Objects/iterobject.c
PyTypeObject PySeqIter_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
"iterator", /* tp_name */
sizeof(seqiterobject), /* tp_basicsize */
// ... 略 ...
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
iter_iternext, /* tp_iternext */
seqiter_methods, /* tp_methods */
0, /* tp_members */
};
這個 tp_iter
成員也是回傳自己而已,來看看 tp_iternext
的實作:
// 檔案:Objects/iterobject.c
static PyObject *
iter_iternext(PyObject *iterator)
{
seqiterobject *it;
PyObject *seq;
PyObject *result;
assert(PySeqIter_Check(iterator));
it = (seqiterobject *)iterator;
seq = it->it_seq;
// ... 錯誤處理 ...
result = PySequence_GetItem(seq, it->it_index);
if (result != NULL) {
it->it_index++;
return result;
}
if (PyErr_ExceptionMatches(PyExc_IndexError) ||
PyErr_ExceptionMatches(PyExc_StopIteration))
{
PyErr_Clear();
it->it_seq = NULL;
Py_DECREF(seq);
}
return NULL;
}
PySequence_GetItem()
函數根據 it_index
索引值去拿序列中的元素,如果拿到的話就回傳,不然的話就停止迭代。這裡也同樣也可以看到如果拿到 IndexError
或是 StopIteration
例外的時候就停止迭代,而且有 PyErr_Clear()
所以不會引發錯誤。
如果在 iter()
函數裡傳入不同的可迭代物件,會得到不同的迭代器物件,我們來看看不同的迭代器物件長什麼樣子:
>>> iter([])
<list_iterator object>
>>> iter(range(0))
<range_iterator object>
>>> iter({})
<dict_keyiterator object>
>>> iter('hello')
<str_ascii_iterator object>
>>> iter('七龍珠')
<str_iterator object>
怎麼這麼多種?這是因為不同的可迭代物件有不同的迭代器實作,在剛才的 PyObject_GetIter()
函數裡的這一行:
PyObject *res = (*f)(o);
就是呼叫 tp_iter
成員的實作函數,並且把目前這個可迭代物件傳進去。不同的資料型態,可能有著不同的 tp_iter
實作,例如串列:
// 檔案:Objects/listobject.c
static PyObject *
list_iter(PyObject *seq)
{
_PyListIterObject *it;
// ... 錯誤處理 ...
it = PyObject_GC_New(_PyListIterObject, &PyListIter_Type);
if (it == NULL)
return NULL;
it->it_index = 0;
it->it_seq = (PyListObject *)Py_NewRef(seq);
_PyObject_GC_TRACK(it);
return (PyObject *)it;
}
這裡會做出一個 PyListIter_Type
類型的迭代器物件,這個迭代器物件的 tp_iternext
成員實作如下:
// 檔案:Objects/listobject.c
static PyObject *
listiter_next(_PyListIterObject *it)
{
PyListObject *seq;
PyObject *item;
// ... 錯誤處理 ...
if (it->it_index < PyList_GET_SIZE(seq)) {
item = PyList_GET_ITEM(seq, it->it_index);
++it->it_index;
return Py_NewRef(item);
}
it->it_seq = NULL;
Py_DECREF(seq);
return NULL;
}
串列的 tp_iternext
比較簡單。如果傳入的可迭代物件是字串,就會看看這個字全部都是 ASCII 還是有其它編碼而決定會建立哪一種迭代器,字典、範圍都是一樣的做法。也就是因為這樣,所以在上面會看到不同的迭代器物件。
有興趣的話,可再順著同樣的思路去追看看由範圍、字串跟字典這些物件所建立的迭代器物件的 tp_iternext
實作,這樣就能了解不同的迭代器物件是怎麼運作的。
同樣都是可以被 next()
函數操作的物件,跟上個章節介紹的產生器比起來,迭代器的實作簡單多了 :)
本文同步刊載於 「為你自己學 Python - 轉呀轉呀七彩迭代器)」