Cython Compilation

Goal - one way to make run Python code run 200x faster is to use Cython compiled code. Let's see how.


Here we will create a simple program to find the number of Pythagorean triplets within a given number - n.


Create a caller .py file

# p_triplets.py

from time import time

import util

import c_util


"""

Prerequisites

- gcc for c compilation


Compile .pyx file

$ pip install cython, easycython

$ easycython c_util.pyx


"""


if __name__ == "__main__":

n = int(input("enter n:"))

start = time()

count = util.find(n)

taken = time() - start

print("python code", n, count, taken, "s")

start = time()

count = c_util.find(n)

taken = time() - start

print("cython code", n, count, taken, "s")

Create a file for pure python implementation

# File: util.py


def find(n):

count=0

i=1

while i<=n:

j = 1

while j<=n:

k = 1

while k<=n:

#print(i, j, k)

s = i**2 + j**2

r = k ** 2

if r == s:

#print(i, j, k)

count += 1

#elif r > s:

# break

k += 1

j += 1

i += 1

return count



Create pyx file for Cython

# file: c_util.pyx


import cython


# The highlighted lines are additional


cpdef int find(int n):

cdef int i, j, k, count

count=0

i=1

while i<=n:

j = 1

while j<=n:

k = 1

while k<=n:

#print(i, j, k)

s = i**2 + j**2

r = k ** 2

if r == s:

count += 1

elif r > s:

break

k += 1

j += 1

i += 1

return count



Compile pyx using easycython package. You will need gcc to run this utility. It generates shared object (so) file.

$ pip install easycython

$ easycython c_util.pyx


Run the code and compare the time taken to find the triplets.

$ python p_triplets.py

enter n:200

python code 200 254 4.901634931564331 s

cython code 200 254 0.006403207778930664 s



Just for fun, here is a java equivalent of the same code.


// file: PTriplets.java


import java.util.Scanner;


public class PTriplets {

public static int triplet(int n){

int count = 0;

for(int i=1;i<n;++i){

for(int j=1;j<n;++j){

for(int k=1;k<n;++k){

int k2 = k * k;

int s2 = i*i + j*j;

if(k2 == s2){

count++;

}else if(k2>s2){

break;

}

}

}

}

return count;

}


public static void main(String[] args){

int n = Integer.valueOf(args[0]);

long start = new java.util.Date().getTime();

int count = triplet(n);

long taken = new java.util.Date().getTime() - start;

System.out.println(String.format("n: %d, count: %d, taken (ms): %d"

, n, count, taken));




}

}


Compile and run

$ javac PTriplets.java

$ java PTriplets 200

n: 200, count: 250, taken (ms): 9