Skip to content

Instantly share code, notes, and snippets.

@MadcowD
Created October 6, 2024 21:53
Show Gist options
  • Save MadcowD/1acc45a642eefe6110b3abd5da4eb155 to your computer and use it in GitHub Desktop.
Save MadcowD/1acc45a642eefe6110b3abd5da4eb155 to your computer and use it in GitHub Desktop.
patch
--- original_generator.c
+++ modified_generator.c
@@ -1,6 +1,7 @@
/* Generator object implementation */
#define _PY_INTERPRETER
+#include "pycore_exceptions.h" // For accessing exception attributes
#include "Python.h"
#include "pycore_call.h" // _PyObject_CallNoArgs()
#include "pycore_ceval.h" // _PyEval_EvalFrame()
@@ -349,6 +350,31 @@
/* No need to visit cr_origin, because it's just tuples/str/int, so can't
participate in a reference cycle. */
Py_VISIT(gen->gi_exc_state.exc_value);
return 0;
+}
+
+/*
+ * Set StopAsyncIteration with specified value. Value can be any object
+ * or NULL.
+ *
+ * Returns 0 if StopAsyncIteration is set and -1 if any other exception is set.
+ */
+int
+_PyAsyncGen_SetStopAsyncIterationValue(PyObject *value)
+{
+ PyObject *e;
+
+ if (value == NULL) {
+ PyErr_SetNone(PyExc_StopAsyncIteration);
+ return 0;
+ }
+
+ /* Construct StopAsyncIteration with the provided value */
+ e = PyObject_CallFunctionObjArgs(PyExc_StopAsyncIteration, value, NULL);
+ if (e == NULL) {
+ return -1;
+ }
+ PyErr_SetObject(PyExc_StopAsyncIteration, e);
+ Py_DECREF(e);
+ return 0;
+}
+
+/*
+ * If StopAsyncIteration exception is set, fetches its 'value'
+ * attribute if any, otherwise sets pvalue to None.
+ *
+ * Returns 0 if StopAsyncIteration is set and value is fetched,
+ * -1 otherwise.
+ */
+int
+_PyAsyncGen_FetchStopAsyncIterationValue(PyObject **pvalue)
+{
+ if (PyErr_ExceptionMatches(PyExc_StopAsyncIteration)) {
+ PyObject *exc = PyErr_GetRaisedException();
+ PyObject *value = PyObject_GetAttr(exc, &_Py_ID(value));
+ Py_DECREF(exc);
+ if (value == NULL) {
+ if (PyErr_ExceptionMatches(PyExc_AttributeError)) {
+ PyErr_Clear();
+ *pvalue = Py_NewRef(Py_None);
+ return 0;
+ }
+ return -1;
+ }
+ *pvalue = value;
+ return 0;
+ }
+ return -1;
}
PyCodeObject *
PyGen_GetCode(PyGenObject *gen) {
@@ -1133,6 +1169,33 @@
return result ? PYGEN_RETURN : PYGEN_ERROR;
}
+/*
+ * If StopAsyncIteration exception is set, fetches its 'value'
+ * attribute if any, otherwise sets pvalue to None.
+ *
+ * Returns 0 if StopAsyncIteration is set and value is fetched,
+ * -1 otherwise.
+ */
+int
+_PyAsyncGen_FetchStopAsyncIterationValue(PyObject **pvalue)
+{
+ if (PyErr_ExceptionMatches(PyExc_StopAsyncIteration)) {
+ PyObject *exc = PyErr_GetRaisedException();
+ PyObject *value = PyObject_GetAttr(exc, &_Py_ID(value));
+ Py_DECREF(exc);
+ if (value == NULL) {
+ if (PyErr_ExceptionMatches(PyExc_AttributeError)) {
+ PyErr_Clear();
+ *pvalue = Py_NewRef(Py_None);
+ return 0;
+ }
+ return -1;
+ }
+ *pvalue = value;
+ return 0;
+ }
+ return -1;
+}
+
static PySendResult
gen_send_ex2(PyGenObject *gen, PyObject *arg, PyObject **presult,
int exc, int closing)
{
PyThreadState *tstate = _PyThreadState_GET();
_PyInterpreterFrame *frame = &gen->gi_iframe;
*presult = NULL;
if (gen->gi_frame_state == FRAME_CREATED && arg && arg != Py_None) {
const char *msg = "can't send non-None value to a "
"just-started generator";
if (PyCoro_CheckExact(gen)) {
msg = NON_INIT_CORO_MSG;
}
else if (PyAsyncGen_CheckExact(gen)) {
msg = "can't send non-None value to a "
"just-started async generator";
}
PyErr_SetString(PyExc_TypeError, msg);
return PYGEN_ERROR;
}
if (gen->gi_frame_state == FRAME_EXECUTING) {
const char *msg = "generator already executing";
if (PyCoro_CheckExact(gen)) {
msg = "coroutine already executing";
}
else if (PyAsyncGen_CheckExact(gen)) {
msg = "async generator already executing";
}
PyErr_SetString(PyExc_ValueError, msg);
return PYGEN_ERROR;
}
if (FRAME_STATE_FINISHED(gen->gi_frame_state)) {
if (PyCoro_CheckExact(gen) && !closing) {
/* `gen` is an exhausted coroutine: raise an error,
except when called from gen_close(), which should
always be a silent method. */
PyErr_SetString(
PyExc_RuntimeError,
"cannot reuse already awaited coroutine");
}
else if (arg && !exc) {
/* `gen` is an exhausted generator:
only return value if called from send(). */
*presult = Py_NewRef(Py_None);
return PYGEN_RETURN;
}
return PYGEN_ERROR;
}
assert((gen->gi_frame_state == FRAME_CREATED) ||
FRAME_STATE_SUSPENDED(gen->gi_frame_state));
/* Push arg onto the frame's value stack */
PyObject *arg_obj = arg ? arg : Py_None;
_PyFrame_StackPush(frame, PyStackRef_FromPyObjectNew(arg_obj));
_PyErr_StackItem *prev_exc_info = tstate->exc_info;
gen->gi_exc_state.previous_item = prev_exc_info;
tstate->exc_info = &gen->gi_exc_state;
if (exc) {
assert(_PyErr_Occurred(tstate));
_PyErr_ChainStackItem();
}
gen->gi_frame_state = FRAME_EXECUTING;
EVAL_CALL_STAT_INC(EVAL_CALL_GENERATOR);
PyObject *result = _PyEval_EvalFrame(tstate, frame, exc);
assert(tstate->exc_info == prev_exc_info);
assert(gen->gi_exc_state.previous_item == NULL);
assert(gen->gi_frame_state != FRAME_EXECUTING);
assert(frame->previous == NULL);
/* If the generator just returned (as opposed to yielding), signal
* that the generator is exhausted. */
if (result) {
if (FRAME_STATE_SUSPENDED(gen->gi_frame_state)) {
*presult = result;
return PYGEN_NEXT;
}
if (PyAsyncGen_CheckExact(gen)) {
/* For async generators, set StopAsyncIteration with the return value */
if (_PyAsyncGen_SetStopAsyncIterationValue(result) < 0) {
/* Failed to set StopAsyncIteration */
return PYGEN_ERROR;
}
Py_DECREF(result);
*presult = NULL;
return PYGEN_RETURN;
}
assert(result == Py_None || !PyAsyncGen_CheckExact(gen));
if (result == Py_None && !PyAsyncGen_CheckExact(gen) && !arg) {
/* Return NULL if called by gen_iternext() */
Py_CLEAR(result);
}
}
else {
assert(!PyErr_ExceptionMatches(PyExc_StopIteration));
assert(!PyAsyncGen_CheckExact(gen) ||
!PyErr_ExceptionMatches(PyExc_StopAsyncIteration));
}
assert(gen->gi_exc_state.exc_value == NULL);
assert(gen->gi_frame_state == FRAME_CLEARED);
*presult = result;
return result ? PYGEN_RETURN : PYGEN_ERROR;
}
@@ -1354,6 +1491,31 @@
return result ? PYGEN_RETURN : PYGEN_ERROR;
}
/* Coroutine Object */
+/*
+ * If StopAsyncIteration exception is set, fetches its 'value'
+ * attribute if any, otherwise sets pvalue to None.
+ *
+ * Returns 0 if StopAsyncIteration is set and value is fetched,
+ * -1 otherwise.
+ */
+int
+_PyAsyncGen_FetchStopAsyncIterationValue(PyObject **pvalue)
+{
+ if (PyErr_ExceptionMatches(PyExc_StopAsyncIteration)) {
+ PyObject *exc = PyErr_GetRaisedException();
+ PyObject *value = PyObject_GetAttr(exc, &_Py_ID(value));
+ Py_DECREF(exc);
+ if (value == NULL) {
+ if (PyErr_ExceptionMatches(PyExc_AttributeError)) {
+ PyErr_Clear();
+ *pvalue = Py_NewRef(Py_None);
+ return 0;
+ }
+ return -1;
+ }
+ *pvalue = value;
+ return 0;
+ }
+ return -1;
+}
+
static PyObject *
gen_send_ex(PyGenObject *gen, PyObject *arg, int exc, int closing)
{
PyObject *result;
if (gen_send_ex2(gen, arg, &result, exc, closing) == PYGEN_RETURN) {
if (PyAsyncGen_CheckExact(gen)) {
/* For async generators, retrieve the return value from StopAsyncIteration */
- PyObject *exc = PyErr_GetRaisedException();
- if (_PyGen_FetchStopIterationValue(&result) == 0) {
- /* Attach the value to StopAsyncIteration */
- PyErr_SetObject(PyExc_StopAsyncIteration, result);
- Py_DECREF(result);
+ /* Fetch the value from StopAsyncIteration */
+ if (_PyAsyncGen_FetchStopAsyncIterationValue(&result) == 0) {
+ /* The value is already set in the exception */
+ /* Return the value to the caller */
+ return result;
+ }
+ /* If fetching the value failed, propagate the exception */
+ return NULL;
}
else if (result == Py_None) {
PyErr_SetNone(PyExc_StopIteration);
Py_CLEAR(result);
}
else {
_PyGen_SetStopIterationValue(result);
}
Py_CLEAR(result);
}
return result;
}
@@ -1663,6 +1767,23 @@
return NULL;
}
+/*
+ * Finalize function for async generators to handle return values.
+ */
+void
+_PyAsyncGen_Finalize(PyObject *self)
+{
+ PyAsyncGenObject *ag = (PyAsyncGenObject *)self;
+
+ if (FRAME_STATE_FINISHED(ag->ag_frame_state)) {
+ /* Async generator isn't paused, so no need to close */
+ return;
+ }
+
+ /* Save the current exception, if any. */
+ PyObject *exc = PyErr_GetRaisedException();
+
+ /* Close the generator */
+ PyObject *res = gen_close((PyObject*)ag, NULL);
+ if (res == NULL) {
+ if (PyErr_Occurred()) {
+ PyErr_WriteUnraisable(self);
+ }
+ }
+ else {
+ Py_DECREF(res);
+ }
+
+ /* Restore the saved exception. */
+ PyErr_SetRaisedException(exc);
+}
+
void
_PyGen_Finalize(PyObject *self)
{
@@ -1748,6 +1890,7 @@
}
}
+ if (PyAsyncGen_CheckExact(self)) {
/* Save the current exception, if any. */
PyObject *exc = PyErr_GetRaisedException();
@@ -1758,6 +1901,16 @@
PyErr_SetRaisedException(exc);
}
else {
+ if (PyAsyncGen_CheckExact(self)) {
+ /* Call the specific async generator finalizer */
+ _PyAsyncGen_Finalize(self);
+ }
+
+ /* Existing finalization logic for regular generators and coroutines */
PyObject *res = gen_close((PyObject*)gen, NULL);
if (res == NULL) {
if (PyErr_Occurred()) {
PyErr_WriteUnraisable(self);
}
else {
Py_DECREF(res);
}
}
}
}
@@ -2333,6 +2456,32 @@
/* Normal finalization steps */
Py_CLEAR(ags->ags_gen);
Py_CLEAR(ags->ags_sendval);
+
+ /* Finalize return value if stored (optional) */
+ if (ags->ags_gen->ag_return_value) {
+ Py_CLEAR(ags->ags_gen->ag_return_value);
+ }
+
_PyObject_GC_UNTRACK((PyObject*)ags);
PyObject_GC_Del(ags);
}
+/* Async Generator AThrow awaitable modifications */
+static void
+async_gen_athrow_finalize(PyObject *self)
+{
+ PyAsyncGenAThrow *agt = _PyAsyncGenAThrow_CAST(self);
+ if (agt->agt_state == AWAITABLE_STATE_INIT) {
+ PyObject *method = agt->agt_args ? &_Py_ID(athrow) : &_Py_ID(aclose);
+ _PyErr_WarnUnawaitedAgenMethod(agt->agt_gen, method);
+ }
+}
+
+/*
+ * Modify async_gen_asend_send to handle return values.
+ */
+static PyObject *
+async_gen_asend_send(PyObject *self, PyObject *arg)
+{
+ PyAsyncGenASend *o = _PyAsyncGenASend_CAST(self);
+ if (o->ags_state == AWAITABLE_STATE_CLOSED) {
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ "cannot reuse already awaited __anext__()/asend()");
+ return NULL;
+ }
+
+ if (o->ags_state == AWAITABLE_STATE_INIT) {
+ if (o->ags_gen->ag_running_async) {
+ o->ags_state = AWAITABLE_STATE_CLOSED;
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ "anext(): asynchronous generator is already running");
+ return NULL;
+ }
+
+ if (arg == NULL || arg == Py_None) {
+ arg = o->ags_sendval;
+ }
+ o->ags_state = AWAITABLE_STATE_ITER;
+ }
+
+ o->ags_gen->ag_running_async = 1;
+ PyObject *result = gen_send((PyObject*)o->ags_gen, arg);
+ result = async_gen_unwrap_value(o->ags_gen, result);
+
+ if (result == NULL) {
+ o->ags_state = AWAITABLE_STATE_CLOSED;
+ }
+
+ return result;
+}
+
/*
* Set StopIteration with specified value. Value can be arbitrary object
* or NULL.
@@ -2450,6 +2623,13 @@
return NULL;
}
+/* Fetch the value from StopAsyncIteration */
+static int
+async_gen_unwrap_stopasynciteration(PyObject *result)
+{
+ if (_PyAsyncGen_FetchStopAsyncIterationValue(&result) == 0) {
+ /* The value is already set in the exception */
+ return 0;
+ }
+ return -1;
+}
+
static PyObject *
async_gen_aclose(PyAsyncGenObject *o, PyObject *arg)
{
@@ -2548,6 +2738,7 @@
o->agt_gen->ag_running_async = 0;
o->agt_state = AWAITABLE_STATE_CLOSED;
}
+/* Async generator throw modifications */
static PyObject *
async_gen_athrow(PyAsyncGenObject *o, PyObject *args)
{
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment