Java、Python、C混合编程



2021年09月12日    Author:Guofei

文章归类: 合集    文章编号: 173

版权声明:本文作者是郭飞。转载随意,标明原文链接即可。本人邮箱
原文链接:https://www.guofei.site/2021/09/12/hybrid.html


编译后的文件

window:

  • .dll 动态
  • .lib 静态

Linux

  • .o 对应 Windows上的 .obj
  • .a 多个 .o 的合并
  • .so 对应 Windows上的 .dll

MacOS

  • 动态 .dylib
  • 静态 .a

【ctypes】python 调 C 编译后的动态库

优缺点:

  • 优点:无需处理 Python 版本
  • 缺点:不能多线程

步骤

  1. 写 c 文件, 保存为 file.c
    #include<stdio.h>
    void add()
    {
    int sum=0;
    for(int i=0;i<1000;i++)
    sum+=i;
    printf("sum:%d\n",sum);
    }
    
  2. file.c 文件编译为 file.sogcc file.c -shared -o file.so
  3. 在Python中导入和使用 .so
    from ctypes import cdll
    clib = cdll.LoadLibrary("./file.so")
    clib.add()
    

如何同时支持 windows/Linux : https://www.bilibili.com/video/BV17y4y1W7BY?p=3

数据类型

ctypes 类型 C 类型 Python 数据类型
c_bool _Bool bool (1)
c_char char 单字符字节串对象
c_wchar wchar_t 单字符字符串
c_byte char int
c_ubyte unsigned char int
c_short short int
c_ushort unsigned short int
c_int int int
c_uint unsigned int int
c_long long int
c_ulong unsigned long int
c_longlong __int64 或 long long int
c_ulonglong unsigned __int64 或 unsigned long long int
c_size_t size_t int
c_ssize_t ssize_t 或 Py_ssize_t int
c_float float float
c_double double float
c_longdouble long double float
c_char_p char * (NUL terminated) 字节串对象或 None
c_wchar_p wchar_t * (NUL terminated) 字符串或 None
c_void_p void * int 或 None

例子:

// c 
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

int my_c_fun(int x1, float x2, char *x3 // 指向字符串的指针
        , int *x4 // 指向 array 的指针
        , char **x5 //指向 array of string 的指针
        , int *x6 // 指向一个 int 类型的指针,在函数中修改它指向的值
        , char *x7 // 指向一个 string 类型的指针,在函数中修改它指向的值
) {
    printf("Get input :\n");
    printf("x1 = %d\n", x1);
    printf("x2 = %f\n", x2);
    printf("x3 = %s\n", x3);

    printf("x4 = ");
    for (int i = 0; i < 5; i++) {
        printf(" %d", x4[i]);
    }
    printf("\n");

    printf("x5 = ");
    for (int i = 0; i < 3; i++) {
        printf(" %s", x5[i]);
    }
    printf("\n");

    // 修改指针指向的值
    *x6 = 8;
    printf("%s\n", x7);
    strcpy(x7, "你好");

    return x1 + (int) x2 + (int) strlen(x3) + x4[0];
}


// python

from ctypes import cdll

clib = cdll.LoadLibrary("./file.so")
import ctypes

clib.my_c_fun.argtypes = [
    ctypes.c_int,
    ctypes.c_float,
    ctypes.c_char_p,
    ctypes.POINTER(ctypes.c_int),
    ctypes.POINTER(ctypes.c_char_p),
    ctypes.POINTER(ctypes.c_int),
    ctypes.c_char_p,
]

clib.my_c_fun.restype = ctypes.c_int

int_lst = [1, 2, 3, 4, 5]
x4 = (ctypes.c_int * len(int_lst))()
x4[:] = int_lst

# Pass a list of string
str_lst = ["hello", "world", "!", "\0"]
x5 = (ctypes.c_char_p * len(str_lst))()
x5[:] = [i.encode("utf-8") for i in str_lst]

x6 = ctypes.c_int()
x7 = ctypes.create_string_buffer(100)
x7.value = "nihao".encode('utf-8')  # 也先可以赋予一个值

res = clib.my_c_fun(
    1  # 默认 int 类型,无需转化
    , 1.2  # 传入 float
    , "hello".encode('utf-8')  # 传入字符串 char *
    , x4
    , x5
    , ctypes.byref(x6)
    , x7
)

print('res:', res)
print("x6:", x6.value)
print("x7:", x7.value.decode('utf-8'))
int my_c_fun(int x1, float x2, char *x3, int *x4, char **x5) {
    printf("Get input :\n");
    printf("x1 = %d\n", x1);
    printf("x2 = %f\n", x2);
    printf("x3 = %s\n", x3);

    printf("x4 = ");
    for (int i = 0; i < 5; i++) {
        printf(" %d", x4[i]);
    }
    printf("\n");

    printf("x5 = ");
    for (int i = 0; i < 3; i++) {
        printf(" %s", x5[i]);
    }
    printf("\n");


    return x1 + (int) x2 + (int) strlen(x3) + x4[0];
}

char *my_c_fun2() {
    char *p = malloc(20);
    strcpy(p, "hello");
    return p;
}


int *my_c_fun3() {
    int *p = malloc(20 * sizeof(int));
    for (int i = 0; i < 20; i++) {
        p[i] = 30 - i;
    }
    return p;
}

// my_c_fun3 对应的内存应当及时 free。如果这里 free 了,python脚本里面的返回值也就free了
void my_c_fun3_free(int *p) {
    free(p);
}


char **my_c_fun4() {
    char **p = malloc(5 * sizeof(char *));
    p[0] = "hello";
    p[1] = "world";
    return p;
}

用 python 去调用上面那个 C

import ctypes

clib = ctypes.cdll.LoadLibrary("/Users/guofei/c_algorithm/main.o")

# Pass a list of int
int_lst = [1, 2, 3, 4, 5]
int_arr = (ctypes.c_int * len(int_lst))()
int_arr[:] = int_lst

# Pass a list of string
str_lst = ["hello", "world", "!", "\0"]
str_arr = (ctypes.c_char_p * len(str_lst))()
str_arr[:] = [i.encode("utf-8") for i in str_lst]

x = clib.my_c_fun(1  # 默认 int 类型,无需转化
                  , ctypes.c_float(1.2)  # 传入 float
                  , ctypes.c_char_p("hello".encode('utf-8'))  # 传入字符串 char *
                  , int_arr # 传入 int *
                  , str_arr # 传入 char **
                  )
print('res:', x)

# 返回一个字符串
clib.my_c_fun2.restype = ctypes.c_char_p
x = clib.my_c_fun2()
print('res:', x)

# 返回 list of int
clib.my_c_fun3.restype = ctypes.POINTER(ctypes.c_int * 20)
x = clib.my_c_fun3()
print('res:', list(x.contents))

# 注意:用完记得 free,free之后,x 的内容也变了(内存被释放了)
clib.my_c_fun3_free(x)

# 返回 list of string
clib.my_c_fun4.restype = ctypes.POINTER(ctypes.c_char_p * 2)
x = clib.my_c_fun4()
print('res:', list(x.contents))

# 这个不知道什么用
# clib.my_c_fun2.argtypes = (ctypes.c_char_p,)

Python 调 Rust

Rust混合编程

https://www.guofei.site/2022/08/28/rust2.html#混合编程

【JPype1】Python 调 Java

安装

pip install JPype1

示例java脚本

package com.guofei9987.learn_java;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class JavaClass {

    public static String tstStatic(String a) {
        return "调用了静态方法,入参:" + a;
    }

    public List<String> my_java_fun(int x1, String x2, List<Integer> x3, List<String> x4) {
        System.out.println("x1 = " + x1);
        System.out.println("x2 = " + x2);
        System.out.println("x3 = " + x3);
        System.out.println("x4 = " + x4);

        ArrayList<String> res = new ArrayList<String>(Arrays.asList("" + x1, x2)) {
        };
        res.addAll(x4);
        return res;
    }

    public String my_java_fun(int a) {
        return "调用了重载的方法,入参:" + a;
    }
}

Python 调用 Java

import jpype
import os

jarpath = os.path.join(os.path.abspath("."), "/Users/guofei9987/learn_java/target/learn_java-1.0-SNAPSHOT.jar")
jvmPath = jpype.getDefaultJVMPath()
jpype.startJVM(jvmPath, "-ea", "-Djava.class.path=%s" % (jarpath))

# step1:加载类
javaClass = jpype.JClass("com.guofei9987.learn_java.JavaClass")

# step2:实例化java对象
javaInstance = javaClass()

# step3:调用java方法
javaClass.tstStatic("Hello")
# >>>'调用了静态方法,入参:Hello'


output = javaInstance.my_java_fun(
    1  # 传入 int
    , "hello"  # 传入 String
    , jpype.java.util.ArrayList([1, 2, 3])  # 传入 ArrayList<int>
    , jpype.java.util.ArrayList(["hello", "world"]))  # 传入 ArrayList<String>
print(list(output))

output2 = javaInstance.my_java_fun(1)
print(output2)

# step4:关闭jvm
jpype.shutdownJVM()

关于变量类型

  • java.lang.String
    • 把 java.lang.String 转为 Python 的 str 类型:str(java_string)
    • 把 Python 的 str 转为 java.lang.String:jpype.JClass("java.lang.String")("Hello, World!")

【PMML】Java调 sklearn 算法模型

描述

Python 训练算法模型,保存为 PMML 文件。然后 Java 调用 PMML 文件做预测,这样线上就变成纯 JAVA 项目了。

相关项目

  • Python 端的项目:https://github.com/jpmml/sklearn2pmml
  • Java端 项目:https://github.com/jpmml/jpmml-model
  • 对应的库:https://repo1.maven.org/maven2/org/jpmml/

安装

  • Python
    pip install scikit-learn==0.24.2
    pip install sklearn2pmml==0.74.4
    
  • JAVA pom.xml 文件
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator</artifactId>
    <version>1.5.9</version>
</dependency>

使用

python 训练模型

from sklearn import tree
from sklearn.datasets import load_iris
from sklearn2pmml.pipeline import PMMLPipeline
from sklearn2pmml import sklearn2pmml

iris = load_iris()
X, y = iris.data, iris.target
pipeline = PMMLPipeline([("classifier", tree.DecisionTreeClassifier())])  # 用决策树分类
pipeline.fit(X, y)
sklearn2pmml(pipeline, "iris.pmml", with_repr=True)

java 调用模型做预测

package com.alipay.tianqian;

import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.*;
import java.util.*;


public class TestPmml {

    public static void main(String[] args) throws JAXBException, SAXException, IOException {
        TestPmml obj = new TestPmml();

        Evaluator model = new LoadingModelEvaluatorBuilder()
                .load(new File("src/iris.pmml"))
                .build();

        List<Map<String, Object>> inputs = new ArrayList<>();
        inputs.add(obj.getRawMap(5.1, 3.5, 1.4, 0.2));
        inputs.add(obj.getRawMap(4.9, 3, 1.5, 4));
        for (Map<String, Object> input : inputs) {
            Map<String, Object> output = obj.predict(model, input);
            System.out.println("X=" + input + " -> y=" + output.get("y"));
        }
    }



    private Map<String, Object> getRawMap(Object a, Object b, Object c, Object d) {
        Map<String, Object> data = new HashMap<String, Object>();
        data.put("x1", a);
        data.put("x2", b);
        data.put("x3", c);
        data.put("x4", d);
        return data;
    }

    /**
     * 运行模型得到结果。
     */
    private Map<String, Object> predict(Evaluator evaluator, Map<String, Object> data) {
        Map<FieldName, FieldValue> input = getFieldMap(evaluator, data);
        Map<String, Object> output = evaluate(evaluator, input);
        return output;
    }

    /**
     * 把原始输入转换成PMML格式的输入。
     */
    private Map<FieldName, FieldValue> getFieldMap(Evaluator evaluator, Map<String, Object> input) {
        List<InputField> inputFields = evaluator.getInputFields();
        Map<FieldName, FieldValue> map = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField field : inputFields) {
            FieldName fieldName = field.getName();
            Object rawValue = input.get(fieldName.getValue());
            FieldValue value = field.prepare(rawValue);
            map.put(fieldName, value);
        }
        return map;
    }

    /**
     * 运行模型得到结果。
     */
    private Map<String, Object> evaluate(Evaluator evaluator, Map<FieldName, FieldValue> input) {
        Map<FieldName, ?> results = evaluator.evaluate(input);
        List<TargetField> targetFields = evaluator.getTargetFields();
        Map<String, Object> output = new LinkedHashMap<String, Object>();
        for (TargetField field : targetFields) {
            FieldName fieldName = field.getName();
            Object value = results.get(fieldName);
            if (value instanceof Computable) {
                Computable computable = (Computable) value;
                value = computable.getResult();
            }
            output.put(fieldName.getValue(), value);
        }
        return output;
    }
}

您的支持将鼓励我继续创作!