fix, clean, add checkpoints
This commit is contained in:
parent
e3257d32c0
commit
dcf9b798dc
1 changed files with 112 additions and 27 deletions
139
sho/snp.py
139
sho/snp.py
|
|
@ -1,31 +1,52 @@
|
||||||
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# Utilities
|
# Utilities
|
||||||
########################################################################
|
########################################################################
|
||||||
|
|
||||||
def x(a):
|
def x(a):
|
||||||
|
"""Return the first element of a 2-tuple.
|
||||||
|
>>> x([1,2])
|
||||||
|
1
|
||||||
|
"""
|
||||||
return a[0]
|
return a[0]
|
||||||
|
|
||||||
|
|
||||||
def y(a):
|
def y(a):
|
||||||
|
"""Return the second element of a 2-tuple.
|
||||||
|
>>> y([1,2])
|
||||||
|
2
|
||||||
|
"""
|
||||||
return a[1]
|
return a[1]
|
||||||
|
|
||||||
|
|
||||||
def distance(a,b):
|
def distance(a,b):
|
||||||
"""Euclidean distance (in pixels)."""
|
"""Euclidean distance (in pixels).
|
||||||
|
|
||||||
|
>>> distance( (1,1),(2,2) ) == math.sqrt(2)
|
||||||
|
True
|
||||||
|
"""
|
||||||
return np.sqrt( (x(a)-x(b))**2 + (y(a)-y(b))**2 )
|
return np.sqrt( (x(a)-x(b))**2 + (y(a)-y(b))**2 )
|
||||||
|
|
||||||
|
|
||||||
def highlight_sensors(domain, sensors):
|
def highlight_sensors(domain, sensors, val=2):
|
||||||
|
"""Add twos to the given domain, in the cells where the given
|
||||||
|
sensors are located.
|
||||||
|
|
||||||
|
>>> highlight_sensors( [[0,0],[1,1]], [(0,0),(1,1)] )
|
||||||
|
[[2, 0], [1, 2]]
|
||||||
|
"""
|
||||||
for s in sensors:
|
for s in sensors:
|
||||||
# `coverage` fills the domain with ones,
|
# `coverage` fills the domain with ones,
|
||||||
# adding twos will be visible in an image.
|
# adding twos will be visible in an image.
|
||||||
domain[y(s)][x(s)] = 2
|
domain[y(s)][x(s)] = val
|
||||||
return domain
|
return domain
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# Objective functions
|
# Objective functions
|
||||||
########################################################################
|
########################################################################
|
||||||
|
|
@ -133,7 +154,9 @@ def make_init(init, **kwargs):
|
||||||
def num_neighb_square(sol, scale):
|
def num_neighb_square(sol, scale):
|
||||||
"""Draw a random vector in a square of witdh `scale`
|
"""Draw a random vector in a square of witdh `scale`
|
||||||
around the given one."""
|
around the given one."""
|
||||||
return sol + np.random.random(len(sol)) * scale - scale/2
|
# TODO handle constraints
|
||||||
|
new = sol + (np.random.random(len(sol)) * scale - scale/2)
|
||||||
|
return new
|
||||||
|
|
||||||
|
|
||||||
def bit_neighb_square(sol, scale):
|
def bit_neighb_square(sol, scale):
|
||||||
|
|
@ -145,6 +168,7 @@ def bit_neighb_square(sol, scale):
|
||||||
for px in range(len(sol[py])):
|
for px in range(len(sol[py])):
|
||||||
if sol[py][px] == 1:
|
if sol[py][px] == 1:
|
||||||
new[py][px] = 0 # Remove original position.
|
new[py][px] = 0 # Remove original position.
|
||||||
|
# TODO handle constraints
|
||||||
d = np.random.randint(-scale//2,scale//2,2)
|
d = np.random.randint(-scale//2,scale//2,2)
|
||||||
new[py+y(d)][px+x(d)] = 1
|
new[py+y(d)][px+x(d)] = 1
|
||||||
return new
|
return new
|
||||||
|
|
@ -162,38 +186,81 @@ def make_neig(neighb, **kwargs):
|
||||||
# Stopping criterions
|
# Stopping criterions
|
||||||
########################################################################
|
########################################################################
|
||||||
|
|
||||||
def iter_max(val, sol, nb_it):
|
def iter_max(i, val, sol, nb_it):
|
||||||
"""Return a generator of nb_it items."""
|
if i < nb_it:
|
||||||
# Directly return the `range` generator.
|
return True
|
||||||
return range(nb_it)
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
def make_iter(iters, **kwargs):
|
def make_iter(iters, **kwargs):
|
||||||
"""Make an iterations operator from the given function.
|
"""Make an iterations operator from the given function.
|
||||||
A iter. op. takes a value and a solution and returns
|
A iter. op. takes a value and a solution and returns
|
||||||
the current number of iterations."""
|
the current number of iterations."""
|
||||||
def cont(val, sol):
|
def f(i, val, sol):
|
||||||
return iters(val, sol, **kwargs)
|
return iters(i, val, sol, **kwargs)
|
||||||
return cont
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
# Stopping criterions that are actually just checkpoints.
|
||||||
|
|
||||||
|
def combine(i, val, sol, agains):
|
||||||
|
"""Combine several stopping criterions in one."""
|
||||||
|
res = True
|
||||||
|
for again in agains:
|
||||||
|
res = res and again(i, val, sol)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def save(i, val, sol, filename="run.csv", fmt="{it} ; {val} ; {sol}\n"):
|
||||||
|
"""Save all iterations to a file."""
|
||||||
|
# Append a line at the end of the file.
|
||||||
|
with open(filename.format(it=i), 'a') as fd:
|
||||||
|
fd.write( fmt.format(it=i, val=val, sol=sol) )
|
||||||
|
return True # No incidence on termination.
|
||||||
|
|
||||||
|
|
||||||
|
def iter_log(i, val, sol, fmt="{it} {val}\n"):
|
||||||
|
"""Print progress on stderr."""
|
||||||
|
sys.stderr.write( fmt.format(it=i, val=val) )
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# Algorithms
|
# Algorithms
|
||||||
########################################################################
|
########################################################################
|
||||||
|
|
||||||
def search(func, init, neighb, iters):
|
def random(func, init, again):
|
||||||
"""Iterative randomized heuristic template."""
|
"""Iterative random search template."""
|
||||||
|
best_sol = None
|
||||||
|
best_val = - np.inf
|
||||||
|
val,sol = best_val,best_sol
|
||||||
|
i = 0
|
||||||
|
while again(i, val, sol):
|
||||||
|
sol = init()
|
||||||
|
val = func(sol)
|
||||||
|
if val > best_val:
|
||||||
|
best_val = val
|
||||||
|
best_sol = sol
|
||||||
|
i += 1
|
||||||
|
return best_val, best_sol
|
||||||
|
|
||||||
|
|
||||||
|
def greedy(func, init, neighb, again):
|
||||||
|
"""Iterative randomized greedy heuristic template."""
|
||||||
best_sol = init()
|
best_sol = init()
|
||||||
best_val = func(best_sol)
|
best_val = func(best_sol)
|
||||||
for i in iters(best_val, best_sol):
|
val,sol = best_val,best_sol
|
||||||
|
i = 1
|
||||||
|
while again(i, val, sol):
|
||||||
sol = neighb(best_sol)
|
sol = neighb(best_sol)
|
||||||
val = func(sol)
|
val = func(sol)
|
||||||
if val > best_val:
|
if val > best_val:
|
||||||
best_val = val
|
best_val = val
|
||||||
best_sol = sol
|
best_sol = sol
|
||||||
return val,sol
|
i += 1
|
||||||
|
return best_val, best_sol
|
||||||
|
|
||||||
|
# TODO add a simulated annealing solver.
|
||||||
# TODO add a population-based stochastic heuristic template.
|
# TODO add a population-based stochastic heuristic template.
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -221,8 +288,8 @@ if __name__=="__main__":
|
||||||
can.add_argument("-i", "--iters", metavar="NB", default=100, type=int,
|
can.add_argument("-i", "--iters", metavar="NB", default=100, type=int,
|
||||||
help="Maximum number of iterations")
|
help="Maximum number of iterations")
|
||||||
|
|
||||||
can.add_argument("-s", "--seed", metavar="VAL", default=0, type=int,
|
can.add_argument("-s", "--seed", metavar="VAL", default=None, type=int,
|
||||||
help="Random pseudo-generator seed (0 for epoch)")
|
help="Random pseudo-generator seed (none for current epoch)")
|
||||||
|
|
||||||
solvers = ["num_greedy","bit_greedy"]
|
solvers = ["num_greedy","bit_greedy"]
|
||||||
can.add_argument("-m", "--solver", metavar="NAME", choices=solvers, default="num_greedy",
|
can.add_argument("-m", "--solver", metavar="NAME", choices=solvers, default="num_greedy",
|
||||||
|
|
@ -244,10 +311,31 @@ if __name__=="__main__":
|
||||||
# in case you would start "runs" in parallel.
|
# in case you would start "runs" in parallel.
|
||||||
np.random.seed(the.seed)
|
np.random.seed(the.seed)
|
||||||
|
|
||||||
|
# Weird numpy way to ensure single line print of array.
|
||||||
|
np.set_printoptions(linewidth = np.inf)
|
||||||
|
|
||||||
domain = np.zeros((the.domain_width, the.domain_width))
|
domain = np.zeros((the.domain_width, the.domain_width))
|
||||||
|
|
||||||
|
# Common termination and checkpointing.
|
||||||
|
iters = make_iter(
|
||||||
|
combine,
|
||||||
|
agains = [
|
||||||
|
make_iter(iter_max,
|
||||||
|
nb_it = the.iters),
|
||||||
|
make_iter(save,
|
||||||
|
filename = the.solver+".csv",
|
||||||
|
fmt = "{it} ; {val} ; {sol}\n"),
|
||||||
|
make_iter(iter_log,
|
||||||
|
fmt="\r{it} {val}")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Erase the previous file.
|
||||||
|
with open(the.solver+".csv", 'w') as fd:
|
||||||
|
fd.write("# {} {}\n".format(the.solver,the.domain_width))
|
||||||
|
|
||||||
if the.solver == "num_greedy":
|
if the.solver == "num_greedy":
|
||||||
val,sol = search(
|
val,sol = greedy(
|
||||||
make_func(num_cover_sum,
|
make_func(num_cover_sum,
|
||||||
domain_width = the.domain_width,
|
domain_width = the.domain_width,
|
||||||
sensor_range = the.sensor_range * the.domain_width),
|
sensor_range = the.sensor_range * the.domain_width),
|
||||||
|
|
@ -256,13 +344,12 @@ if __name__=="__main__":
|
||||||
scale = the.domain_width),
|
scale = the.domain_width),
|
||||||
make_neig(num_neighb_square,
|
make_neig(num_neighb_square,
|
||||||
scale = the.domain_width/10), # TODO think of an alternative.
|
scale = the.domain_width/10), # TODO think of an alternative.
|
||||||
make_iter(iter_max,
|
iters
|
||||||
nb_it = the.iters)
|
|
||||||
)
|
)
|
||||||
sensors = num_to_sensors(sol)
|
sensors = num_to_sensors(sol)
|
||||||
|
|
||||||
elif the.solver == "bit_greedy":
|
elif the.solver == "bit_greedy":
|
||||||
val,sol = search(
|
val,sol = greedy(
|
||||||
make_func(bit_cover_sum,
|
make_func(bit_cover_sum,
|
||||||
domain_width = the.domain_width,
|
domain_width = the.domain_width,
|
||||||
sensor_range = the.sensor_range),
|
sensor_range = the.sensor_range),
|
||||||
|
|
@ -271,15 +358,13 @@ if __name__=="__main__":
|
||||||
nb_sensors = the.nb_sensors),
|
nb_sensors = the.nb_sensors),
|
||||||
make_neig(bit_neighb_square,
|
make_neig(bit_neighb_square,
|
||||||
scale = the.domain_width/10),
|
scale = the.domain_width/10),
|
||||||
make_iter(iter_max,
|
iters
|
||||||
nb_it = the.iters)
|
|
||||||
)
|
)
|
||||||
sensors = bit_to_sensors(sol)
|
sensors = bit_to_sensors(sol)
|
||||||
|
|
||||||
# TODO add a simulated annealing solver.
|
|
||||||
|
|
||||||
# Fancy output.
|
# Fancy output.
|
||||||
print(val,":",sensors)
|
print("\n",val,":",sensors)
|
||||||
|
|
||||||
domain = coverage(domain, sensors,
|
domain = coverage(domain, sensors,
|
||||||
the.sensor_range * the.domain_width)
|
the.sensor_range * the.domain_width)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue