利用DeepSeek编写Duckdb的C插件实现大数运算
摘要: 本文展示了如何在DuckDB扩展中正确处理字符串数据以调用OpenSSL大数运算函数。作者首先遇到字符串处理问题,发现DuckDB对字符串有特殊处理机制(短字符串内联存储,长字符串存地址)。通过参考文档,提供了两个关键函数:extract_string()用于从duckdb_string_t提取C字符串,set_string()用于将结果设置回duckdb_string_t。文中实现了两个
·
在C插件模板的基础上, 将一个调用openssl中的大数运算函数示例发给他,第一次给出的代码编译通过,但是执行出错,问题出在DuckDB对字符串做了特殊处理,短的内联,长的存地址,他不了解,将文档中的正确例子给他,就输出了如下代码:
#include "duckdb_extension.h"
#include <string.h>
#include <stdlib.h>
// 声明BN函数(来自OpenSSL)
//extern "C" {
typedef struct bignum_st BIGNUM;
typedef struct bn_ctx_st BN_CTX;
BIGNUM* BN_new();
void BN_free(BIGNUM*);
int BN_dec2bn(BIGNUM**, const char*);
char* BN_bn2dec(const BIGNUM*);
int BN_add(BIGNUM*, const BIGNUM*, const BIGNUM*);
int BN_mul(BIGNUM*, const BIGNUM*, const BIGNUM*, BN_CTX*);
BN_CTX* BN_CTX_new();
void BN_CTX_free(BN_CTX*);
void free(void*);
//}
DUCKDB_EXTENSION_EXTERN
// 从duckdb_string_t提取C字符串
static char* extract_string(duckdb_string_t str) {
char* result = NULL;
if (duckdb_string_is_inlined(str)) {
result = duckdb_malloc(str.value.inlined.length + 1);
memcpy(result, str.value.inlined.inlined, str.value.inlined.length);
result[str.value.inlined.length] = '\0';
} else {
result = duckdb_malloc(str.value.pointer.length + 1);
memcpy(result, str.value.pointer.ptr, str.value.pointer.length);
result[str.value.pointer.length] = '\0';
}
return result;
}
// 设置duckdb_string_t为C字符串
static void set_string(duckdb_string_t* str, const char* cstr) {
size_t len = strlen(cstr);
if (len <= 12) { // 内联字符串最大长度
str->value.inlined.length = len;
memcpy(str->value.inlined.inlined, cstr, len);
str->value.inlined.inlined[len] = '\0';
} else {
char* ptr = duckdb_malloc(len + 1);
memcpy(ptr, cstr, len);
ptr[len] = '\0';
str->value.pointer.ptr = ptr;
str->value.pointer.length = len;
}
}
// 标量函数:大数加法
static void BNAddFunction(duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output) {
idx_t input_size = duckdb_data_chunk_get_size(input);
duckdb_vector a_vec = duckdb_data_chunk_get_vector(input, 0);
duckdb_vector b_vec = duckdb_data_chunk_get_vector(input, 1);
// 获取字符串向量数据
duckdb_string_t* a_data = (duckdb_string_t*)duckdb_vector_get_data(a_vec);
duckdb_string_t* b_data = (duckdb_string_t*)duckdb_vector_get_data(b_vec);
duckdb_string_t* result_data = (duckdb_string_t*)duckdb_vector_get_data(output);
// 获取有效性位图
uint64_t* a_validity = duckdb_vector_get_validity(a_vec);
uint64_t* b_validity = duckdb_vector_get_validity(b_vec);
BN_CTX* ctx = BN_CTX_new();
if (!ctx) return;
for (idx_t row = 0; row < input_size; row++) {
if ((a_validity && !duckdb_validity_row_is_valid(a_validity, row)) ||
(b_validity && !duckdb_validity_row_is_valid(b_validity, row))) {
// 如果有NULL值,结果设为NULL
duckdb_vector_ensure_validity_writable(output);
uint64_t* result_validity = duckdb_vector_get_validity(output);
duckdb_validity_set_row_invalid(result_validity, row);
continue;
}
BIGNUM* a = BN_new();
BIGNUM* b = BN_new();
BIGNUM* result = BN_new();
char* a_str = NULL;
char* b_str = NULL;
char* res_str = NULL;
if (!a || !b || !result) {
if (a) BN_free(a);
if (b) BN_free(b);
if (result) BN_free(result);
continue;
}
// 提取输入字符串
a_str = extract_string(a_data[row]);
b_str = extract_string(b_data[row]);
// 将字符串转换为大数
if (!BN_dec2bn(&a, a_str) || !BN_dec2bn(&b, b_str)) {
duckdb_free(a_str);
duckdb_free(b_str);
BN_free(a);
BN_free(b);
BN_free(result);
continue;
}
// 执行加法
if (!BN_add(result, a, b)) {
duckdb_free(a_str);
duckdb_free(b_str);
BN_free(a);
BN_free(b);
BN_free(result);
continue;
}
// 将结果转换为字符串
res_str = BN_bn2dec(result);
set_string(&result_data[row], res_str);
free(res_str);
duckdb_free(a_str);
duckdb_free(b_str);
BN_free(a);
BN_free(b);
BN_free(result);
}
BN_CTX_free(ctx);
}
// 标量函数:大数乘法
static void BNMultiplyFunction(duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output) {
idx_t input_size = duckdb_data_chunk_get_size(input);
duckdb_vector a_vec = duckdb_data_chunk_get_vector(input, 0);
duckdb_vector b_vec = duckdb_data_chunk_get_vector(input, 1);
duckdb_string_t* a_data = (duckdb_string_t*)duckdb_vector_get_data(a_vec);
duckdb_string_t* b_data = (duckdb_string_t*)duckdb_vector_get_data(b_vec);
duckdb_string_t* result_data = (duckdb_string_t*)duckdb_vector_get_data(output);
uint64_t* a_validity = duckdb_vector_get_validity(a_vec);
uint64_t* b_validity = duckdb_vector_get_validity(b_vec);
BN_CTX* ctx = BN_CTX_new();
if (!ctx) return;
for (idx_t row = 0; row < input_size; row++) {
if ((a_validity && !duckdb_validity_row_is_valid(a_validity, row)) ||
(b_validity && !duckdb_validity_row_is_valid(b_validity, row))) {
duckdb_vector_ensure_validity_writable(output);
uint64_t* result_validity = duckdb_vector_get_validity(output);
duckdb_validity_set_row_invalid(result_validity, row);
continue;
}
BIGNUM* a = BN_new();
BIGNUM* b = BN_new();
BIGNUM* result = BN_new();
char* a_str = NULL;
char* b_str = NULL;
char* res_str = NULL;
if (!a || !b || !result) {
if (a) BN_free(a);
if (b) BN_free(b);
if (result) BN_free(result);
continue;
}
a_str = extract_string(a_data[row]);
b_str = extract_string(b_data[row]);
if (!BN_dec2bn(&a, a_str) || !BN_dec2bn(&b, b_str)) {
duckdb_free(a_str);
duckdb_free(b_str);
BN_free(a);
BN_free(b);
BN_free(result);
continue;
}
// 执行乘法
if (!BN_mul(result, a, b, ctx)) {
duckdb_free(a_str);
duckdb_free(b_str);
BN_free(a);
BN_free(b);
BN_free(result);
continue;
}
res_str = BN_bn2dec(result);
set_string(&result_data[row], res_str);
free(res_str);
duckdb_free(a_str);
duckdb_free(b_str);
BN_free(a);
BN_free(b);
BN_free(result);
}
BN_CTX_free(ctx);
}
// 注册函数
void RegisterBNFunctions(duckdb_connection connection) {
// 注册大数加法函数
duckdb_scalar_function add_func = duckdb_create_scalar_function();
duckdb_scalar_function_set_name(add_func, "bn_add");
duckdb_logical_type varchar_type = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR);
duckdb_scalar_function_add_parameter(add_func, varchar_type);
duckdb_scalar_function_add_parameter(add_func, varchar_type);
duckdb_scalar_function_set_return_type(add_func, varchar_type);
duckdb_scalar_function_set_function(add_func, BNAddFunction);
duckdb_register_scalar_function(connection, add_func);
duckdb_destroy_scalar_function(&add_func);
// 注册大数乘法函数
duckdb_scalar_function mul_func = duckdb_create_scalar_function();
duckdb_scalar_function_set_name(mul_func, "bn_multiply");
duckdb_scalar_function_add_parameter(mul_func, varchar_type);
duckdb_scalar_function_add_parameter(mul_func, varchar_type);
duckdb_scalar_function_set_return_type(mul_func, varchar_type);
duckdb_scalar_function_set_function(mul_func, BNMultiplyFunction);
duckdb_register_scalar_function(connection, mul_func);
duckdb_destroy_scalar_function(&mul_func);
duckdb_destroy_logical_type(&varchar_type);
}
将RegisterBNFunctions函数的声明加入add_numbers.h
void RegisterBNFunctions(duckdb_connection connection);
再在capi_quack.c中增加对它的调用
RegisterBNFunctions(connection);
重新编译和加工就可以加载使用了。可见,它像模板一样正确处理了NULL。注意,编译动态库时需要加-lssl -lcrypto选项动态链接openssl库。还要注意load命令后跟绝对路径的插件文件名。
root@6ae32a5ffcde:/par/extension-template-c-main/src# gcc -fPIC -shared -o libtest2.so *.c -I include -I ../duckdb_capi -lssl -lcrypto
root@6ae32a5ffcde:/par/extension-template-c-main/src# python3 ../../appendmetadata.py -l libtest2.so -n add -dv v1.2.0 --duckdb-platform linux_amd64 --extension-version 0.1
Creating extension binary:
- Input file: libtest2.so
- Output file: add.duckdb_extension
- Metadata:
- FIELD8 (unused) = EMPTY
- FIELD7 (unused) = EMPTY
- FIELD6 (unused) = EMPTY
- FIELD5 (abi_type) = C_STRUCT
- FIELD4 (extension_version) = 0.1
- FIELD3 (duckdb_version) = v1.2.0
- FIELD2 (duckdb_platform) = linux_amd64
- FIELD1 (header signature) = 4 (special value to identify a duckdb extension)
root@6ae32a5ffcde:/par/extension-template-c-main/src# /par/duckdb140 -unsigned
DuckDB v1.4.0 (Andium) b8a06e4a22
Enter ".help" for usage hints.
D load '/par/extension-template-c-main/src/add.duckdb_extension';
D select bn_add('123456780','123999999');
┌──────────────────────────────────┐
│ bn_add('123456780', '123999999') │
│ varchar │
├──────────────────────────────────┤
│ 247456779 │
└──────────────────────────────────┘
D select bn_add('123456780123456780','123999999'),bn_multiply('12345678901234567890','123999999');
┌───────────────────────────────────────────┬──────────────────────────────────────────────────┐
│ bn_add('123456780123456780', '123999999') │ bn_multiply('12345678901234567890', '123999999') │
│ varchar │ varchar │
├───────────────────────────────────────────┼──────────────────────────────────────────────────┤
│ 123456780247456779 │ 1530864171407407517125432110 │
└───────────────────────────────────────────┴──────────────────────────────────────────────────┘
D with t as(select '12345678' a,NULL::varchar b union all select '123456780123456780','123999999' union all select NULL,'123999999')select a,b,bn_multiply(a,b) from t;
┌────────────────────┬───────────┬────────────────────────────┐
│ a │ b │ bn_multiply(a, b) │
│ varchar │ varchar │ varchar │
├────────────────────┼───────────┼────────────────────────────┤
│ 12345678 │ NULL │ NULL │
│ 123456780123456780 │ 123999999 │ 15308640611851860596543220 │
│ NULL │ 123999999 │ NULL │
└────────────────────┴───────────┴────────────────────────────┘
更多推荐


所有评论(0)