Description
I'm currently working with a fork of sklearn where I include progress bars in long running Estimators (particularly MiniBatchKmeans). It takes up to 9 hours to run some of the algorithms and having an estimated time remaining has been quite helpful to see when I should come back and look at the results.
I was wondering if it would be worthwhile to compile some of these into a pull requests and add this functionality to the library. If so there are few logistics:
Which progress implementation should I use? I see three options here.
- I could add a dependency on the external click library and use its ProgressBar class.
- I could port my a simplified version of my own ProgIter class into sklearn utils.
- I could port a simplified version of click's ProgressBar class
All options essentially wrap an iterator and dump some info to stdout about how far through the iterator they are.
These options are ordered from the least work to the most work. The first option is the least work, but it adds a dependency. I have experience with the second option, but it is not as widely used as click (although my port would use a more click-like signature). The third option requires me to delve into the click library, which I haven't had any experience with yet. However, judging from my initial look at its implementation it looks my implementation may be more efficient. The click implementation seems to compute and display progress information on every iteration, whereas my implementation tries to minimally impact the wrapped loop by adjusting its report frequency.
Here is some example code. Previously a code block looked like this:
if self.verbose:
print('Computing label assignment and total inertia')
x_squared_norms = row_norms(X, squared=True)
slices = gen_batches(X.shape[0], self.batch_size)
results = [_labels_inertia(X[s], x_squared_norms[s],
self.cluster_centers_) for s in slices]
labels, inertia = zip(*results)
return np.hstack(labels), np.sum(inertia)
If I port a refactored / paired down version of my progress iterator to sklearn, then it will look something like this:
if self.verbose:
print('Computing label assignment and total inertia')
x_squared_norms = row_norms(X, squared=True)
n_samples = X.shape[0]
total_batches = int(n_samples // self.batch_size)
slices = gen_batches(n_samples, self.batch_size)
slices = ProgIter(slices, label='computing labels and inertia',
length=total_batches, enabled=self.verbose)
results = [_labels_inertia(X[s], x_squared_norms[s],
self.cluster_centers_) for s in slices]
labels, inertia = zip(*results)
return np.hstack(labels), np.sum(inertia)
Here is a mockup of the simplified ProgIter object. I paired it down rather quickly, so there may be some small issue in this version.
from __future__ import absolute_import, division, print_function, unicode_literals
import sys
import time
import datetime
import collections
import six
WIN32 = sys.platform.startswith('win32')
if WIN32:
# Use time.clock in win32
default_timer = time.clock
BEFORE_BAR = '\r'
AFTER_BAR = '\n'
else:
default_timer = time.time
BEFORE_BAR = '\r\033[?25l'
AFTER_BAR = '\033[?25h\n'
text_type = six.text_type
class ProgIter(object):
def __init__(self, iterable=None, label=None, length=None, enabled=True,
freq=1, adjust=True, eta_window=64):
if length is None:
try:
length = len(iterable)
except Exception:
pass
if label is None:
label = ''
self.iterable = iterable
self.label = label
self.length = length
self.freq = freq
self.enabled = enabled
self.adjust = adjust
self.count = 0
# Window size for estimated time remaining
self.eta_window = eta_window
def __call__(self, iterable):
self.iterable = iterable
return self
def __iter__(self):
if not self.enabled:
return iter(self.iterable)
else:
return self.iter_rate()
def build_msg_fmtstr(self):
with_wall = True
tzname = time.tzname[0]
length_ = '?' if self.length == 0 else text_type(self.length)
msg_head = ['', self.label, ' {count:4d}/', length_ , '... ']
msg_tail = [
('rate={rate:4.2f} Hz,'),
('' if self.length == 0 else ' eta={eta},'),
' ellapsed={ellapsed},',
(' wall={wall} ' + tzname if with_wall else ''),
]
msg_fmtstr_time = ''.join(([BEFORE_BAR] + msg_head + msg_tail))
return msg_fmtstr_time
def tryflush(self):
try:
# flush sometimes causes issues in IPython notebooks
sys.stdout.flush()
except IOError:
pass
def write(self, msg):
sys.stdout.write(msg)
def iter_rate(self):
freq = self.freq
self.count = 0
between_count = 0
last_count = 0
# how long iterations should be before a flush
# (used for freq adjustment)
time_thresh = 2.0
max_between_time = -1.0
max_between_count = -1.0
iters_per_second = -1
# Prepare for iteration
msg_fmtstr = self.build_msg_fmtstr()
self.tryflush()
msg = msg_fmtstr.format(
count=self.count, rate=0.0,
eta=text_type('0:00:00'),
ellapsed=text_type('0:00:00'),
wall=time.strftime('%H:%M'),
)
self.write(msg)
self.tryflush()
start_time = default_timer()
last_time = start_time
# use last few (64) times to compute a more stable average rate
measure_between_time = collections.deque([], maxlen=self.eta_window)
measure_est_seconds = collections.deque([], maxlen=self.eta_window)
# Wrap the for loop with a generator
for self.count, item in enumerate(self.iterable, start=1):
yield item
if (self.count) % freq == 0:
# update progress information every so often
now_time = default_timer()
between_time = (now_time - last_time)
between_count = self.count - last_count
total_seconds = (now_time - start_time)
measure_between_time.append(between_count / (float(between_time) + 1E-9))
iters_per_second = sum(measure_between_time) / len(measure_between_time)
if self.length is None:
est_seconds_left = -1
else:
iters_left = self.length - self.count
est_eta = iters_left / (iters_per_second + 1E-9)
measure_est_seconds.append(est_eta)
est_seconds_left = sum(measure_est_seconds) / len(measure_est_seconds)
last_count = self.count
last_time = now_time
# Adjust frequency if printing too quickly
# so progress doesnt slow down actual function
if self.adjust and (between_time < time_thresh or
between_time > time_thresh * 2.0):
max_between_time = max(max_between_time, between_time)
max_between_time = max(max_between_time, 1E-9)
max_between_count = max(max_between_count, between_count)
# If progress was uniform and all time estimates were
# perfect this would be the new freq to achieve time_thresh
new_freq = int(time_thresh * max_between_count / max_between_time)
new_freq = max(new_freq, 1)
# But things are not perfect. So, don't make drastic changes
max_freq_change_up = max(256, freq * 2)
max_freq_change_down = freq // 2
if (new_freq - freq) > max_freq_change_up:
freq += max_freq_change_up
elif (freq - new_freq) > max_freq_change_down:
freq -= max_freq_change_down
else:
freq = new_freq
msg = msg_fmtstr.format(
count=self.count,
rate=iters_per_second,
eta=text_type(datetime.timedelta(seconds=int(est_seconds_left))),
ellapsed=text_type(datetime.timedelta(seconds=int(total_seconds))),
wall=time.strftime('%H:%M'),
)
self.write(msg)
self.tryflush()
if (self.count) % freq != 0:
# Write the final progress line if it was not written in the loop
est_seconds_left = 0
now_time = default_timer()
total_seconds = (now_time - start_time)
msg = msg_fmtstr.format(
count=self.count,
rate=iters_per_second,
eta=text_type(datetime.timedelta(seconds=int(est_seconds_left))),
ellapsed=text_type(datetime.timedelta(seconds=int(total_seconds))),
wall=time.strftime('%H:%M'),
)
self.write(msg)
self.tryflush()
self.write(AFTER_BAR)
Here is an example showing what this does:
# Define a function that takes some time
def is_prime(n):
return n >= 2 and not any(n % i == 0 for i in range(2, n))
N = 50000
# Default behavior adjusts frequency of progress reporting so
# the performance of the loop is minimally impacted
iterable = (is_prime(n) for n in range(N))
piterable = ProgIter(iterable, length=N)
_ = list(piterable)
50000/50000... rate=25097.04 Hz, eta=0:00:00, ellapsed=0:00:22, wall=11:22 EST
# Adjustments can be turned off to give constant feedback
iterable = (is_prime(n) for n in range(N))
piterable = ProgIter(iterable, length=N, adjust=False, freq=100)
_ = list(piterable)
2200/50000... rate=32500.44 Hz, eta=0:00:01, ellapsed=0:00:00, wall=11:23 EST
43300/50000... rate=1398.73 Hz, eta=0:00:06, ellapsed=0:00:18, wall=11:23 EST
50000/50000... rate=1129.99 Hz, eta=0:00:02, ellapsed=0:00:25, wall=11:25 EST
# only one line of the above output is shown at a time, I copied different points from a few runs
# to illustrate the output.
Lastly here is some timeit information that shows how frequency adjusting causes minimal overhead
def is_prime(n):
return n >= 2 and not any(n % i == 0 for i in range(2, n))
N = 1000
# Test time of raw loop without any fanciness
%timeit list((is_prime(n) for n in range(N)))
100 loops, best of 3: 10 ms per loop
# Test time of raw loop when printing on every iteration
%timeit list(ProgIter((is_prime(n) for n in range(N)), length=N, freq=1, adjust=False))
10 loops, best of 3: 159 ms per loop
# Test time of raw loop when printing based on an adjusted schedule
%timeit list(ProgIter((is_prime(n) for n in range(N)), length=N, freq=1, adjust=True))
100 loops, best of 3: 11 ms per loop