Cython Compilation
Goal - one way to make run Python code run 200x faster is to use Cython compiled code. Let's see how.
Here we will create a simple program to find the number of Pythagorean triplets within a given number - n.
Create a caller .py file
# p_triplets.py
from time import time
import util
import c_util
"""
Prerequisites
- gcc for c compilation
Compile .pyx file
$ pip install cython, easycython
$ easycython c_util.pyx
"""
if __name__ == "__main__":
n = int(input("enter n:"))
start = time()
count = util.find(n)
taken = time() - start
print("python code", n, count, taken, "s")
start = time()
count = c_util.find(n)
taken = time() - start
print("cython code", n, count, taken, "s")
Create a file for pure python implementation
# File: util.py
def find(n):
count=0
i=1
while i<=n:
j = 1
while j<=n:
k = 1
while k<=n:
#print(i, j, k)
s = i**2 + j**2
r = k ** 2
if r == s:
#print(i, j, k)
count += 1
#elif r > s:
# break
k += 1
j += 1
i += 1
return count
Create pyx file for Cython
# file: c_util.pyx
import cython
# The highlighted lines are additional
cpdef int find(int n):
cdef int i, j, k, count
count=0
i=1
while i<=n:
j = 1
while j<=n:
k = 1
while k<=n:
#print(i, j, k)
s = i**2 + j**2
r = k ** 2
if r == s:
count += 1
elif r > s:
break
k += 1
j += 1
i += 1
return count
Compile pyx using easycython package. You will need gcc to run this utility. It generates shared object (so) file.
$ pip install easycython
$ easycython c_util.pyx
Run the code and compare the time taken to find the triplets.
$ python p_triplets.py
enter n:200
python code 200 254 4.901634931564331 s
cython code 200 254 0.006403207778930664 s
Just for fun, here is a java equivalent of the same code.
// file: PTriplets.java
import java.util.Scanner;
public class PTriplets {
public static int triplet(int n){
int count = 0;
for(int i=1;i<n;++i){
for(int j=1;j<n;++j){
for(int k=1;k<n;++k){
int k2 = k * k;
int s2 = i*i + j*j;
if(k2 == s2){
count++;
}else if(k2>s2){
break;
}
}
}
}
return count;
}
public static void main(String[] args){
int n = Integer.valueOf(args[0]);
long start = new java.util.Date().getTime();
int count = triplet(n);
long taken = new java.util.Date().getTime() - start;
System.out.println(String.format("n: %d, count: %d, taken (ms): %d"
, n, count, taken));
}
}
Compile and run
$ javac PTriplets.java
$ java PTriplets 200
n: 200, count: 250, taken (ms): 9