iT邦幫忙

2024 iThome 鐵人賽

DAY 28
0
Python

為你自己讀 CPython 原始碼系列 第 28

Day 28 - 轉呀轉呀七彩迭代器

  • 分享至 

  • xImage
  •  

本文同步刊載於 「為你自己學 Python - 轉呀轉呀七彩迭代器)

轉呀轉呀七彩迭代器

為你自己學 Python

在 Python 裡迭代器(Iterator)使用的頻率很高,讓我們可以不用像其它程式語言一樣用 for 迴圈就能遍歷各種「容器」,而且這所謂的容器還不是只有串列,字典、字串、範圍(Range)這些都能用類似的操作,這個章節我們就來看看迭代器是怎麼實作的。

迭代器協議

在「為你自己學 Python」的物件導向程式設計 - 進階篇有提到三個看起來有點像的東西,分別是 Iteration、Iterable 以及 Iterator,很快的複習一下:

  • Iteration,「迭代」,名詞,指的是遍歷某個物件裡面所有元素的過程。
  • Iterable,「可迭代的」,形容詞,是指可以被進行迭代的物件,本章節提到的「可迭代物件」就是指它。
  • Iterator,「迭代器」,名詞,有點像是一個容器,我們可以用特定的方法遍歷這個容器中的每個元素。

根據 Python 對迭代器的定義,只要有實作「迭代器協議(Iterator Protocol)」的物件,就能被稱之迭代器。迭代器協議的內容也很簡單,只要有實作 __iter__() 以及 __next__() 這兩個魔術方法就可以了。不過這是在 Python 層級的定義,我們來看看在 CPython 是怎麼實作的。

在 Python 要建立一個迭代器,可以使用內建函數 iter()

iter([9, 5, 2, 7])

所以我們先從這個函數的實作原始碼看起:

// 檔案:Python/clinic/bltinmodule.c.h

static PyObject *
builtin_iter(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
{
    PyObject *return_value = NULL;
    PyObject *object;
    PyObject *sentinel = NULL;

    if (!_PyArg_CheckPositional("iter", nargs, 1, 2)) {
        goto exit;
    }
    object = args[0];
    if (nargs < 2) {
        goto skip_optional;
    }
    sentinel = args[1];
skip_optional:
    return_value = builtin_iter_impl(module, object, sentinel);

exit:
    return return_value;
}

看起來真正實作的函數是 builtin_iter_impl()

// 檔案:Python/bltinmodule.c

static PyObject *
builtin_iter_impl(PyObject *module, PyObject *object, PyObject *sentinel)
{
    if (sentinel == NULL)
        return PyObject_GetIter(object);
    if (!PyCallable_Check(object)) {
        PyErr_SetString(PyExc_TypeError,
                        "iter(object, sentinel): object must be callable");
        return NULL;
    }
    return PyCallIter_New(object, sentinel);
}

還滿容易懂的,如果沒有帶「哨兵(sentinel)」的話就直接呼叫 PyObject_GetIter(),否則就是 PyCallIter_New()。但這裡的哨兵是什麼意思?

站住,口令,誰!

其實哨兵的意思是如果迭代器回傳的值等於哨兵的話就停止迭代,舉個例子:

from random import randint

numbers = iter(lambda: randint(1, 10), 7)

for num in numbers:
    print(num)

因為我在 iter() 的第二個參數帶了 7,所以上面這段程式碼會不斷的產生 1 到 10 之間的隨機數字,直到數字等於 7 為止。如果沒有帶哨兵的話就會一直迭代下去。

我們先從比較簡單的 PyCallIter_New() 開始看:

// 檔案:Objects/iterobject.c

PyObject *
PyCallIter_New(PyObject *callable, PyObject *sentinel)
{
    calliterobject *it;
    it = PyObject_GC_New(calliterobject, &PyCallIter_Type);
    if (it == NULL)
        return NULL;
    it->it_callable = Py_NewRef(callable);
    it->it_sentinel = Py_NewRef(sentinel);
    _PyObject_GC_TRACK(it);
    return (PyObject *)it;
}

這裡會建立一個 calliterobject 結構的物件,這個結構是用來儲存迭代器的資訊:

// 檔案:Objects/iterobject.c

typedef struct {
    PyObject_HEAD
    PyObject *it_callable;
    PyObject *it_sentinel;
} calliterobject;

還滿單純的,我們順便看一下 PyCallIter_Type 的結構:

// 檔案:Objects/iterobject.c

PyTypeObject PyCallIter_Type = {
    PyVarObject_HEAD_INIT(&PyType_Type, 0)
    "callable_iterator",                        /* tp_name */
    sizeof(calliterobject),                     /* tp_basicsize */
    0,                                          /* tp_itemsize */
    // ... 略 ...
    0,                                          /* tp_richcompare */
    0,                                          /* tp_weaklistoffset */
    PyObject_SelfIter,                          /* tp_iter */
    (iternextfunc)calliter_iternext,            /* tp_iternext */
    calliter_methods,                           /* tp_methods */
};

這裡的重點應該在 tp_itertp_iternexttp_iter 會回傳迭代器物件自己,而 tp_iternext 則應該要回傳下一個元素,其實這就是我們前面提到迭代器協議需要實作的 __iter__ 以及 __next__ 方法。

// 檔案:Objects/object.c

PyObject *
PyObject_SelfIter(PyObject *obj)
{
    return Py_NewRef(obj);
}

就真的是回傳自己這個迭代器物件而已,很簡單,再看看 calliter_iternext() 函數:

// 檔案:Objects/iterobject.c

static PyObject *
calliter_iternext(calliterobject *it)
{
    PyObject *result;
    // ... 錯誤處理 ...

    result = _PyObject_CallNoArgs(it->it_callable);
    if (result != NULL && it->it_sentinel != NULL){
        int ok;

        ok = PyObject_RichCompareBool(it->it_sentinel, result, Py_EQ);
        if (ok == 0) {
            return result;
        }

        if (ok > 0) {
            Py_CLEAR(it->it_callable);
            Py_CLEAR(it->it_sentinel);
        }
    }
    else if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
        PyErr_Clear();
        Py_CLEAR(it->it_callable);
        Py_CLEAR(it->it_sentinel);
    }
    Py_XDECREF(result);
    return NULL;
}

中間使用 PyObject_RichCompareBool() 函數比對是不是等於哨兵,如果是的話就停止迭代,並且把 it_callable 以及 it_sentinel 的值清空,不然的話就回傳元素。

這裡可以另外看一個小亮點,在這裡也可以看到如果迭代器回傳的值是 StopIteration 的時候,會呼叫 PyErr_Clear() 清除當前執行緒的錯誤狀態並停止迭代,這也是為什麼一般我們用 next() 拿下一個拿到沒東西的時候會拋出 StopIteration 例外,但在 for 迴圈或是串列推導式裡不會出錯的原因。

整個看起來算滿簡單的。我們再看另一個 PyObject_GetIter() 函數,這個就稍微複雜一點:

// 檔案:Objects/abstract.c

PyObject *
PyObject_GetIter(PyObject *o)
{
    PyTypeObject *t = Py_TYPE(o);
    getiterfunc f;

    f = t->tp_iter;
    if (f == NULL) {
        if (PySequence_Check(o))
            return PySeqIter_New(o);
        return type_error("'%.200s' object is not iterable", o);
    }
    else {
        PyObject *res = (*f)(o);
        if (res != NULL && !PyIter_Check(res)) {
            PyErr_Format(PyExc_TypeError,
                         "iter() returned non-iterator "
                         "of type '%.100s'",
                         Py_TYPE(res)->tp_name);
            Py_SETREF(res, NULL);
        }
        return res;
    }
}

先看看有沒有實作 tp_iter 成員,如果有就呼叫它:

PyObject *res = (*f)(o);

這行程式碼就是在做這件事,不過這有一些有趣的細節待會看。

如果沒有沒有實作 tp_iter 成員的話沒關係,就再檢查是不是一種序列,如果是的話就用這個序列建立一個泛用型的 PySeqIter_New() 的迭代器,再看看這個 PySeqIter_New() 在做什麼:

// 檔案:Objects/iterobject.c

PyObject *
PySeqIter_New(PyObject *seq)
{
    seqiterobject *it;

    if (!PySequence_Check(seq)) {
        PyErr_BadInternalCall();
        return NULL;
    }
    it = PyObject_GC_New(seqiterobject, &PySeqIter_Type);
    if (it == NULL)
        return NULL;
    it->it_index = 0;
    it->it_seq = Py_NewRef(seq);
    _PyObject_GC_TRACK(it);
    return (PyObject *)it;
}

這個函數會建立一個 seqiterobject 結構的物件,這個結構的設計滿簡單的:

// 檔案:Objects/iterobject.c

typedef struct {
    PyObject_HEAD
    Py_ssize_t it_index;
    PyObject *it_seq;
} seqiterobject;

it_index 記錄的是目前迭代的位置,而 it_seq 則是指向被迭代的序列物件。再看一下 PySeqIter_Type 的結構:

// 檔案:Objects/iterobject.c

PyTypeObject PySeqIter_Type = {
    PyVarObject_HEAD_INIT(&PyType_Type, 0)
    "iterator",                                 /* tp_name */
    sizeof(seqiterobject),                      /* tp_basicsize */
    // ... 略 ...
    0,                                          /* tp_richcompare */
    0,                                          /* tp_weaklistoffset */
    PyObject_SelfIter,                          /* tp_iter */
    iter_iternext,                              /* tp_iternext */
    seqiter_methods,                            /* tp_methods */
    0,                                          /* tp_members */
};

這個 tp_iter 成員也是回傳自己而已,來看看 tp_iternext 的實作:

// 檔案:Objects/iterobject.c

static PyObject *
iter_iternext(PyObject *iterator)
{
    seqiterobject *it;
    PyObject *seq;
    PyObject *result;

    assert(PySeqIter_Check(iterator));
    it = (seqiterobject *)iterator;
    seq = it->it_seq;
    // ... 錯誤處理 ...

    result = PySequence_GetItem(seq, it->it_index);
    if (result != NULL) {
        it->it_index++;
        return result;
    }
    if (PyErr_ExceptionMatches(PyExc_IndexError) ||
        PyErr_ExceptionMatches(PyExc_StopIteration))
    {
        PyErr_Clear();
        it->it_seq = NULL;
        Py_DECREF(seq);
    }
    return NULL;
}

PySequence_GetItem() 函數根據 it_index 索引值去拿序列中的元素,如果拿到的話就回傳,不然的話就停止迭代。這裡也同樣也可以看到如果拿到 IndexError 或是 StopIteration 例外的時候就停止迭代,而且有 PyErr_Clear() 所以不會引發錯誤。

不同的迭代器?

如果在 iter() 函數裡傳入不同的可迭代物件,會得到不同的迭代器物件,我們來看看不同的迭代器物件長什麼樣子:

>>> iter([])
<list_iterator object>

>>> iter(range(0))
<range_iterator object>

>>> iter({})
<dict_keyiterator object>

>>> iter('hello')
<str_ascii_iterator object>

>>> iter('七龍珠')
<str_iterator object>

怎麼這麼多種?這是因為不同的可迭代物件有不同的迭代器實作,在剛才的 PyObject_GetIter() 函數裡的這一行:

PyObject *res = (*f)(o);

就是呼叫 tp_iter 成員的實作函數,並且把目前這個可迭代物件傳進去。不同的資料型態,可能有著不同的 tp_iter 實作,例如串列:

// 檔案:Objects/listobject.c

static PyObject *
list_iter(PyObject *seq)
{
    _PyListIterObject *it;

    // ... 錯誤處理 ...
    it = PyObject_GC_New(_PyListIterObject, &PyListIter_Type);
    if (it == NULL)
        return NULL;
    it->it_index = 0;
    it->it_seq = (PyListObject *)Py_NewRef(seq);
    _PyObject_GC_TRACK(it);
    return (PyObject *)it;
}

這裡會做出一個 PyListIter_Type 類型的迭代器物件,這個迭代器物件的 tp_iternext 成員實作如下:

// 檔案:Objects/listobject.c

static PyObject *
listiter_next(_PyListIterObject *it)
{
    PyListObject *seq;
    PyObject *item;

    // ... 錯誤處理 ...

    if (it->it_index < PyList_GET_SIZE(seq)) {
        item = PyList_GET_ITEM(seq, it->it_index);
        ++it->it_index;
        return Py_NewRef(item);
    }

    it->it_seq = NULL;
    Py_DECREF(seq);
    return NULL;
}

串列的 tp_iternext 比較簡單。如果傳入的可迭代物件是字串,就會看看這個字全部都是 ASCII 還是有其它編碼而決定會建立哪一種迭代器,字典、範圍都是一樣的做法。也就是因為這樣,所以在上面會看到不同的迭代器物件。

有興趣的話,可再順著同樣的思路去追看看由範圍、字串跟字典這些物件所建立的迭代器物件的 tp_iternext 實作,這樣就能了解不同的迭代器物件是怎麼運作的。

同樣都是可以被 next() 函數操作的物件,跟上個章節介紹的產生器比起來,迭代器的實作簡單多了 :)

本文同步刊載於 「為你自己學 Python - 轉呀轉呀七彩迭代器)


上一篇
Day 27 - 產生一個產生器
下一篇
Day 29 - 無所不在的描述器
系列文
為你自己讀 CPython 原始碼31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言