Sunday, January 17, 2010

Profiling adventures and cython - Introducing cython

In the previous blog post, I made some attempts at speeding up the function mandel() by making changes in the Python code. While I had some success in doing so, it was clearly not enough for my purpose. As a result, I will now try to use cython. Before I do this, I note again the result from the last profiling run, limiting the information to the 5 longest-running functions or methods.

       3673150 function calls in 84.807 CPU seconds

 ncalls  tottime  percall  cumtime  percall filename:lineno(function)
3168000   73.210    0.000   73.210    0.000 mandel1c.py:7(mandel)
     11   10.855    0.987   84.204    7.655 viewer1.py:23(draw_fractal)
      1    0.593    0.593    0.593    0.593 {_tkinter.create}
 504530    0.137    0.000    0.137    0.000 viewer1.py:17(draw_pixel)
     37    0.009    0.000    0.009    0.000 {built-in method call}

The goal of cython could be described as providing an easy way to convert a Python module into a C extension. This is what I will do. [There are other ways to work with cython extensions than what I use here; for more information, please consult the cython web site.] Note that I am NOT a cython expert; this is only the first project for which I use cython. While I am not interested in creating an application for distribution, and hence do not use the setup method for cython, it is quite possible that there are better ways to use cython than what I explore here.
I first start by taking my existing module and copying it into a new file, with a ".pyx" extension instead of the traditional ".py".

# mandel2cy.pyx
# cython: profile=True

def mandel(c, max_iterations=20):
    '''determines if a point is in the Mandelbrot set based on deciding if,
       after a maximum allowed number of iterations, the absolute value of
       the resulting number is greater or equal to 2.'''
    z = 0
    for i in range(0, max_iterations):
        z = z**2 + c
        if abs(z) >= 4:
            return False
    return abs(z) < 2

Note that I have removed the equivalence between range and xrange. The reason I have done this is because with xrange present like this in the file results in a compilation error when running cython with Python 3.1. Furthermore, as will be seen later, it is not really needed even for Python 2.x when using cython properly.
I have also included a commented line stating that 'profile' was equal to True; this is a cython directive that will enable the Python profiler to also include cython functions in its tally.

In order to import this module, I also need to modify the viewer to import the cython module. Here is the new version.
# viewer2.py

import pyximport
pyximport.install()

from mandel2_cy import mandel
from viewer import Viewer
import time

import sys
if sys.version_info < (3,):
    import Tkinter as tk
    range = xrange
else:
    import tkinter as tk

class FancyViewer(Viewer):
    '''Application to display fractals'''

    def draw_pixel(self, x, y):
        '''Simulates drawing a given pixel in black by drawing a black line
           of length equal to one pixel.'''
        return
        #self.canvas.create_line(x, y, x+1, y, fill="black")

    def draw_fractal(self):
        '''draws a fractal on the canvas'''
        self.calculating = True
        begin = time.time()
        # clear the canvas
        self.canvas.create_rectangle(0, 0, self.canvas_width,
                                    self.canvas_height, fill="white")
        for x in range(0, self.canvas_width):
            real = self.min_x + x*self.pixel_size
            for y in range(0, self.canvas_height):
                imag = self.min_y + y*self.pixel_size
                c = complex(real, imag)
                if mandel(c, self.nb_iterations):
                    self.draw_pixel(x, self.canvas_height - y)
        self.status.config(text="Time required = %.2f s  [%s iterations]  %s" %(
                                (time.time() - begin), self.nb_iterations,
                                                                self.zoom_info))
        self.status2.config(text=self.info())
        self.calculating = False

if __name__ == "__main__":
    root = tk.Tk()
    app = FancyViewer(root)
    root.mainloop()

Other than the top few lines, nothing has changed. Time to run the profiler with this new version.
       6841793 function calls in 50.145 CPU seconds

 ncalls  tottime  percall  cumtime  percall filename:lineno(function)
3168000   35.913    0.000   35.913    0.000 mandel2_cy.pyx:4(mandel)
     11   10.670    0.970   48.754    4.432 viewer2.py:26(draw_fractal)
3168000    2.001    0.000   37.914    0.000 {mandel2_cy.mandel}
      1    1.356    1.356    1.356    1.356 {_tkinter.create}
 505173    0.167    0.000    0.167    0.000 viewer2.py:20(draw_pixel)

A reduction from 85 to 50 seconds; cython must be doing something right! Note that the calls to abs() have been eliminated by using cython. All I did is import the module via Python without making any other change to the code.
Note also that mandel appears twice: once (the longest running) as the function defined on line 8 of mandel2_cy.pyx, and once as a object belonging to the module mandel2_cy. I will come back to this later but, for now, I will do some changes to help cython do even better.
As mentioned before, cython is a tool to help create C extensions. One of the differences between C and Python is that variables have a declared type in C. If one tells cython about what type a given variable is, cython can often use that information to make the code run faster. As an example, I know that two of the variables are of type integers which is a native C type; I can add this information as follows.
# mandel2a_cy.pyx
# cython: profile=True

def mandel(c, int max_iterations=20):
    '''determines if a point is in the Mandelbrot set based on deciding if,
       after a maximum allowed number of iterations, the absolute value of
       the resulting number is greater or equal to 2.'''
    cdef int i
    z = 0
    for i in range(0, max_iterations):
        z = z**2 + c
        if abs(z) >= 2:
            return False
    return abs(z) < 2

Running the profiler with this change yields the following:
       6841793 function calls in 39.860 CPU seconds

 ncalls  tottime  percall  cumtime  percall filename:lineno(function)
3168000   27.431    0.000   27.431    0.000 mandel2a_cy.pyx:4(mandel)
     11    9.869    0.897   39.339    3.576 viewer2.py:26(draw_fractal)
3168000    1.906    0.000   29.337    0.000 {mandel2a_cy.mandel}
      1    0.511    0.511    0.511    0.511 {_tkinter.create}
 505173    0.131    0.000    0.131    0.000 viewer2.py:20(draw_pixel)

Another significant time reduction, this time of the order of 20%. And we didn't tell cython that "z" and "c" are complex yet.

Actually, C does not have a complex data type. So, I can choose one of two strategies:
  1. I can change the code so that I deal only with real numbers, by working myself how to multiply and add complex numbers.
  2. I can use some special cython technique to extract all the relevant information about the Python built-in complex data type without changing the code inside the function (other than adding some type declaration).
I will choose the second of these methods and see what it gives. The required changes are as follows:
# mandel2b_cy.pyx
# cython: profile=True

cdef extern from "complexobject.h":

    struct Py_complex:
        double real
        double imag

    ctypedef class __builtin__.complex [object PyComplexObject]:
        cdef Py_complex cval


def mandel(complex c, int max_iterations=20):
    '''determines if a point is in the Mandelbrot set based on deciding if,
       after a maximum allowed number of iterations, the absolute value of
       the resulting number is greater or equal to 2.'''
    cdef int i
    cdef complex z

    z = 0. + 0.j

    for i in range(0, max_iterations):
        z = z**2 + c
        if abs(z) >= 2:
            return False
    return abs(z) < 2

The timing results are the following:
       6841793 function calls in 38.424 CPU seconds

 ncalls  tottime  percall  cumtime  percall filename:lineno(function)
3168000   26.771    0.000   26.771    0.000 mandel2b_cy.pyx:14(mandel)
     11    9.435    0.858   38.209    3.474 viewer2.py:26(draw_fractal)
3168000    1.865    0.000   28.636    0.000 {mandel2b_cy.mandel}
      1    0.205    0.205    0.205    0.205 {_tkinter.create}
 505173    0.136    0.000    0.136    0.000 viewer2.py:20(draw_pixel)

The time difference between this run and the previous one is within the variation I observe from one profiling run to the next (using exactly the same program). Therefore, I conclude that this latest attempt didn't speed up the code. It is possible that I have overlooked something to ensure that cython could make use of the information about the complex datatype more efficiently ... It seems like I need a different strategy. I will resort to doing the complex algebra myself, and work only with real numbers. Here's the modified code for the mandel module.
# mandel2c_cy.pyx
# cython: profile=True

def mandel(double real, double imag, int max_iterations=20):
    '''determines if a point is in the Mandelbrot set based on deciding if,
       after a maximum allowed number of iterations, the absolute value of
       the resulting number is greater or equal to 2.'''
    cdef double z_real = 0., z_imag = 0.
    cdef int i

    for i in range(0, max_iterations):
        z_real, z_imag = ( z_real*z_real - z_imag*z_imag + real,
                     2*z_real*z_imag + imag )
        if (z_real*z_real + z_imag*z_imag) >= 4:
            return False
    return (z_real*z_real + z_imag*z_imag) < 4

I also change the call within draw_fractal() so that I don't use complex variables. The result is extremely encouraging:
       6841793 function calls in 7.205 CPU seconds

 ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     11    4.379    0.398    7.066    0.642 viewer2a.py:26(draw_fractal)
3168000    1.557    0.000    2.570    0.000 {mandel2c_cy.mandel}
3168000    1.013    0.000    1.013    0.000 mandel2c_cy.pyx:4(mandel)
      1    0.130    0.130    0.130    0.130 {_tkinter.create}
 505173    0.114    0.000    0.114    0.000 viewer2a.py:20(draw_pixel)

This total execution time has been reduced from 38 to 7 seconds. mandel() is no longer the largest contributor to the overall execution time; draw_fractal() is. However, the program is still a bit too slow: without actually doing any drawing, it takes approximately 0.6 seconds to generate one fractal image. However, I can do better. Looking at the code, I notice that draw_fractal() contains two embedded for loops, resulting to all those calls to mandel(). Remember how telling cython about integer types used in loops sped up the code? This suggest that perhaps I should do something similar and move some of the code of draw_fractal() to the cython module. Here's a modified viewer module.
# viewer2b.py

import pyximport
pyximport.install()

from mandel2d_cy import create_fractal
from viewer import Viewer
import time

import sys
if sys.version_info < (3,):
    import Tkinter as tk
    range = xrange
else:
    import tkinter as tk

class FancyViewer(Viewer):
    '''Application to display fractals'''

    def draw_fractal(self):
        '''draws a fractal on the canvas'''
        self.calculating = True
        begin = time.time()
        # clear the canvas
        self.canvas.create_rectangle(0, 0, self.canvas_width,
                                    self.canvas_height, fill="white")
        create_fractal(self.canvas_width, self.canvas_height,
                       self.min_x, self.min_y, self.pixel_size,
                       self.nb_iterations, self.canvas)
        self.status.config(text="Time required = %.2f s  [%s iterations]  %s" %(
                                (time.time() - begin), self.nb_iterations,
                                                                self.zoom_info))
        self.status2.config(text=self.info())
        self.calculating = False

if __name__ == "__main__":
    root = tk.Tk()
    app = FancyViewer(root)
    root.mainloop()

And here is the new cython module, without any additional type declaration.
# mandel2d_cy.pyx
# cython: profile=True

def mandel(double real, double imag, int max_iterations=20):
    '''determines if a point is in the Mandelbrot set based on deciding if,
       after a maximum allowed number of iterations, the absolute value of
       the resulting number is greater or equal to 2.'''
    cdef double z_real = 0., z_imag = 0.
    cdef int i

    for i in range(0, max_iterations):
        z_real, z_imag = ( z_real*z_real - z_imag*z_imag + real,
                     2*z_real*z_imag + imag )
        if (z_real*z_real + z_imag*z_imag) >= 4:
            return False
    return (z_real*z_real + z_imag*z_imag) < 4

def draw_pixel(x, y, canvas):
    '''Simulates drawing a given pixel in black by drawing a black line
       of length equal to one pixel.'''
    return
    #canvas.create_line(x, y, x+1, y, fill="black")

def create_fractal(canvas_width, canvas_height,
                       min_x, min_y, pixel_size,
                       nb_iterations, canvas):
    for x in range(0, canvas_width):
        real = min_x + x*pixel_size
        for y in range(0, canvas_height):
            imag = min_y + y*pixel_size
            if mandel(real, imag, nb_iterations):
                draw_pixel(x, canvas_height - y, canvas)

The profiling result is as follows:
       3673815 function calls in 3.873 CPU seconds

 ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     11    2.632    0.239    3.706    0.337 mandel2d_cy.pyx:24(create_fractal)
3168000    1.002    0.000    1.002    0.000 mandel2d_cy.pyx:4(mandel)
      1    0.155    0.155    0.155    0.155 {_tkinter.create}
 505173    0.072    0.000    0.072    0.000 mandel2d_cy.pyx:18(draw_pixel)
     37    0.009    0.000    0.009    0.000 {built-in method call}

Simply by moving over some of the code to the cython module, I have reduced the profiling time to almost half of it previous value. Looking more closely at the profiling results, I also notice that calls to mandel() now only appear once; some overhead in calling cython functions from python modules has disappeared. Let's see what happens if I now add some type information.
def create_fractal(int canvas_width, int canvas_height,
                       double min_x, double min_y, double pixel_size,
                       int nb_iterations, canvas):
    cdef int x, y
    cdef double real, imag

    for x in range(0, canvas_width):
        real = min_x + x*pixel_size
        for y in range(0, canvas_height):
            imag = min_y + y*pixel_size
            if mandel(real, imag, nb_iterations):
                draw_pixel(x, canvas_height - y, canvas)

The result is only slightly better:
       3673815 function calls in 3.475 CPU seconds

 ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     11    2.189    0.199    3.308    0.301 mandel2e_cy.pyx:24(create_fractal)
3168000    1.046    0.000    1.046    0.000 mandel2e_cy.pyx:4(mandel)
      1    0.135    0.135    0.135    0.135 {_tkinter.create}
 505173    0.074    0.000    0.074    0.000 mandel2e_cy.pyx:18(draw_pixel)
     37    0.028    0.001    0.028    0.001 {built-in method call}

However, one thing I remember from the little I know about C it that, not only do variables have to be declared to be of a certain type, but the same has to be done to functions as well. Here, mandel() has not been declared to be of a specific type, so cython assumes it to be a generic Python object. After reading the cython documentation, and noticing that mandel() is only called from within the cython module, I conclude that not only should I specify the type for mandel() but that it probably makes sense to specify that it can be "inlined"; I also do the same for draw_pixel().
# mandel2f_cy.pyx
# cython: profile=True

cdef inline bint mandel(double real, double imag, int max_iterations=20):
    '''determines if a point is in the Mandelbrot set based on deciding if,
       after a maximum allowed number of iterations, the absolute value of
       the resulting number is greater or equal to 2.'''
    cdef double z_real = 0., z_imag = 0.
    cdef int i

    for i in range(0, max_iterations):
        z_real, z_imag = ( z_real*z_real - z_imag*z_imag + real,
                     2*z_real*z_imag + imag )
        if (z_real*z_real + z_imag*z_imag) >= 4:
            return False
    return (z_real*z_real + z_imag*z_imag) < 4

cdef inline void draw_pixel(x, y, canvas):
    '''Simulates drawing a given pixel in black by drawing a black line
       of length equal to one pixel.'''
    return
    #canvas.create_line(x, y, x+1, y, fill="black")

This yields a nice improvement.
       3673815 function calls in 2.333 CPU seconds

 ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     11    1.190    0.108    2.194    0.199 mandel2f_cy.pyx:24(create_fractal)
3168000    0.930    0.000    0.930    0.000 mandel2f_cy.pyx:4(mandel)
      1    0.127    0.127    0.127    0.127 {_tkinter.create}
 505173    0.074    0.000    0.074    0.000 mandel2f_cy.pyx:18(draw_pixel)
     37    0.009    0.000    0.009    0.000 {built-in method call}

However... I asked cython to "inline" mandel, thus treating them as a pure C function. Yet, they both appear in the Python profiling information, which was not the case for abs() once I used cython for the first time. The reason it appears is that cython has been instructed to profile all functions in the module, via the directive at the top. I can selectively turn off the profiling for an individual function by importing the "cython module" and using a special purpose decorator as follows.
# mandel2g_cy.pyx
# cython: profile=True

import cython

@cython.profile(False)
cdef inline bint mandel(double real, double imag, int max_iterations=20):
    '''determines if a point is in the Mandelbrot set based on deciding if,
       after a maximum allowed number of iterations, the absolute value of
       the resulting number is greater or equal to 2.'''
    cdef double z_real = 0., z_imag = 0.
    cdef int i

    for i in range(0, max_iterations):
        z_real, z_imag = ( z_real*z_real - z_imag*z_imag + real,
                     2*z_real*z_imag + imag )
        if (z_real*z_real + z_imag*z_imag) >= 4:
            return False
    return (z_real*z_real + z_imag*z_imag) < 4

cdef inline void draw_pixel(x, y, canvas):
    '''Simulates drawing a given pixel in black by drawing a black line
       of length equal to one pixel.'''
    return
    #canvas.create_line(x, y, x+1, y, fill="black")

The result is even better than I would have expected!
      505519 function calls in 0.817 CPU seconds

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    11    0.605    0.055    0.676    0.061 mandel2g_cy.pyx:27(create_fractal)
     1    0.128    0.128    0.128    0.128 {_tkinter.create}
504877    0.070    0.000    0.070    0.000 mandel2g_cy.pyx:21(draw_pixel)
    37    0.010    0.000    0.010    0.000 {built-in method call}
    11    0.001    0.000    0.678    0.062 viewer2b.py:20(draw_fractal)

From 85 seconds (at the beginning of this post) down to 0.8 seconds: a reduction by a factor of 100 ...thank you cython!  :-)

However, increasing the number of iterations to 1000 (from the current value of 100 used for testing) does increase the time significantly.
      495235 function calls in 3.872 CPU seconds

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    11    3.653    0.332    3.723    0.338 mandel2g_cy.pyx:27(create_fractal)
     1    0.136    0.136    0.136    0.136 {_tkinter.create}
494593    0.071    0.000    0.071    0.000 mandel2g_cy.pyx:21(draw_pixel)
    37    0.009    0.000    0.009    0.000 {built-in method call}
    11    0.001    0.000    3.726    0.339 viewer2b.py:20(draw_fractal)

It is probably a good time to put back the drawing to see what the overall time profile looks like in a more realistic situation.
      5441165 function calls in 20.747 CPU seconds

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
494604    8.682    0.000   14.427    0.000 Tkinter.py:2135(_create)
    11    3.863    0.351   20.572    1.870 mandel2g_cy.pyx:27(create_fractal)
494632    2.043    0.000    5.326    0.000 Tkinter.py:1046(_options)
494657    1.845    0.000    2.861    0.000 Tkinter.py:77(_cnfmerge)
494593    1.548    0.000   16.709    0.000 mandel2g_cy.pyx:21(draw_pixel)

Clearly, the limiting time factor is now the Tkinter based drawing, and not the other code. It is time to think of a better drawing strategy. However, this will have to wait until next post.

Saturday, January 16, 2010

Profiling adventures [and cython]: basic profiling

In the previous blog post, I introduced a simple Tkinter-based viewer for the Mandelbrot set. As I mentioned at that time, the viewer was really too slow to be usable. In this post, I will start do some basic profiling and start looking for some strategies designed to make it faster.


The first rule for making an application faster is to do a proper profile rather than guessing. I make use of the profiler module, focusing on the main method (draw_fractal()) which I wish to make faster, and paying a closer look only at the most time-consuming functions/methods.


# profile1.py

import pstats
import cProfile

from viewer1 import tk, FancyViewer

def main():
    root = tk.Tk()
    app = FancyViewer(root)
    app.nb_iterations = 100
    for i in range(10):
        app.draw_fractal()

if __name__ == "__main__":
    cProfile.run("main()", "Profile.prof")
    s = pstats.Stats("Profile.prof")
    s.strip_dirs().sort_stats("time").print_stats(10)

The profile run will call draw_fractal() once, when app is created, with the number of iterations for mendel() set at 20 (the default) and then call it again 10 times with a larger number of iterations. Running the profiler adds some overhead. Based on the previous run with no profiles, I would have expected a profiler run to take approximately 75 seconds: a little over 4 seconds for the initial set up and slightly more than 7 seconds for each of the subsequent runs. Instead, what I observe is that


69802205 function calls in 115.015 CPU seconds

  Ordered by: internal time
  List reduced from 60 to 10 due to restriction <10>

  ncalls  tottime  percall  cumtime  percall filename:lineno(function)
 3168000   59.908    0.000   81.192    0.000 mandel1a.py:3(mandel)
57908682   14.005    0.000   14.005    0.000 {abs}
      11   13.387    1.217  114.484   10.408 viewer1.py:23(draw_fractal)
  505184   10.950    0.000   17.672    0.000 Tkinter.py:2135(_create)
 3168001    7.279    0.000    7.279    0.000 {range}
  505212    2.359    0.000    6.220    0.000 Tkinter.py:1046(_options)
  505237    2.169    0.000    3.389    0.000 Tkinter.py:77(_cnfmerge)
  505173    1.457    0.000   19.902    0.000 viewer1.py:17(draw_pixel)
 1010400    0.906    0.000    0.906    0.000 {method 'update' of 'dict' objects}
 1010417    0.817    0.000    0.817    0.000 {_tkinter._flatten}

Clearly, running the profiler adds some overhead. I should also add that there are variations from run to run done with the profiler, caused by background activities. As a consequence, I normally run the profiler 3 times and focus on the fastest of the three runs; however I will not bother to do this here: I simply want to start by establishing some rough baseline to identify the main contributors to the relative lack of speed of this program.


It appears clear that the largest contributor to the overall execution time is mandel(). Going down the lists of functions that contribute significantly to the overall time, I notice quite a few calls to Tkinter function/methods. So as to reduce the time to take a given profile, and to focus on mandel(), I will temporarily eliminate some Tkinter calls by changing draw_pixel() as follows.


def draw_pixel(self, x, y):
    '''Simulates drawing a given pixel in black by drawing a black line
       of length equal to one pixel.'''
    return
    #self.canvas.create_line(x, y, x+1, y, fill="black")

Also, since I want to establish a rough baseline, I should probably see what happens when I increase the number of iterations from 100 to 1000 for mandel(), which is what I expect to have to use in many cases to get accurate pictures. I do this first using Python 2.5


465659765 function calls in 574.947 CPU seconds

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
  3168000  419.296    0.000  561.467    0.000 mandel1a.py:3(mandel)
458828552  100.464    0.000  100.464    0.000 {abs}
  3168001   41.707    0.000   41.707    0.000 {range}
       11   12.731    1.157  574.341   52.213 viewer1.py:23(draw_fractal)
        1    0.596    0.596    0.596    0.596 {_tkinter.create}
   494593    0.140    0.000    0.140    0.000 viewer1.py:17(draw_pixel)
       37    0.010    0.000    0.010    0.000 {built-in method call}
       11    0.000    0.000    0.001    0.000 Tkinter.py:2135(_create)
       54    0.000    0.000    0.000    0.000 {method 'update' of 'dict' objects}
       39    0.000    0.000    0.000    0.000 Tkinter.py:1046(_options)

Ouch! Close to 10 minutes of running time. However, it is clear that I have accomplished my goal of reducing the importance of Tkinter calls so that I can focus on my own code. Let's repeat this profiling test using Python 3.1.


462491974 function calls (462491973 primitive calls) in 506.148 CPU seconds

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
  3168000  386.169    0.000  493.823    0.000 mandel1a.py:3(mandel)
458828552  107.654    0.000  107.654    0.000 {built-in method abs}
       11   11.474    1.043  505.439   45.949 viewer1.py:23(draw_fractal)
        1    0.694    0.694    0.694    0.694 {built-in method create}
   494593    0.140    0.000    0.140    0.000 viewer1.py:17(draw_pixel)
       48    0.014    0.000    0.014    0.000 {method 'call' of 'tkapp' objects}
       39    0.000    0.000    0.001    0.000 __init__.py:1032(_options)
       64    0.000    0.000    0.001    0.000 __init__.py:66(_cnfmerge)
      206    0.000    0.000    0.000    0.000 {built-in method isinstance}
      2/1    0.000    0.000  506.148  506.148 {built-in method exec}

We note that the total time taken is significantly less. Doing a comparison function by function, two significant differences appear: The built-in function abs is 7% slower with Python 3.1, which is a bit disappointing. On the other hand, range no longer appears as a function in Python 3.1; this appears to be the main contributor to the significant decrease in time when using Python 3.1 as compared with Python 2.5. This is easily understood: range in Python 3.1 does not create a list like it did in Python 2.x; it is rather like the old xrange. This suggest that I modify mandel1a.py to be as follows:


# mandel1b.py

import sys
if sys.version_info < (3,):
    range = xrange

def mandel(c):
    '''determines if a point is in the Mandelbrot set based on deciding if,
       after 20 iterations, the absolute value of the resulting number is
       greater or equal to 2.'''
    z = 0
    for iter in range(0, 20):
        z = z**2 + c
        if abs(z) >= 2:
            return False
    return abs(z) < 2

I also do a similar change to viewer1.py. From now on, except where otherwise noted, I will focus on using only Python 2.5. So, after doing this change, we can run the profiler one more time with the same number of iterations. The result is as follows:
462491765 function calls in 503.926 CPU seconds

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
  3168000  391.297    0.000  491.642    0.000 mandel1b.py:7(mandel)
458828552  100.345    0.000  100.345    0.000 {abs}
       11   11.409    1.037  503.189   45.744 viewer1.py:23(draw_fractal)
        1    0.726    0.726    0.726    0.726 {_tkinter.create}
   494593    0.136    0.000    0.136    0.000 viewer1.py:17(draw_pixel)
       37    0.009    0.000    0.009    0.000 {built-in method call}
       11    0.000    0.000    0.001    0.000 Tkinter.py:2135(_create)
       54    0.000    0.000    0.000    0.000 {method 'update' of 'dict' objects}
       39    0.000    0.000    0.000    0.000 Tkinter.py:1046(_options)
       64    0.000    0.000    0.001    0.000 Tkinter.py:77(_cnfmerge)

This is now approximately the same as what we had for Python 3.1 as expected. Moving down on the list of time-consuming functions, we note that abs appears to be another function we should look at. Let's first reduce the number of iterations inside mandel to 100, so that a profiling run does not take as long but that proper attention is still focuses on mandel as well as abs. Here's the result from a typical run to use as a new baseline:

61582475 function calls in 81.998 CPU seconds

  ncalls  tottime  percall  cumtime  percall filename:lineno(function)
 3168000   55.347    0.000   69.054    0.000 mandel1b.py:7(mandel)
57908682   13.706    0.000   13.706    0.000 {abs}
      11   11.591    1.054   80.773    7.343 viewer1.py:23(draw_fractal)
       1    1.213    1.213    1.213    1.213 {_tkinter.create}
  505173    0.125    0.000    0.125    0.000 viewer1.py:17(draw_pixel)
      37    0.011    0.000    0.011    0.000 {built-in method call}
      11    0.000    0.000    0.001    0.000 Tkinter.py:2135(_create)
       4    0.000    0.000    0.000    0.000 {posix.stat}
       3    0.000    0.000    0.000    0.000 Tkinter.py:1892(_setup)
      39    0.000    0.000    0.000    0.000 Tkinter.py:1046(_options)

Since a fair bit of time is spent inside abs(), perhIaps  could speed things up by using another method. The way that we approximate the Mandlebrot set is by iterating over a number of time and checking if the absolute value of the complex number is greater than 2; if it is, then it can be proven that subsequent iterations will yield larger and larger values which means that the number we are considering is not in the Mandelbrot set. Since taking an absolute value of a complex number involves taking a square root, perhaps we can speed things up by not taking the square root. Let's implement this and try it out.

# mandel1c.py

import sys
if sys.version_info < (3,):
    range = xrange

def mandel(c, max_iterations=20):
    '''determines if a point is in the Mandelbrot set based on deciding if,
       after a maximum allowed number of iterations, the absolute value of
       the resulting number is greater or equal to 2.'''
    z = 0
    for iter in range(0, max_iterations):
        z = z**2 + c
        z_sq = z.real**2 + z.imag**2
        if z_sq >= 4:
            return False
    return z_sq < 4
3673150 function calls in 84.807 CPU seconds

 ncalls  tottime  percall  cumtime  percall filename:lineno(function)
3168000   73.210    0.000   73.210    0.000 mandel1c.py:7(mandel)
     11   10.855    0.987   84.204    7.655 viewer1.py:23(draw_fractal)
      1    0.593    0.593    0.593    0.593 {_tkinter.create}
 504530    0.137    0.000    0.137    0.000 viewer1.py:17(draw_pixel)
     37    0.009    0.000    0.009    0.000 {built-in method call}
     11    0.000    0.000    0.001    0.000 Tkinter.py:2135(_create)
     54    0.000    0.000    0.000    0.000 {method 'update' of 'dict' objects}
     39    0.000    0.000    0.000    0.000 Tkinter.py:1046(_options)
     64    0.000    0.000    0.001    0.000 Tkinter.py:77(_cnfmerge)
     22    0.000    0.000    0.001    0.000 Tkinter.py:1172(_configure)

The result is worse than before, even though the total number of function calls has almost been cut in half! Actually, this should not come as a total surprise: abs is a Python built-in function, which has been already optimized in C. Extracting the real and imaginary parts explictly like we have done is bound to be a time-consuming operation when performed in pure Python as opposed to C. At this point, we might be tempted to convert complex numbers everywhere into pairs of real numbers so as to reduce the overhead of dealing with complex numbers ... but this would not have any significant effect on the overall time. [Those curious may want to try ... I've done it and it's not worth reporting in details.] 
 
Clearly, I need a different strategy if we are to reduce significantly the execution time. It is time to introduce cython. However, this will have to wait until the next blog post!

Friday, January 15, 2010

Profiling adventures and cython - setting the stage

Summary This post is the first in a series dedicated to
examining the use of profiling and, eventually, using cython, as a
means to improve greatly the speed of an application. The intended
audience is for programmers who have never done any profiling and/or
never used cython before. Note that we will not make use of cython
until the third post in this series.


Preamble

Python is a great multi-purpose language which is really fun to use. However, it is sometimes too slow for some applications. Since I only program for fun, I had never really faced a situation where I found Python's speed to be truly a limiting factor - at least, not until a few weeks ago when I did some exploration of a  four-colouring grid problem I talked about. I started exploring ways to speed things up using only Python and trying to come up with different algorithms, but every one I tried was just too slow. So, I decided it was time to take the plunge and do something different. After considering various alternatives, like using shedskin or
attempting to write a C extension (which I found too daunting since I don't know C), I decided to try to use  cython.

cython, for those that don't know it, is a Python look-alike language that claims
to make writing C extensions for the Python language
as easy as Python itself.

After looking at a few examples on the web, I concluded that such a rather bold statement might very well be true and that it was worthwhile trying it out on a more complex example. Furthermore, I thought it might be of interest to record what I've done in a series of blog posts, as a more detailed example than what I had found so far on the web. As I was wondering if an esoteric problem like the four-colouring grid challenge mentioned previously was a good candidate to use as an example, by sheer serendipity, I came accross a link on reddit by a new programmer about his simple Mandelbrot viewer.

Who does not like fractals? ... Furthermore, I have never written a fractal viewer. This seemed like a good time to write one. So, my goal at the end of this series of posts, is to have a "nice" (for some definition of "nice") fractal viewer that is fast enough for explorations of the Mandelbrot set.  In addition, in order to make it easy for anyone having Python on their system to follow along and try their own variation, I decided to stick by the following constraints:
  • With the exception of cython, I will only use modules found in the
    standard library. This means using Tkinter for the GUI.
  • The program should work using Python 2.5+ (including Python 3).

So, without further ado, and based on the example found on the reddit link I mentioned, here's a very basic fractal viewer that can be used as a starting point.


''' mandel1.py

Mandelbrot set drawn in black and white.'''

import time

import sys
if sys.version_info < (3,):
    import Tkinter as tk
else:
    import tkinter as tk

def mandel(c):
    '''determines if a point is in the Mandelbrot set based on deciding if,
       after 20 iterations, the absolute value of the resulting number is
       greater or equal to 2.'''
    z = 0
    for iter in range(0, 20):
        z = z**2 + c
        if abs(z) >= 2:
            return False
    return abs(z) < 2

class Viewer(object):
    def __init__(self, parent, width=500, height=500):
        self.canvas = tk.Canvas(parent, width=width, height=height)
        self.width = width
        self.height = height

        # the following "shift" variables are used to center the drawing
        self.shift_x = 0.5*self.width
        self.shift_y = 0.5*self.height
        self.scale = 0.01

        self.canvas.pack()
        self.draw_fractal()

    def draw_pixel(self, x, y):
        '''Simulates drawing a given pixel in black by drawing a black line
           of length equal to one pixel.'''
        self.canvas.create_line(x, y, x+1, y, fill="black")

    def draw_fractal(self):
        '''draws a fractal picture'''
        begin = time.time()
        print("Inside draw_fractal.")
        for x in range(0, self.width):
            real = (x - self.shift_x)*self.scale
            for y in range(0, self.height):
                imag = (y - self.shift_y)*self.scale
                c = complex(real, imag)
                if mandel(c):
                    self.draw_pixel(x, self.height - y)
        print("Time taken for calculating and drawing = %s" %
                                                (time.time() - begin))

if __name__ == "__main__":
    print("Starting...")
    root = tk.Tk()
    app = Viewer(root)
    root.mainloop()

At this point, perhaps a few comments about the program might be useful
  1. I have tried to write the code in the most straightforward and pythonic way, with no thought given to making calculations fast. It should be remembered that this is just a starting point: first we make it work, then, if needed, we make it fast.
  2. The function mandel() is the simplest translation of the Mandelbrot fractal iteration into Python code that I could come up with. The fact that Python has a built-in complex type makes it very easy to implement the standard Mandelbrot set algorithm.
  3. I have taken the maximum number of iterations inside mandel() to be 20, the same value used in the post I mentioned before. According to the very simple method used to time the application, it takes about 2 seconds on my computer to draw a simple picture. This is annoying slow. Furthermore, by looking at the resulting picture, and trying out with different number of iterations in mandel(), it is clear that 20 iterations is not sufficient to adaquately represent the Mandelbrot set; this is especially noticeable when exploring smaller regions of the complex plane. A more realistic value might be to take 100 if not 1000 iterations which takes too long to be practical.
  4. Tkinter's canvas does not have a method to set the colour of individual pixels. We can simulate such a method by drawing a line (for which there is a primitive method) of length 1.
  5. The screen vertical coordinates ("y") increase in values from the top towards the bottom, in opposite direction to the usual vertical coordinates in the complex plane. While the picture produced is vertically symmetric about the x-axis, I nonetheless wrote the code so that the inversion of direction was properly handled.
This basic application is not really useful as a tool for exploring the Mandelbrot set, as the region of the complex plane it displays is fixed. However, it is useful to start with something simple like this as a first prototype. Once we know it is working we can move on to a better second version. So, let's write a fancier fractal viewer following the outline below:

class Viewer(object):
    '''Base class viewer to display fractals'''

        # The viewer should be able to enlarge ("zoom in") various regions
        # of the complex plane.  I will implement this
        # using keyboard shortcuts.
        #
        self.parent.bind("+", self.zoom_in)
        self.parent.bind("-", self.zoom_out)
    def zoom_in(self, event):
    def zoom_out(self, event):
    def change_scale(self, scale):

        #  Since one might want to "zoom in" quickly in some regions,
        # and then be able to do finer scale adjustments,
        # I will use keyboard shortcuts to enable switching back
        # and forth between two possible zooming mode.
        # A better application might give the user more control
        # over the zooming scale.
        self.parent.bind("n", self.normal_zoom)
        self.parent.bind("b", self.bigger_zoom)
    def normal_zoom(self, event, scale=1):
    def bigger_zoom(self, event):

        # Set the maximum number of iterations via a keyboard-triggered event
        self.parent.bind("i", self.set_max_iter)
    def set_max_iter(self, event):

        # Like what is done with google maps and other
        # similar applications, we should be able to move the image
        # to look at various regions of interest in the complex plane.
        # I will implement this using mouse controls.
        self.parent.bind("<button-1>", self.mouse_down)
        self.parent.bind("<button1-motion>", self.mouse_motion)
        self.parent.bind("<button1-buttonrelease>", self.mouse_up)
    def mouse_down(self, event):
    def mouse_motion(self, event):
    def mouse_up(self, event):

        # Presuming that "nice pictures" will be eventually produced,
        # and that it might be desired to reproduce them,
        # I will include some information about the region of the
        # complex plane currently displayed.
    def info(self):
        '''information about fractal location'''
  • Furthermore, while I plan to use proper profiling tools, I will nonetheless display some basic timing information as part of the GUI as a quick evaluation
    of the speed of the application.
  • Finally, since I expect that both the function mandel() and the drawing method draw_fractal to be the speed-limiting factors, I will leave them out of the fractal viewer and work on them separately. If it turns out that the profiling information obtained indicates otherwise, I will revisit this hypothesis.
Here is a second prototype for my fractal viewer, having the features described above.

''' viewer.py

Base class viewer for fractals.'''

import sys
if sys.version_info < (3,):
    import Tkinter as tk
    import tkSimpleDialog as tk_dialog
else:
    import tkinter as tk
    from tkinter import simpledialog as tk_dialog

class Viewer(object):
    '''Base class viewer to display fractals'''

    def __init__(self, parent, width=600, height=480,
                 min_x=-2.5, min_y=-1.5, max_x=1.):

        self.parent = parent
        self.canvas_width = width
        self.canvas_height = height

        # The following are drawing boundaries in the complex plane
        self.min_x = min_x
        self.min_y = min_y
        self.max_x = max_x
        self.calculate_pixel_size()
        self.max_y = self.min_y + self.canvas_height*self.pixel_size

        self.calculating = False
        self.nb_iterations = 20
        self.normal_zoom(None)

        self.canvas = tk.Canvas(parent, width=width, height=height)
        self.canvas.pack()
        self.status = tk.Label(self.parent, text="", bd=1, relief=tk.SUNKEN,
                               anchor=tk.W)
        self.status.pack(side=tk.BOTTOM, fill=tk.X)
        self.status2 = tk.Label(self.parent, text=self.info(), bd=1,
                                relief=tk.SUNKEN, anchor=tk.W)
        self.status2.pack(side=tk.BOTTOM, fill=tk.X)

        # We change the size of the image using the keyboard.
        self.parent.bind("+", self.zoom_in)
        self.parent.bind("-", self.zoom_out)
        self.parent.bind("n", self.normal_zoom)
        self.parent.bind("b", self.bigger_zoom)

        # Set the maximum number of iterations via a keyboard-triggered event
        self.parent.bind("i", self.set_max_iter)

        # We move the canvas using the mouse.
        self.translation_line = None
        self.parent.bind("<button-1>", self.mouse_down)
        self.parent.bind("<button1-motion>", self.mouse_motion)
        self.parent.bind("<button1-buttonrelease>", self.mouse_up)

        self.draw_fractal()  # Needs to be implemented by subclass

    def info(self):
        '''information about fractal location'''
        return "Location: (%f, %f) to (%f, %f)" %(self.min_x, self.min_y,
                                                  self.max_x, self.max_y)

    def calculate_pixel_size(self):
        '''Calculates the size of a (square) pixel in complex plane
        coordinates based on the canvas_width.'''
        self.pixel_size = 1.*(self.max_x - self.min_x)/self.canvas_width
        return

    def mouse_down(self, event):
        '''records the x and y positions of the mouse when the left button
           is clicked.'''
        self.start_x = self.canvas.canvasx(event.x)
        self.start_y = self.canvas.canvasy(event.y)

    def mouse_motion(self, event):
        '''keep track of the mouse motion by drawing a line from its
           starting point to the current point.'''
        x = self.canvas.canvasx(event.x)
        y = self.canvas.canvasy(event.y)

        if (self.start_x != event.x)  and (self.start_y != event.y) :
            self.canvas.delete(self.translation_line)
            self.translation_line = self.canvas.create_line(
                                self.start_x, self.start_y, x, y, fill="orange")
            self.canvas.update_idletasks()

    def mouse_up(self, event):
        '''Moves the canvas based on the mouse motion'''
        dx = (self.start_x - event.x)*self.pixel_size
        dy = (self.start_y - event.y)*self.pixel_size
        self.min_x += dx
        self.max_x += dx
        # y-coordinate in complex plane run in opposite direction from
        # screen coordinates
        self.min_y -= dy
        self.max_y -= dy
        self.canvas.delete(self.translation_line)
        self.status.config(text="Moving the fractal.  Please wait.")
        self.status.update_idletasks()
        self.status2.config(text=self.info())
        self.status2.update_idletasks()
        self.draw_fractal()

    def normal_zoom(self, event, scale=1):
        '''Sets the zooming in/out scale to its normal value'''
        if scale==1:
            self.zoom_info = "[normal zoom]"
        else:
            self.zoom_info = "[faster zoom]"
        if event is not None:
            self.status.config(text=self.zoom_info)
            self.status.update_idletasks()
        self.zoom_in_scale = 0.1
        self.zoom_out_scale = -0.125

    def bigger_zoom(self, event):
        '''Increases the zooming in/out scale from its normal value'''
        self.normal_zoom(event, scale=3)
        self.zoom_in_scale = 0.3
        self.zoom_out_scale = -0.4

    def zoom_in(self, event):
        '''decreases the size of the region of the complex plane displayed'''
        if self.calculating:
            return
        self.status.config(text="Zooming in.  Please wait.")
        self.status.update_idletasks()
        self.change_scale(self.zoom_in_scale)

    def zoom_out(self, event):
        '''increases the size of the region of the complex plane displayed'''
        if self.calculating:
            return
        self.status.config(text="Zooming out.  Please wait.")
        self.status.update_idletasks()
        self.change_scale(self.zoom_out_scale)

    def change_scale(self, scale):
        '''changes the size of the region of the complex plane displayed and
           redraws'''
        if self.calculating:
            return
        dx = (self.max_x - self.min_x)*scale
        dy = (self.max_y - self.min_y)*scale
        self.min_x += dx
        self.max_x -= dx
        self.min_y += dy
        self.max_y -= dy
        self.calculate_pixel_size()
        self.draw_fractal()

    def set_max_iter(self, event):
        '''set maximum number of iterations'''
        i = tk_dialog.askinteger('title', 'prompt')
        if i is not None:
            self.nb_iterations = i
            self.status.config(text="Redrawing.  Please wait.")
            self.status.update_idletasks()
            self.draw_fractal()

    def draw_fractal(self):
        '''draws a fractal on the canvas'''
        raise NotImplementedError

I move the Mandelbrot set calculation in a separate file.

# mandel1a.py

def mandel(c, max_iterations=20):
    '''determines if a point is in the Mandelbrot set based on deciding if,
       after a maximum allowed number of iterations, the absolute value of
       the resulting number is greater or equal to 2.'''
    z = 0
    for iter in range(0, max_iterations):
        z = z**2 + c
        if abs(z) >= 2:
            return False
    return abs(z) < 2

And, finally, I implement the missing functions for the viewer in a new main application.


# viewer1.py

from mandel1a import mandel
from viewer import Viewer
import time

import sys
if sys.version_info &lt; (3,):
    import Tkinter as tk
    range = xrange
else:
    import tkinter as tk

class FancyViewer(Viewer):
    '''Application to display fractals'''

    def draw_pixel(self, x, y):
        '''Simulates drawing a given pixel in black by drawing a black line
           of length equal to one pixel.'''
        self.canvas.create_line(x, y, x+1, y, fill="black")

    def draw_fractal(self):
        '''draws a fractal on the canvas'''
        self.calculating = True
        begin = time.time()
        # clear the canvas
        self.canvas.create_rectangle(0, 0, self.canvas_width,
                                    self.canvas_height, fill="white")
        for x in range(0, self.canvas_width):
            real = self.min_x + x*self.pixel_size
            for y in range(0, self.canvas_height):
                imag = self.min_y + y*self.pixel_size
                c = complex(real, imag)
                if mandel(c, self.nb_iterations):
                    self.draw_pixel(x, self.canvas_height - y)
        self.status.config(text="Time required = %.2f s  [%s iterations]  %s" %(
                                (time.time() - begin), self.nb_iterations,
                                                                self.zoom_info))
        self.status2.config(text=self.info())
        self.status2.update_idletasks()
        self.calculating = False

if __name__ == "__main__":
    root = tk.Tk()
    app = FancyViewer(root)
    root.mainloop()


Let me conclude with few black and white pictures obtained using this program, which, if you look at the time, highlight the need for something faster. First for 20 iterations, drawn in 4 seconds.




Then, for 100 interations - better image, but 7 seconds to draw...





Next post, I'll start profiling the application and make it faster.

Sunday, January 10, 2010

Python + cython: faster than C?

[Note added on January 15, 2013: I am amazed that this post, clearly written tongue-in-cheek 3 years ago, can still spur people into doing their own test, and feeling the urge to comment.]

Now that I have your attention... ;-)

I've been playing around with cython and will write more about it soon.  But I could not resist doing a quick test when I read this post, comparing C and Java on some toy micro benchmark.

First, the result:

andre$ gcc -o fib -O3 fib.c
andre$ time ./fib
433494437

real    0m3.765s
user    0m3.463s
sys     0m0.028s

andre$ time python fib_test.py
433494437

real    0m2.953s
user    0m2.452s
sys     0m0.380s

Yes, Python+cython is faster than C!  :-)

Now, the code.  First, the C program taken straight from this post, except for a different, more meaningful value for "num" ;-)

#include 

double fib(int num)
{
   if (num <= 1)
       return 1;
   else
       return fib(num-1) + fib(num-2);
 }
 int main(void)
{
     printf("%.0f\n", fib(42));
    return 0;
 }

Next, the cython code ...

# fib.pyx

cdef inline int fib(int num):
    if (num <= 1):
        return 1
    else:
        return fib(num-1) + fib(num-2)

def fib_import (int num):
    return fib(num)

... and the Python file that calls it

# fib_test.py

import pyximport
pyximport.install()

from fib import fib_import

print fib_import(42)


  1. I know, I declared fib() to be of type "double" in the C code (like what was done in the original post) and "int" in the cython code; however, if I declare it to be of type "int" in the C code, I get -0 as the answer instead of 433494437.  I could declare it to be of type "unsigned int" ... but even when I did that, the Python code was marginally faster.
  2. If I declared fib to be of type "double" in the cython code, it was slightly slower than the corresponding C code.  However, what Python user would ever think of an integer variable as a floating point number! ;-)
  3. What is most impressive is that the cython code is NOT pre-compiled, unlike the C code.  However, to be fair ... I did run it a few times and took the best result, and it was always slower the first time.
  4. Yes, it is silly to compare such micro-benchmarks ... but it can be fun, no? ;-)
More seriously: using cython is a relatively painless way to significantly increase the speed of Python applications ... without having to learn a totally different language.

Saturday, January 02, 2010

Rur-ple is alive again!


Thanks to some wonderful work by some developers who joined the project, rur-ple is moving forward again. :-)  Its new home also features a brand new logo for it displayed above.

Thanks in particular go to Frederic Muller who is moving things along rather nicely.

Friday, December 25, 2009

The 17x17 challenge: setting up the grid

In the previous post, I talked about the 17x17 challenge, but did not provide any code.  Today, I will describe the code I use to represent a given grid.  Consider the following grid where A and b are used to represent two different colours:
AbAb
bAbA
bbAA
bAAA
This grid has two rectangles, one of which is emphasized below
....
.A.A
....
.A.A
Finding such a rectangle in a small grid is fairly easy to do.  A naive, but very inefficient way to do this is to have a function like the following:
def find_rect_candidates(row_1, row_2):
   points_in_common = [0, 0, ...] # one per colour
    for i, point in enumerate(row_1):
       if row_1[i] == row_2[i]:
          points_in_common[row_1[i]] += 1
   return points_in_common
where row_x[i] is set to a given colour, represented by an integer (0, 1, 2, ...). For an NxN grid, the first loop is of order N.  Note that this does not tell us (yet) if we have found a rectangle; we still have to loop over points_in_common and see if any entry is greater than 1.

A better way to do the comparison, which does not grow with N, is mentioned in this post and is based on the following observation:  for a given row, at a given point, a given colour is either present (True or 1) or not (False or 0). Thus, a given colour distribution on a single row can be represented as a string of 0s and 1s ... which can be thought of as the binary representation of an integer.  For example, the row containing the colour "A" in the pattern ".....A..A" can have this pattern represented by the number 9="1001".  Consider another row represented by the number 6="110". These two rows have no points in common (for this colour) and hence do not form a rectangle.  If we do a bitwise "and" for these two rows, i.e. 9&6 we will get zero.  This is achieved by a single operation instead of a series of comparisons.

What happens if two rows have a single point in common (for a given colour) so that no rectangle is present?  For example, consider 9="1001" and 15="1110".  If we do a bitwise "and" we have

answer = 9 & 15 = 8 = "1000"

Taking the bitwise "and" of answer and answer-1, we get

answer & (answer-1) = 8 & 7 = "1000" & "111" = 0.

If we have two points (1 bit) in common, it is easy to see that the bitwise comparison of answer & (answer-1) will not give zero.

Going back to the function above, we could rewrite it instead as follows:

def find_rect_candidates(row_1, row_2, colours):
   points_in_common = [0, 0, ...] # one per colour
      for c in colours:
         points_in_common[c] = row_1[c] & row_2[c]
   return points_in_common


where we still have to do the same (N-independent) processing of the return value as with the function above to determine if we do have rectangles.

Now, without further ado, here is the basic code that I use to represent a grid, together with two utility functions.

def rectangles_for_two_rows(row_1, row_2):
    '''returns 0 if two rows (given colour, encoded as a bit string)
        do not form a rectangle -
        otherwise returns the points in common encoded as a bit string'''
    intersect = row_1 & row_2
    if not intersect:   # no point in common
        return 0
    # perhaps there is only one point in common of the form 00010000...
    if not(intersect & (intersect-1)):
        return 0
    else:
        return intersect

def count_bits(x):
    '''counts the number of bits
    see http://stackoverflow.com/questions/407587/python-set-bits-count-popcount
    for reference and other alternative'''
    return bin(x).count('1')

class AbstractGrid(object):
    def __init__(self, nb_colours, max_rows, max_columns):
        self.nb_colours = nb_colours
        self.colours = list(range(nb_colours))
        self.max_rows = max_rows
        self.max_columns = max_columns
        self.powers_of_two = [2**i for i in range(max_columns)]
        self.grid = {}

    def initialise(self):
        '''initialise a grid according to some strategy'''
        raise NotImplemented

    def print_grid(self):
        '''prints a representation of the grid.
        Used for diagnostic only - no need to optimize further.'''
        for row in self.grid:
            row_rep = []
            for column in range(self.max_columns-1, -1, -1):
                for colour_ in self.colours:
                    if self.powers_of_two[column] & self.grid[row][colour_]:
                        row_rep.append(str(colour_))
            print("".join(row_rep))

    def identify_intersect_points(self):
        '''identifies the dots that contribute to forming rectangles'''
        # possibly could cut down the calculation time by only computing
        # for colours that have changed...
        nb_intersect = [0 for colour in self.colours]
        intersect_info = []
        for colour in self.colours:
            for row in self.grid:
                for other_row in range(row+1, self.max_rows):
                    intersect = rectangles_for_two_rows(
                                            self.grid[row][colour],
                                            self.grid[other_row][colour])
                    if intersect != 0:
                        nb_intersect[colour] += count_bits(intersect)
                        intersect_info.append([colour, row, other_row, intersect])
        return nb_intersect, intersect_info

The actual code I use has a few additional methods introduced for convenience.  If anyone can find a better method to identify intersection points between rows (from which rectangles can be formed), I would be very interested to hear about it.

The 17x17 challenge

In a recent blog post, William Gasarch issued a challenge: find a 17x17 four-colour grid for which there are no rectangles with 4 corners of the same colour.  If you can do this, Gasarch will give you $289.  For a more detailed explanation, you can either read Gasarch's post, or this post which contains a slightly friendlier explanation of the problem.

This problem, which is fairly easy to state, is extremely hard to solve as the number of possible grid configurations is 4289 which is a number much too large to solve by random searches or by naive methods.  As can be seen from the comments of the posts mentioned above, many people have devoted a lot of cpu cycles in an attempt to find a solution, but with no success.  Like others, I have tried to write programs to solve it ... but with no luck.  In a future post, I may write about the various strategies I have implemented in Python in an attempt to solve this problem.

Gasarch mentions that it is not known if a solution can be found for a 17x18 four-colour grid or even an 18x18 grid or a 17x19 one.  (Note that because of symmetry, an MxN solution can be rotated to give an NxM solution.)  While the number of configurations is large, the larger the grid is, I believe I have found a proof demonstrating that no solution can be found for a 17x18 grid.  This proof builds on the work of Elizabeth Kupin who has found the largest "colour class" i.e. the number of single-colour points on a 17x17 grid which is rectangle-free.  A generic solution found by Kupin is reproduced below:



Consider a 17x18 grid (17 rows and 18 columns).  Such a grid has 306 points.  Using the pigeonhole principle, one can easily show that, for any colouring, at least one colour must cover 77 points.  (Indeed, the most symmetric colouring would be 77, 77, 76, 76.)  Also, if a 17x18 rectangle-free solution exist, it must contain a 17x17 rectangle-free subset, which we take to be the above solution (other solutions for the 17x17 grid can be derived from this one using interchanges of rows and/or columns).

Let us attempt to add points in an additional column. First, we can try to add points in a row with 4 elements.  Without loss of generality, we can take this row to be the first one (row A).  Once we do this, we can not add any more points to row 3-17 without creating a rectangle.  The only row to which we can add a point is the second (row B) bringing the total number of points to 76 - one short of what we need.

Perhaps we can first remove a point from the 17x17 solution and then add a new column.  There are three cases we must consider: 1) that of a point belonging to a row with 5 points and a column with 5 points; 2) that of a point belonging to a row with 5 points and a column with 4 points; 3) that of a point belonging to a row with 4 points and a column with 5 points.

Case 1) Without loss of generality, let us move the point on row F in the first column to a new additional column, keeping the number of points at 74.  It is easy to show that the only rows to which we can then add an additional point without creating a rectangle are rows A, C, D, E. Once we add such a point (say on row A), we can no longer add a point on any of the remaining rows (C, D, E) without creating a rectangle.

Case 2) Without loss of generality, let us move the left-most point on row Q to a new column.  The only row to which we can then add a point in this new column is row A, bringing the total to 75.  We can't add another point without creating a rectangle.

Case 3) Again, without loss of generality, let us remove the top-left element (A) and move it to a new column added to the above solution.  We can add one more point, bringing the total to 75, to that new column (in row B) without creating a rectangle; any other addition of a point on that new column will result in a rectangle.


This concludes the sketch of the proof. 

UPDATE: Instead of adding a column, we can add a row (R) with 3 points located in the 9th, 12th and 17th column, bringing the total to 77 points and keeping the grid rectangle-free.  So the 17x18 case is still open... but I can't see how one could add a row with an extra 4 points to build a "single-colour" rectangle-free 17x19 grid.  (However, I won't venture again to say it is a proof until I have looked at all cases exhaustively.)

Note that construction of solutions for a 17x17 grid such as those found by Kupin, even for a single colours, can possibly be used to restrict the search space for the more general problem and help find a solution.  Unfortunately, no one has been able to do this (yet).