Cython Compilation

Goal - one way to make run Python code run 200x faster is to use Cython compiled code. Let's see how.


Here we will create a simple program to find the number of Pythagorean triplets within a given number - n.


Create a caller .py file

# p_triplets.py
from time import time   
import util
import c_util

"""
Prerequisites 
- gcc for c compilation

Compile .pyx file
$ pip install cython, easycython
$ easycython c_util.pyx 

"""

if __name__ == "__main__":
    n = int(input("enter n:"))
    start = time()
    count = util.find(n)
    taken = time() - start
    print("python code", n, count, taken, "s") 
    start = time()
    count = c_util.find(n)
    taken = time() - start
    print("cython code", n, count, taken, "s") 

Create a file for pure python implementation

# File: util.py

def find(n):
    count=0
    i=1
    while i<=n:
        j = 1
        while j<=n:
            k = 1
            while k<=n:
                #print(i, j, k)
                s = i**2 + j**2
                r = k ** 2
                if r == s:
                    #print(i, j, k)
                    count += 1
                #elif r > s:
                #    break
                k += 1
            j += 1
        i += 1
    return count



Create pyx file for Cython

# file: c_util.pyx

import cython 

# The highlighted lines are additional

cpdef int find(int n):
    cdef int i, j, k, count
    
    count=0
    i=1
    while i<=n:
        j = 1
        while j<=n:
            k = 1
            while k<=n:
                #print(i, j, k)
                s = i**2 + j**2
                r = k ** 2
                if r == s:
                    count += 1
                elif r > s:
                    break
                k += 1
            j += 1
        i += 1
    return count



Compile pyx using easycython package. You will need gcc to run this utility. It generates shared object (so) file.

$ pip install easycython
$ easycython c_util.pyx 


Run the code and compare the time taken to find the triplets.

$ python p_triplets.py 
enter n:200
python code 200 254 4.901634931564331 s
cython code 200 254 0.006403207778930664 s



Just for fun, here is a java equivalent of the same code.


// file: PTriplets.java 

import java.util.Scanner;

public class PTriplets {
    public static int triplet(int n){
        int count = 0;
        for(int i=1;i<n;++i){
            for(int j=1;j<n;++j){
                for(int k=1;k<n;++k){
                    int k2 = k * k;
                    int s2 = i*i + j*j;
                    if(k2 == s2){
                        count++;
                    }else if(k2>s2){
                        break;
                    }
                }
            }
        }
        return count;
    }

    public static void main(String[] args){
        int n = Integer.valueOf(args[0]);
        long start = new java.util.Date().getTime();
        int count = triplet(n);
        long taken = new java.util.Date().getTime() - start;
        System.out.println(String.format("n: %d, count: %d, taken (ms): %d"
                , n, count, taken));



    }
}


Compile and run

$ javac PTriplets.java 
$ java PTriplets 200
n: 200, count: 250, taken (ms): 9