diff --git a/sho/snp.py b/sho/snp.py index 76559c8..0914842 100644 --- a/sho/snp.py +++ b/sho/snp.py @@ -1,31 +1,52 @@ +import sys import numpy as np import matplotlib.pyplot as plt import copy + ######################################################################## # Utilities ######################################################################## def x(a): + """Return the first element of a 2-tuple. + >>> x([1,2]) + 1 + """ return a[0] def y(a): + """Return the second element of a 2-tuple. + >>> y([1,2]) + 2 + """ return a[1] def distance(a,b): - """Euclidean distance (in pixels).""" + """Euclidean distance (in pixels). + + >>> distance( (1,1),(2,2) ) == math.sqrt(2) + True + """ return np.sqrt( (x(a)-x(b))**2 + (y(a)-y(b))**2 ) -def highlight_sensors(domain, sensors): +def highlight_sensors(domain, sensors, val=2): + """Add twos to the given domain, in the cells where the given + sensors are located. + + >>> highlight_sensors( [[0,0],[1,1]], [(0,0),(1,1)] ) + [[2, 0], [1, 2]] + """ for s in sensors: # `coverage` fills the domain with ones, # adding twos will be visible in an image. - domain[y(s)][x(s)] = 2 + domain[y(s)][x(s)] = val return domain + ######################################################################## # Objective functions ######################################################################## @@ -133,7 +154,9 @@ def make_init(init, **kwargs): def num_neighb_square(sol, scale): """Draw a random vector in a square of witdh `scale` around the given one.""" - return sol + np.random.random(len(sol)) * scale - scale/2 + # TODO handle constraints + new = sol + (np.random.random(len(sol)) * scale - scale/2) + return new def bit_neighb_square(sol, scale): @@ -145,6 +168,7 @@ def bit_neighb_square(sol, scale): for px in range(len(sol[py])): if sol[py][px] == 1: new[py][px] = 0 # Remove original position. + # TODO handle constraints d = np.random.randint(-scale//2,scale//2,2) new[py+y(d)][px+x(d)] = 1 return new @@ -162,38 +186,81 @@ def make_neig(neighb, **kwargs): # Stopping criterions ######################################################################## -def iter_max(val, sol, nb_it): - """Return a generator of nb_it items.""" - # Directly return the `range` generator. - return range(nb_it) - +def iter_max(i, val, sol, nb_it): + if i < nb_it: + return True + else: + return False def make_iter(iters, **kwargs): """Make an iterations operator from the given function. A iter. op. takes a value and a solution and returns the current number of iterations.""" - def cont(val, sol): - return iters(val, sol, **kwargs) - return cont + def f(i, val, sol): + return iters(i, val, sol, **kwargs) + return f + + +# Stopping criterions that are actually just checkpoints. + +def combine(i, val, sol, agains): + """Combine several stopping criterions in one.""" + res = True + for again in agains: + res = res and again(i, val, sol) + return res + + +def save(i, val, sol, filename="run.csv", fmt="{it} ; {val} ; {sol}\n"): + """Save all iterations to a file.""" + # Append a line at the end of the file. + with open(filename.format(it=i), 'a') as fd: + fd.write( fmt.format(it=i, val=val, sol=sol) ) + return True # No incidence on termination. + + +def iter_log(i, val, sol, fmt="{it} {val}\n"): + """Print progress on stderr.""" + sys.stderr.write( fmt.format(it=i, val=val) ) + return True ######################################################################## # Algorithms ######################################################################## -def search(func, init, neighb, iters): - """Iterative randomized heuristic template.""" +def random(func, init, again): + """Iterative random search template.""" + best_sol = None + best_val = - np.inf + val,sol = best_val,best_sol + i = 0 + while again(i, val, sol): + sol = init() + val = func(sol) + if val > best_val: + best_val = val + best_sol = sol + i += 1 + return best_val, best_sol + + +def greedy(func, init, neighb, again): + """Iterative randomized greedy heuristic template.""" best_sol = init() best_val = func(best_sol) - for i in iters(best_val, best_sol): + val,sol = best_val,best_sol + i = 1 + while again(i, val, sol): sol = neighb(best_sol) val = func(sol) if val > best_val: best_val = val best_sol = sol - return val,sol - + i += 1 + return best_val, best_sol +# TODO add a simulated annealing solver. # TODO add a population-based stochastic heuristic template. @@ -221,8 +288,8 @@ if __name__=="__main__": can.add_argument("-i", "--iters", metavar="NB", default=100, type=int, help="Maximum number of iterations") - can.add_argument("-s", "--seed", metavar="VAL", default=0, type=int, - help="Random pseudo-generator seed (0 for epoch)") + can.add_argument("-s", "--seed", metavar="VAL", default=None, type=int, + help="Random pseudo-generator seed (none for current epoch)") solvers = ["num_greedy","bit_greedy"] can.add_argument("-m", "--solver", metavar="NAME", choices=solvers, default="num_greedy", @@ -244,10 +311,31 @@ if __name__=="__main__": # in case you would start "runs" in parallel. np.random.seed(the.seed) + # Weird numpy way to ensure single line print of array. + np.set_printoptions(linewidth = np.inf) + domain = np.zeros((the.domain_width, the.domain_width)) + # Common termination and checkpointing. + iters = make_iter( + combine, + agains = [ + make_iter(iter_max, + nb_it = the.iters), + make_iter(save, + filename = the.solver+".csv", + fmt = "{it} ; {val} ; {sol}\n"), + make_iter(iter_log, + fmt="\r{it} {val}") + ] + ) + + # Erase the previous file. + with open(the.solver+".csv", 'w') as fd: + fd.write("# {} {}\n".format(the.solver,the.domain_width)) + if the.solver == "num_greedy": - val,sol = search( + val,sol = greedy( make_func(num_cover_sum, domain_width = the.domain_width, sensor_range = the.sensor_range * the.domain_width), @@ -256,13 +344,12 @@ if __name__=="__main__": scale = the.domain_width), make_neig(num_neighb_square, scale = the.domain_width/10), # TODO think of an alternative. - make_iter(iter_max, - nb_it = the.iters) + iters ) sensors = num_to_sensors(sol) elif the.solver == "bit_greedy": - val,sol = search( + val,sol = greedy( make_func(bit_cover_sum, domain_width = the.domain_width, sensor_range = the.sensor_range), @@ -271,15 +358,13 @@ if __name__=="__main__": nb_sensors = the.nb_sensors), make_neig(bit_neighb_square, scale = the.domain_width/10), - make_iter(iter_max, - nb_it = the.iters) + iters ) sensors = bit_to_sensors(sol) - # TODO add a simulated annealing solver. # Fancy output. - print(val,":",sensors) + print("\n",val,":",sensors) domain = coverage(domain, sensors, the.sensor_range * the.domain_width)