fix, clean, add checkpoints

This commit is contained in:
Johann Dreo 2018-12-13 21:06:14 +01:00
commit dcf9b798dc

View file

@ -1,31 +1,52 @@
import sys
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import copy import copy
######################################################################## ########################################################################
# Utilities # Utilities
######################################################################## ########################################################################
def x(a): def x(a):
"""Return the first element of a 2-tuple.
>>> x([1,2])
1
"""
return a[0] return a[0]
def y(a): def y(a):
"""Return the second element of a 2-tuple.
>>> y([1,2])
2
"""
return a[1] return a[1]
def distance(a,b): def distance(a,b):
"""Euclidean distance (in pixels).""" """Euclidean distance (in pixels).
>>> distance( (1,1),(2,2) ) == math.sqrt(2)
True
"""
return np.sqrt( (x(a)-x(b))**2 + (y(a)-y(b))**2 ) return np.sqrt( (x(a)-x(b))**2 + (y(a)-y(b))**2 )
def highlight_sensors(domain, sensors): def highlight_sensors(domain, sensors, val=2):
"""Add twos to the given domain, in the cells where the given
sensors are located.
>>> highlight_sensors( [[0,0],[1,1]], [(0,0),(1,1)] )
[[2, 0], [1, 2]]
"""
for s in sensors: for s in sensors:
# `coverage` fills the domain with ones, # `coverage` fills the domain with ones,
# adding twos will be visible in an image. # adding twos will be visible in an image.
domain[y(s)][x(s)] = 2 domain[y(s)][x(s)] = val
return domain return domain
######################################################################## ########################################################################
# Objective functions # Objective functions
######################################################################## ########################################################################
@ -133,7 +154,9 @@ def make_init(init, **kwargs):
def num_neighb_square(sol, scale): def num_neighb_square(sol, scale):
"""Draw a random vector in a square of witdh `scale` """Draw a random vector in a square of witdh `scale`
around the given one.""" around the given one."""
return sol + np.random.random(len(sol)) * scale - scale/2 # TODO handle constraints
new = sol + (np.random.random(len(sol)) * scale - scale/2)
return new
def bit_neighb_square(sol, scale): def bit_neighb_square(sol, scale):
@ -145,6 +168,7 @@ def bit_neighb_square(sol, scale):
for px in range(len(sol[py])): for px in range(len(sol[py])):
if sol[py][px] == 1: if sol[py][px] == 1:
new[py][px] = 0 # Remove original position. new[py][px] = 0 # Remove original position.
# TODO handle constraints
d = np.random.randint(-scale//2,scale//2,2) d = np.random.randint(-scale//2,scale//2,2)
new[py+y(d)][px+x(d)] = 1 new[py+y(d)][px+x(d)] = 1
return new return new
@ -162,38 +186,81 @@ def make_neig(neighb, **kwargs):
# Stopping criterions # Stopping criterions
######################################################################## ########################################################################
def iter_max(val, sol, nb_it): def iter_max(i, val, sol, nb_it):
"""Return a generator of nb_it items.""" if i < nb_it:
# Directly return the `range` generator. return True
return range(nb_it) else:
return False
def make_iter(iters, **kwargs): def make_iter(iters, **kwargs):
"""Make an iterations operator from the given function. """Make an iterations operator from the given function.
A iter. op. takes a value and a solution and returns A iter. op. takes a value and a solution and returns
the current number of iterations.""" the current number of iterations."""
def cont(val, sol): def f(i, val, sol):
return iters(val, sol, **kwargs) return iters(i, val, sol, **kwargs)
return cont return f
# Stopping criterions that are actually just checkpoints.
def combine(i, val, sol, agains):
"""Combine several stopping criterions in one."""
res = True
for again in agains:
res = res and again(i, val, sol)
return res
def save(i, val, sol, filename="run.csv", fmt="{it} ; {val} ; {sol}\n"):
"""Save all iterations to a file."""
# Append a line at the end of the file.
with open(filename.format(it=i), 'a') as fd:
fd.write( fmt.format(it=i, val=val, sol=sol) )
return True # No incidence on termination.
def iter_log(i, val, sol, fmt="{it} {val}\n"):
"""Print progress on stderr."""
sys.stderr.write( fmt.format(it=i, val=val) )
return True
######################################################################## ########################################################################
# Algorithms # Algorithms
######################################################################## ########################################################################
def search(func, init, neighb, iters): def random(func, init, again):
"""Iterative randomized heuristic template.""" """Iterative random search template."""
best_sol = None
best_val = - np.inf
val,sol = best_val,best_sol
i = 0
while again(i, val, sol):
sol = init()
val = func(sol)
if val > best_val:
best_val = val
best_sol = sol
i += 1
return best_val, best_sol
def greedy(func, init, neighb, again):
"""Iterative randomized greedy heuristic template."""
best_sol = init() best_sol = init()
best_val = func(best_sol) best_val = func(best_sol)
for i in iters(best_val, best_sol): val,sol = best_val,best_sol
i = 1
while again(i, val, sol):
sol = neighb(best_sol) sol = neighb(best_sol)
val = func(sol) val = func(sol)
if val > best_val: if val > best_val:
best_val = val best_val = val
best_sol = sol best_sol = sol
return val,sol i += 1
return best_val, best_sol
# TODO add a simulated annealing solver.
# TODO add a population-based stochastic heuristic template. # TODO add a population-based stochastic heuristic template.
@ -221,8 +288,8 @@ if __name__=="__main__":
can.add_argument("-i", "--iters", metavar="NB", default=100, type=int, can.add_argument("-i", "--iters", metavar="NB", default=100, type=int,
help="Maximum number of iterations") help="Maximum number of iterations")
can.add_argument("-s", "--seed", metavar="VAL", default=0, type=int, can.add_argument("-s", "--seed", metavar="VAL", default=None, type=int,
help="Random pseudo-generator seed (0 for epoch)") help="Random pseudo-generator seed (none for current epoch)")
solvers = ["num_greedy","bit_greedy"] solvers = ["num_greedy","bit_greedy"]
can.add_argument("-m", "--solver", metavar="NAME", choices=solvers, default="num_greedy", can.add_argument("-m", "--solver", metavar="NAME", choices=solvers, default="num_greedy",
@ -244,10 +311,31 @@ if __name__=="__main__":
# in case you would start "runs" in parallel. # in case you would start "runs" in parallel.
np.random.seed(the.seed) np.random.seed(the.seed)
# Weird numpy way to ensure single line print of array.
np.set_printoptions(linewidth = np.inf)
domain = np.zeros((the.domain_width, the.domain_width)) domain = np.zeros((the.domain_width, the.domain_width))
# Common termination and checkpointing.
iters = make_iter(
combine,
agains = [
make_iter(iter_max,
nb_it = the.iters),
make_iter(save,
filename = the.solver+".csv",
fmt = "{it} ; {val} ; {sol}\n"),
make_iter(iter_log,
fmt="\r{it} {val}")
]
)
# Erase the previous file.
with open(the.solver+".csv", 'w') as fd:
fd.write("# {} {}\n".format(the.solver,the.domain_width))
if the.solver == "num_greedy": if the.solver == "num_greedy":
val,sol = search( val,sol = greedy(
make_func(num_cover_sum, make_func(num_cover_sum,
domain_width = the.domain_width, domain_width = the.domain_width,
sensor_range = the.sensor_range * the.domain_width), sensor_range = the.sensor_range * the.domain_width),
@ -256,13 +344,12 @@ if __name__=="__main__":
scale = the.domain_width), scale = the.domain_width),
make_neig(num_neighb_square, make_neig(num_neighb_square,
scale = the.domain_width/10), # TODO think of an alternative. scale = the.domain_width/10), # TODO think of an alternative.
make_iter(iter_max, iters
nb_it = the.iters)
) )
sensors = num_to_sensors(sol) sensors = num_to_sensors(sol)
elif the.solver == "bit_greedy": elif the.solver == "bit_greedy":
val,sol = search( val,sol = greedy(
make_func(bit_cover_sum, make_func(bit_cover_sum,
domain_width = the.domain_width, domain_width = the.domain_width,
sensor_range = the.sensor_range), sensor_range = the.sensor_range),
@ -271,15 +358,13 @@ if __name__=="__main__":
nb_sensors = the.nb_sensors), nb_sensors = the.nb_sensors),
make_neig(bit_neighb_square, make_neig(bit_neighb_square,
scale = the.domain_width/10), scale = the.domain_width/10),
make_iter(iter_max, iters
nb_it = the.iters)
) )
sensors = bit_to_sensors(sol) sensors = bit_to_sensors(sol)
# TODO add a simulated annealing solver.
# Fancy output. # Fancy output.
print(val,":",sensors) print("\n",val,":",sensors)
domain = coverage(domain, sensors, domain = coverage(domain, sensors,
the.sensor_range * the.domain_width) the.sensor_range * the.domain_width)