Skip to content

简易class绑定

本章将讲解并演示一个基本的类绑定示例。 创建一个简单的C++类,并将其绑定到Python中使用。

编写 CMakeLists.txt

使用 CMake 配置并管理此 C++ 项目。

text
mmath/
├── CMakeLists.txt
└── mmath.cpp
cmake
# CMakeLists.txt
cmake_minimum_required (VERSION 3.11)

project(PYTHON_TEST)

if(MSVC)
    add_compile_options(/utf-8 /EHsc)
endif()

include_directories("D:/Python/Python312/include")
link_directories("D:/Python/Python312/libs")

add_library(mmath SHARED "mmath.cpp")

if(WIN32)
    set_target_properties(mmath PROPERTIES SUFFIX ".pyd")
endif()

target_link_libraries(mmath PRIVATE python312.lib)

set_property(TARGET mmath PROPERTY CXX_STANDARD 20)

定义一个C++类

此处以Vec3为例,创建一个简单的三维向量类。

cpp
#include <iostream>

class Vec3
{
protected:
	double x, y, z;
public:
    Vec3() : x(0.0), y(0.0), z(0.0) {}
    Vec3(double x, double y, double z) : x(x), y(y), z(z) {}
    double getX() const { return x; }
    double getY() const { return y; }
    double getZ() const { return z; }
	void setX(double x) { this->x = x; }
	void setY(double y) { this->y = y; }
	void setZ(double z) { this->z = z; }

    double length() const
    {
        return std::sqrt(x * x + y * y + z * z);
    }

    // 归一化向量
    bool normalize()
    {
        double len = length();
        if (len < 1e-12) // 浮点抖动处理
        {
            return false;
        }
        x /= len;
        y /= len;
        z /= len;
        return true;
    }

    // 向量缩放
    void scale(double scalar)
    {
        x *= scalar;
        y *= scalar;
        z *= scalar;
	}

    // 向量加法
    void addVec3(const Vec3& other)
    {
        x += other.x;
        y += other.y;
        z += other.z;
	}

    // dot积运算
    double dot(const Vec3& other) const
    {
        return x * other.x + y * other.y + z * other.z;
    }

    // cross积运算
    void crossSelf(const Vec3& other)
    {
        double cx = y * other.z - z * other.y;
        double cy = z * other.x - x * other.z;
        double cz = x * other.y - y * other.x;
        x = cx;
        y = cy;
        z = cz;
    }
};

声明Py对象内存布局

自定义Python类型需要定义一个结构体来描述对象的内存布局。

cpp
struct PyVec3
{
    PyObject_HEAD   // Python对象头部填充宏(PyObject基础信息)
    Vec3* ptr;      // 额外内存字段 存放自定义数据 此处储存一份Vec3指针
};

编写构造与析构函数

为自定义类型实现newdealloc方法。

cpp
#include <iostream>
#include <Python.h>

// Py Vec3构造函数
static PyObject* PyVec3_new(PyTypeObject* type, PyObject* args, PyObject* kwds)
{
    // 分配Py对象内存
    PyVec3* self = (PyVec3*)type->tp_alloc(type, 0);
    if (self == NULL)
    {
        // 内存分配失败 设置异常信息并抛出异常
        PyErr_SetString(PyExc_TypeError, "对象分配失败");
        return NULL;
    }
    // 解析参数并初始化C++对象
    Py_ssize_t argCount = PyTuple_Size(args);
    if (argCount == 0)
    {
        // 无参构造
		self->ptr = new Vec3();
    }
    else if (argCount == 3)
    {
        // 三个参数构造
        double x = 0.0, y = 0.0, z = 0.0;
        if (!PyArg_ParseTuple(args, "ddd", &x, &y, &z))
        {
            Py_TYPE(self)->tp_free((PyObject*)self);    // 构造异常释放分配好的内存
            PyErr_SetString(PyExc_TypeError, "非数值参数无法解析");
            return NULL;
        }
		self->ptr = new Vec3(x, y, z);
    }
    else
    {
        // 参数数量不符合需求
        Py_TYPE(self)->tp_free((PyObject*)self);
        PyErr_SetString(PyExc_TypeError, "参数数量不符合需求");
        return NULL;
    }
    // 返回Py对象指针
    return (PyObject*)self;
}

// Py Vec3析构函数
static void PyVec3_dealloc(PyVec3* self)
{
    // 释放C++对象并触发对应的析构函数
    delete self->ptr;
    // 释放Py对象内存(请确保在C++对象释放之后否则指针信息将丢失)
    Py_TYPE(self)->tp_free((PyObject*)self);
}

定义PyClass信息表

PyClass需要一个PyTypeObject来描述类信息,如绑定方法表,构造函数等。

cpp
#include <iostream>
#include <Python.h>

// 定义PyVec3的方法表
static PyMethodDef PyVec3_methods[] = {
    {NULL, NULL, 0, NULL}
};

// 定义PyVec3的属性访问器(如:obj.x访问)
static PyGetSetDef PyVec3_getset[] = {
    {NULL}
};

// PyTypeObject详细参数功能可查阅Python官方文档
static PyTypeObject PyVec3Type = {
    PyVarObject_HEAD_INIT(NULL, 0)
    "mmath.Vec3",             /* tp_name */
    sizeof(PyVec3),           /* tp_basicsize */
    0,                        /* tp_itemsize */
    // 绑定析构函数 负责释放内存与C++对象
    (destructor) PyVec3_dealloc, /* tp_dealloc */
    0,                        /* tp_print */
    0,                        /* tp_getattr */
    0,                        /* tp_setattr */
    0,                        /* tp_reserved */
    0,                        /* tp_repr */
    0,                        /* tp_as_number */
    0,                        /* tp_as_sequence */
    0,                        /* tp_as_mapping */
    0,                        /* tp_hash  */
    0,                        /* tp_call */
    0,                        /* tp_str */
    0,                        /* tp_getattro */
    0,                        /* tp_setattro */
    0,                        /* tp_as_buffer */
    Py_TPFLAGS_DEFAULT,       /* tp_flags */
    "Vec3 objects",           /* tp_doc */
    0,                        /* tp_traverse */
    0,                        /* tp_clear */
    0,                        /* tp_richcompare */
    0,                        /* tp_weaklistoffset */
    0,                        /* tp_iter */
    0,                        /* tp_iternext */
    PyVec3_methods,           /* tp_methods */
    0,                        /* tp_members */
    PyVec3_getset,            /* tp_getset */
    0,                        /* tp_base */
    0,                        /* tp_dict */
    0,                        /* tp_descr_get */
    0,                        /* tp_descr_set */
    0,                        /* tp_dictoffset */
    0,                        /* tp_init */
    0,                        /* tp_alloc */
    // 绑定构造函数(内存分配与初始化)
    PyVec3_new,               /* tp_new */
};

添加方法绑定

为类添加方法绑定,以便在Python中调用。

cpp
#include <iostream>
#include <Python.h>

// getTuple实现: 将C++Vec3转换为tuple对象返回
static PyObject* PyVec3_getTuple(PyVec3* self, PyObject* Py_UNUSED(ignored))
{
    // 通过self->ptr获取C++对象
    return Py_BuildValue("(ddd)", self->ptr->getX(), self->ptr->getY(), self->ptr->getZ());
}

// length方法绑定: 获取向量长度
static PyObject* PyVec3_length(PyVec3* self, PyObject* Py_UNUSED(ignored))
{
	return PyFloat_FromDouble(self->ptr->length());
}

// normalize方法绑定: 归一化向量
static PyObject* PyVec3_normalize(PyVec3* self, PyObject* Py_UNUSED(ignored))
{
    return PyBool_FromLong(self->ptr->normalize() ? 1 : 0);
}

// scale方法绑定: 按比例缩放向量大小
static PyObject* PyVec3_scale(PyVec3* self, PyObject* args)
{
    double scalar;
    if (!PyArg_ParseTuple(args, "d", &scalar))
    {
        PyErr_SetString(PyExc_TypeError, "需要一个数值参数");
        return NULL;
    }
    self->ptr->scale(scalar);
    Py_RETURN_NONE;
}

// 补充方法表
static PyMethodDef PyVec3_methods[] = {
    {"getTuple", (PyCFunction) PyVec3_getTuple, METH_NOARGS, "转换为tuple对象"},
    {"length", (PyCFunction) PyVec3_length, METH_NOARGS, "获取向量长度"},
	{"normalize", (PyCFunction) PyVec3_normalize, METH_NOARGS, "归一化向量"},
    {"scale", (PyCFunction) PyVec3_scale, METH_VARARGS, "按比例缩放向量大小"},
    {NULL, NULL, 0, NULL}
};

添加属性绑定

为类添加属性绑定,以便在Python中直接访问/修改属性。

cpp
// 为方便复用 此处使用模板自动生成不同的属性绑定函数
template <auto Getter>
static PyObject* PyVec3_GetAttr(PyVec3* self, void* closure)
{
    // 调用C++对象的特定方法获取属性
    return PyFloat_FromDouble((self->ptr->*Getter)());
}

template <auto Setter>
static int PyVec3_SetAttr(PyVec3* self, PyObject* value, void* closure)
{
    if (!PyFloat_Check(value) && !PyLong_Check(value))
    {
        PyErr_SetString(PyExc_TypeError, "must be a number");
        return -1;
    }
    double val = PyFloat_AsDouble(value);
    // 调用C++对象的特定方法设置属性
    (self->ptr->*Setter)(val);
    return 0;
}

// 补充属性绑定函数
static PyGetSetDef PyVec3_getset[] = {
    {"x", (getter) PyVec3_GetAttr<&Vec3::getX>, (setter) PyVec3_SetAttr<&Vec3::setX>, "X 坐标", NULL},
    {"y", (getter) PyVec3_GetAttr<&Vec3::getY>, (setter) PyVec3_SetAttr<&Vec3::setY>, "Y 坐标", NULL},
    {"z", (getter) PyVec3_GetAttr<&Vec3::getZ>, (setter) PyVec3_SetAttr<&Vec3::setZ>, "Z 坐标", NULL},
    {NULL}
};

绑定class到模块

在完成类的定义后,需要将其绑定到Python模块中。

cpp
#include <iostream>
#include <Python.h>

// 模块入口函数
PyMODINIT_FUNC PyInit_mmath()
{
    static PyModuleDef moduleDef = {
        PyModuleDef_HEAD_INIT,
        "mmath",
        "测试C扩展数学运算库",
        -1,
        NULL, NULL, NULL, NULL, NULL
    };
    // 检查并补全默认字段
    if (PyType_Ready(&PyVec3Type) < 0)
    {
        return NULL;
    }
    PyObject* m = PyModule_Create(&moduleDef);
    Py_INCREF(&PyVec3Type); // 增加class的引用计数避免释放
    PyModule_AddObject(m, "Vec3", (PyObject*)&PyVec3Type);
    return m;
}

代码测试

完成编译后,可以在Python中测试此类。

python
import mmath

v3 = mmath.Vec3(11, 22, 33.5)
v3.y = 0
v3.normalize()          # 归一化向量
print(v3.getTuple())    # 输出计算结果

完整绑定实现

此处是完整的C++代码示例,包含类定义、方法绑定和模块入口。

cpp
#include <iostream>
#include <Python.h>
// By Zero123

class Vec3
{
protected:
	double x, y, z;
public:
    Vec3() : x(0.0), y(0.0), z(0.0) {}
    Vec3(double x, double y, double z) : x(x), y(y), z(z) {}
    double getX() const { return x; }
    double getY() const { return y; }
    double getZ() const { return z; }
	void setX(double x) { this->x = x; }
	void setY(double y) { this->y = y; }
	void setZ(double z) { this->z = z; }

    double length() const
    {
        return std::sqrt(x * x + y * y + z * z);
    }

    bool normalize()
    {
        double len = length();
        if (len < 1e-12) // 浮点抖动处理
        {
            return false;
        }
        x /= len;
        y /= len;
        z /= len;
        return true;
    }

    void scale(double scalar)
    {
        x *= scalar;
        y *= scalar;
        z *= scalar;
	}

    void addVec3(const Vec3& other)
    {
        x += other.x;
        y += other.y;
        z += other.z;
	}

    double dot(const Vec3& other) const
    {
        return x * other.x + y * other.y + z * other.z;
    }

    void crossSelf(const Vec3& other)
    {
        double cx = y * other.z - z * other.y;
        double cy = z * other.x - x * other.z;
        double cz = x * other.y - y * other.x;
        x = cx;
        y = cy;
        z = cz;
    }
};

// Py Class结构体声明
struct PyVec3
{
    PyObject_HEAD
    Vec3* ptr;
};

// Py Vec3构造函数
static PyObject* PyVec3_new(PyTypeObject* type, PyObject* args, PyObject* kwds)
{
    PyVec3* self = (PyVec3*)type->tp_alloc(type, 0);
    if (self == NULL)
    {
        PyErr_SetString(PyExc_TypeError, "对象分配失败");
        return NULL;
    }
    Py_ssize_t argCount = PyTuple_Size(args);
    if (argCount == 0)
    {
		self->ptr = new Vec3();
    }
    else if (argCount == 3)
    {
        double x = 0.0, y = 0.0, z = 0.0;
        if (!PyArg_ParseTuple(args, "ddd", &x, &y, &z))
        {
            Py_TYPE(self)->tp_free((PyObject*)self);    // 构造异常释放分配好的内存
            PyErr_SetString(PyExc_TypeError, "非数值参数无法解析");
            return NULL;
        }
		self->ptr = new Vec3(x, y, z);
    }
    else
    {
        Py_TYPE(self)->tp_free((PyObject*)self);
        PyErr_SetString(PyExc_TypeError, "参数数量不符合需求");
        return NULL;
    }
    return (PyObject*)self;
}

// Py Vec3析构函数
static void PyVec3_dealloc(PyVec3* self)
{
    delete self->ptr;
    Py_TYPE(self)->tp_free((PyObject*)self);
}

static PyObject* PyVec3_getX(PyVec3* self, PyObject* Py_UNUSED(ignored))
{
    return PyFloat_FromDouble(self->ptr->getX());
}

static PyObject* PyVec3_getY(PyVec3* self, PyObject* Py_UNUSED(ignored))
{
    return PyFloat_FromDouble(self->ptr->getY());
}

static PyObject* PyVec3_getZ(PyVec3* self, PyObject* Py_UNUSED(ignored))
{
    return PyFloat_FromDouble(self->ptr->getZ());
}

static PyObject* PyVec3_setX(PyVec3* self, PyObject* args)
{
    double v;
    if (!PyArg_ParseTuple(args, "d", &v))
    {
        PyErr_SetString(PyExc_TypeError, "需要一个数值参数");
        return NULL;
    }
    self->ptr->setX(v);
    Py_RETURN_NONE;
}

static PyObject* PyVec3_setY(PyVec3* self, PyObject* args)
{
    double v;
    if (!PyArg_ParseTuple(args, "d", &v))
    {
        PyErr_SetString(PyExc_TypeError, "需要一个数值参数");
        return NULL;
    }
    self->ptr->setY(v);
    Py_RETURN_NONE;
}

static PyObject* PyVec3_setZ(PyVec3* self, PyObject* args)
{
    double v;
    if (!PyArg_ParseTuple(args, "d", &v))
    {
        PyErr_SetString(PyExc_TypeError, "需要一个数值参数");
        return NULL;
    }
    self->ptr->setZ(v);
    Py_RETURN_NONE;
}

static PyObject* PyVec3_getTuple(PyVec3* self, PyObject* Py_UNUSED(ignored))
{
    return Py_BuildValue("(ddd)", self->ptr->getX(), self->ptr->getY(), self->ptr->getZ());
}

static PyObject* PyVec3_length(PyVec3* self, PyObject* Py_UNUSED(ignored))
{
	return PyFloat_FromDouble(self->ptr->length());
}

static PyObject* PyVec3_normalize(PyVec3* self, PyObject* Py_UNUSED(ignored))
{
    return PyBool_FromLong(self->ptr->normalize() ? 1 : 0);
}

static PyObject* PyVec3_setXYZ(PyVec3* self, PyObject* args)
{
    double x, y, z;
    if (!PyArg_ParseTuple(args, "ddd", &x, &y, &z))
    {
        PyErr_SetString(PyExc_TypeError, "参数解析错误");
        return NULL;
    }
	self->ptr->setX(x);
	self->ptr->setY(y);
	self->ptr->setZ(z);
    Py_RETURN_NONE;
}

static PyObject* PyVec3_scale(PyVec3* self, PyObject* args)
{
    double scalar;
    if (!PyArg_ParseTuple(args, "d", &scalar))
    {
        PyErr_SetString(PyExc_TypeError, "需要一个数值参数");
        return NULL;
    }
    self->ptr->scale(scalar);
    Py_RETURN_NONE;
}

static PyObject* PyVec3_addVec3(PyVec3*, PyObject*);
static PyObject* PyVec3_dot(PyVec3*, PyObject*);
static PyObject* PyVec3_crossSelf(PyVec3*, PyObject*);

static PyMethodDef PyVec3_methods[] = {
    {"getX", (PyCFunction) PyVec3_getX, METH_NOARGS, "获取 X 坐标"},
    {"getY", (PyCFunction) PyVec3_getY, METH_NOARGS, "获取 Y 坐标"},
    {"getZ", (PyCFunction) PyVec3_getZ, METH_NOARGS, "获取 Z 坐标"},
    {"setX", (PyCFunction) PyVec3_setX, METH_VARARGS, "设置 X 坐标"},
    {"setY", (PyCFunction) PyVec3_setY, METH_VARARGS, "设置 Y 坐标"},
    {"setZ", (PyCFunction) PyVec3_setZ, METH_VARARGS, "设置 Z 坐标"},
    {"getTuple", (PyCFunction) PyVec3_getTuple, METH_NOARGS, "转换为tuple对象"},
    {"setXYZ", (PyCFunction) PyVec3_setXYZ, METH_VARARGS, "批量设置xyz值"},
    {"length", (PyCFunction) PyVec3_length, METH_NOARGS, "获取向量长度"},
	{"normalize", (PyCFunction) PyVec3_normalize, METH_NOARGS, "归一化向量"},
    {"scale", (PyCFunction) PyVec3_scale, METH_VARARGS, "按比例缩放向量大小"},
    {"addVec3", (PyCFunction) PyVec3_addVec3, METH_VARARGS, "按比例缩放向量大小"},
    {"dot", (PyCFunction) PyVec3_dot, METH_VARARGS, "点积运算"},
	{"crossSelf", (PyCFunction) PyVec3_crossSelf, METH_VARARGS, "叉积运算"},
    {NULL, NULL, 0, NULL}
};

// 属性绑定
template <auto Getter>
static PyObject* PyVec3_GetAttr(PyVec3* self, void* closure)
{
    return PyFloat_FromDouble((self->ptr->*Getter)());
}

template <auto Setter>
static int PyVec3_SetAttr(PyVec3* self, PyObject* value, void* closure)
{
    if (!PyFloat_Check(value) && !PyLong_Check(value))
    {
        PyErr_SetString(PyExc_TypeError, "must be a number");
        return -1;
    }
    double val = PyFloat_AsDouble(value);
    (self->ptr->*Setter)(val);
    return 0;
}

static PyGetSetDef PyVec3_getset[] = {
    {"x", (getter) PyVec3_GetAttr<&Vec3::getX>, (setter) PyVec3_SetAttr<&Vec3::setX>, "X 坐标", NULL},
    {"y", (getter) PyVec3_GetAttr<&Vec3::getY>, (setter) PyVec3_SetAttr<&Vec3::setY>, "Y 坐标", NULL},
    {"z", (getter) PyVec3_GetAttr<&Vec3::getZ>, (setter) PyVec3_SetAttr<&Vec3::setZ>, "Z 坐标", NULL},
    {NULL}
};

static PyTypeObject PyVec3Type = {
    PyVarObject_HEAD_INIT(NULL, 0)
    "mmath.Vec3",             /* tp_name */
    sizeof(PyVec3),           /* tp_basicsize */
    0,                        /* tp_itemsize */
    (destructor) PyVec3_dealloc, /* tp_dealloc */
    0,                        /* tp_print */
    0,                        /* tp_getattr */
    0,                        /* tp_setattr */
    0,                        /* tp_reserved */
    0,                        /* tp_repr */
    0,                        /* tp_as_number */
    0,                        /* tp_as_sequence */
    0,                        /* tp_as_mapping */
    0,                        /* tp_hash  */
    0,                        /* tp_call */
    0,                        /* tp_str */
    0,                        /* tp_getattro */
    0,                        /* tp_setattro */
    0,                        /* tp_as_buffer */
    Py_TPFLAGS_DEFAULT,       /* tp_flags */
    "Vec3 objects",           /* tp_doc */
    0,                        /* tp_traverse */
    0,                        /* tp_clear */
    0,                        /* tp_richcompare */
    0,                        /* tp_weaklistoffset */
    0,                        /* tp_iter */
    0,                        /* tp_iternext */
    PyVec3_methods,           /* tp_methods */
    0,                        /* tp_members */
    PyVec3_getset,            /* tp_getset */
    0,                        /* tp_base */
    0,                        /* tp_dict */
    0,                        /* tp_descr_get */
    0,                        /* tp_descr_set */
    0,                        /* tp_dictoffset */
    0,                        /* tp_init */
    0,                        /* tp_alloc */
    PyVec3_new,               /* tp_new */
};

static PyObject* PyVec3_addVec3(PyVec3* self, PyObject* args)
{
    PyObject* otherObj;
    if (!PyArg_ParseTuple(args, "O!", &PyVec3Type, &otherObj))
    {
        PyErr_SetString(PyExc_TypeError, "需要一个 Vec3 对象");
        return NULL;
    }
    PyVec3* other = (PyVec3*)otherObj;
    self->ptr->addVec3(*other->ptr);
    Py_RETURN_NONE;
}

static PyObject* PyVec3_dot(PyVec3* self, PyObject* args)
{
    PyObject* otherObj;
    if (!PyArg_ParseTuple(args, "O!", &PyVec3Type, &otherObj))
    {
        PyErr_SetString(PyExc_TypeError, "需要一个 Vec3 对象");
        return NULL;
    }
    PyVec3* other = (PyVec3*)otherObj;
    return PyFloat_FromDouble(self->ptr->dot(*other->ptr));
}

static PyObject* PyVec3_crossSelf(PyVec3* self, PyObject* args)
{
    PyObject* otherObj;
    if (!PyArg_ParseTuple(args, "O!", &PyVec3Type, &otherObj))
    {
        PyErr_SetString(PyExc_TypeError, "需要一个 Vec3 对象");
        return NULL;
    }
    PyVec3* other = (PyVec3*)otherObj;
    self->ptr->crossSelf(*other->ptr);
    Py_RETURN_NONE;
}

PyMODINIT_FUNC PyInit_mmath()
{
    static PyModuleDef moduleDef = {
        PyModuleDef_HEAD_INIT,
        "mmath",
        "测试C扩展数学运算库",
        -1,
        NULL, NULL, NULL, NULL, NULL
    };
    if (PyType_Ready(&PyVec3Type) < 0)
    {
        return NULL;
    }
    PyObject* m = PyModule_Create(&moduleDef);
    Py_INCREF(&PyVec3Type);
    PyModule_AddObject(m, "Vec3", (PyObject*)&PyVec3Type);
    return m;
}

完整pyi声明

提供了C++类的完整声明信息,便于IDE进行类型推断和代码补全。

python
from typing import overload

class Vec3:
    @overload
    def __init__(self) -> None: ...
    @overload
    def __init__(self, x: float, y: float, z: float) -> None: ...

    @property
    def x(self) -> float: ...
    @x.setter
    def x(self, value: float) -> None: ...

    @property
    def y(self) -> float: ...
    @y.setter
    def y(self, value: float) -> None: ...

    @property
    def z(self) -> float: ...
    @z.setter
    def z(self, value: float) -> None: ...
    
    def setXYZ(self, x: float, y: float, z: float) -> None: ...
    def getX(self) -> float: ...
    def getY(self) -> float: ...
    def getZ(self) -> float: ...
    def setX(self, x: float) -> None: ...
    def setY(self, y: float) -> None: ...
    def setZ(self, z: float) -> None: ...
    def getTuple(self) -> tuple[float, float, float]: ...

    # Vec3相关运算操作仅修改自己 不会产生新的对象
    def length(self) -> float: ...
    def normalize(self) -> None: ...
    def scale(self, scalar: float) -> None: ...
    def addVec3(self, other: 'Vec3') -> None: ...
    def dot(self, other: 'Vec3') -> float: ...
    def crossSelf(self, other: 'Vec3') -> None: ...

由于手写官方C-API实现class绑定的工作量过于庞大,正式项目中建议使用第三方库如pybind11nanobind来简化绑定工作。

Released under the BSD3 License