8000 [Feature Request] Progress Bars · Issue #7574 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
[Feature Request] Progress Bars #7574
New issue
Open
@Erotemic

Description

@Erotemic

I'm currently working with a fork of sklearn where I include progress bars in long running Estimators (particularly MiniBatchKmeans). It takes up to 9 hours to run some of the algorithms and having an estimated time remaining has been quite helpful to see when I should come back and look at the results.

I was wondering if it would be worthwhile to compile some of these into a pull requests and add this functionality to the library. If so there are few logistics:

Which progress implementation should I use? I see three options here.

  1. I could add a dependency on the external click library and use its ProgressBar class.
  2. I could port my a simplified version of my own ProgIter class into sklearn utils.
  3. I could port a simplified version of click's ProgressBar class

All options essentially wrap an iterator and dump some info to stdout about how far through the iterator they are.

These options are ordered from the least work to the most work. The first option is the least work, but it adds a dependency. I have experience with the second option, but it is not as widely used as click (although my port would use a more click-like signature). The third option requires me to delve into the click library, which I haven't had any experience with yet. However, judging from my initial look at its implementation it looks my implementation may be more efficient. The click implementation seems to compute and display progress information on every iteration, whereas my implementation tries to minimally impact the wrapped loop by adjusting its report frequency.

Here is some example code. Previously a code block looked like this:

        if self.verbose:
            print('Computing label assignment and total inertia')
        x_squared_norms = row_norms(X, squared=True)
        slices = gen_batches(X.shape[0], self.batch_size)
        results = [_labels_inertia(X[s], x_squared_norms[s],
                                   self.cluster_centers_) for s in slices]
        labels, inertia = zip(*results)
        return np.hstack(labels), np.sum(inertia)

If I port a refactored / paired down version of my progress iterator to sklearn, then it will look something like this:

        if self.verbose:
            print('Computing label assignment and total inertia')
        x_squared_norms = row_norms(X, squared=True)
        n_samples = X.shape[0]
        total_batches = int(n_samples // self.batch_size)
        slices = gen_batches(n_samples, self.batch_size)
        slices = ProgIter(slices, label='computing labels and inertia',
                          length=total_batches, enabled=self.verbose)
        results = [_labels_inertia(X[s], x_squared_norms[s],
                                   self.cluster_centers_) for s in slices]
        labels, inertia = zip(*results)
        return np.hstack(labels), np.sum(inertia)

Here is a mockup of the simplified ProgIter object. I paired it down rather quickly, so there may be some small issue in this version.

from __future__ import absolute_import, division, print_function, unicode_literals
import sys
import time
import datetime
import collections
import six

WIN32 = sys.platform.startswith('win32')

if WIN32:
    # Use time.clock in win32
    default_timer = time.clock
    BEFORE_BAR = '\r'
    AFTER_BAR = '\n'
else:
    default_timer = time.time
    BEFORE_BAR = '\r\033[?25l'
    AFTER_BAR = '\033[?25h\n'

text_type = six.text_type

class ProgIter(object):
    def __init__(self, iterable=None, label=None, length=None, enabled=True,
                 freq=1, adjust=True, eta_window=64):
        if length is None:
            try:
                length = len(iterable)
            except Exception:
                pass
        if label is None:
            label = ''
        self.iterable = iterable
        self.label = label
        self.length = length
        self.freq = freq
        self.enabled = enabled
        self.adjust = adjust
        self.count = 0
        # Window size for estimated time remaining
        self.eta_window = eta_window

    def __call__(self, iterable):
        self.iterable = iterable
        return self

    def __iter__(self):
        if not self.enabled:
            return iter(self.iterable)
        else:
            return self.iter_rate()

    def build_msg_fmtstr(self):
        with_wall = True
        tzname = time.tzname[0]
        length_ = '?' if self.length == 0 else text_type(self.length)
        msg_head = ['', self.label, ' {count:4d}/', length_ , '...  ']
        msg_tail = [
            ('rate={rate:4.2f} Hz,'),
            ('' if self.length == 0 else ' eta={eta},'),
            ' ellapsed={ellapsed},',
            (' wall={wall} ' + tzname if with_wall else ''),
        ]
        msg_fmtstr_time = ''.join(([BEFORE_BAR] + msg_head + msg_tail))
        return msg_fmtstr_time

    def tryflush(self):
        try:
            # flush sometimes causes issues in IPython notebooks
            sys.stdout.flush()
        except IOError:
            pass

    def write(self, msg):
        sys.stdout.write(msg)

    def iter_rate(self):
        freq          = self.freq
        self.count    = 0
        between_count = 0
        last_count    = 0

        # how long iterations should be before a flush
        # (used for freq adjustment)
        time_thresh = 2.0

        max_between_time = -1.0
        max_between_count = -1.0

        iters_per_second = -1

        # Prepare for iteration
        msg_fmtstr = self.build_msg_fmtstr()

        self.tryflush()

        msg = msg_fmtstr.format(
            count=self.count, rate=0.0,
            eta=text_type('0:00:00'),
            ellapsed=text_type('0:00:00'),
            wall=time.strftime('%H:%M'),
        )
        self.write(msg)

        self.tryflush()

        start_time = default_timer()
        last_time  = start_time

        # use last few (64) times to compute a more stable average rate
        measure_between_time = collections.deque([], maxlen=self.eta_window)
        measure_est_seconds = collections.deque([], maxlen=self.eta_window)

        # Wrap the for loop with a generator
        for self.count, item in enumerate(self.iterable, start=1):
            yield item

            if (self.count) % freq == 0:
                # update progress information every so often
                now_time          = default_timer()
                between_time      = (now_time - last_time)
                between_count     = self.count - last_count
                total_seconds     = (now_time - start_time)

                measure_between_time.append(between_count / (float(between_time) + 1E-9))
                iters_per_second = sum(measure_between_time) / len(measure_between_time)

                if self.length is None:
                    est_seconds_left = -1
                else:
                    iters_left = self.length - self.count
                    est_eta = iters_left / (iters_per_second + 1E-9)

                    measure_est_seconds.append(est_eta)
                    est_seconds_left = sum(measure_est_seconds) / len(measure_est_seconds)

                last_count = self.count
                last_time  = now_time
                # Adjust frequency if printing too quickly
                # so progress doesnt slow down actual function
                if self.adjust and (between_time < time_thresh or
                                    between_time > time_thresh * 2.0):
                    max_between_time = max(max_between_time, between_time)
                    max_between_time = max(max_between_time, 1E-9)
                    max_between_count = max(max_between_count, between_count)
                    # If progress was uniform and all time estimates were
                    # perfect this would be the new freq to achieve time_thresh
                    new_freq = int(time_thresh * max_between_count / max_between_time)
                    new_freq = max(new_freq, 1)
                    # But things are not perfect. So, don't make drastic changes
                    max_freq_change_up = max(256, freq * 2)
                    max_freq_change_down = freq // 2
                    if (new_freq - freq) > max_freq_change_up:
                        freq += max_freq_change_up
                    elif (freq - new_freq) > max_freq_change_down:
                        freq -= max_freq_change_down
                    else:
                        freq = new_freq
                msg = msg_fmtstr.format(
                    count=self.count,
                    rate=iters_per_second,
                    eta=text_type(datetime.timedelta(seconds=int(est_seconds_left))),
                    ellapsed=text_type(datetime.timedelta(seconds=int(total_seconds))),
                    wall=time.strftime('%H:%M'),
                )
                self.write(msg)
                self.tryflush()

        if (self.count) % freq != 0:
            # Write the final progress line if it was not written in the loop
            est_seconds_left = 0
            now_time = default_timer()
            total_seconds = (now_time - start_time)
            msg = msg_fmtstr.format(
                count=self.count,
                rate=iters_per_second,
                eta=text_type(datetime.timedelta(seconds=int(est_seconds_left))),
                ellapsed=text_type(datetime.timedelta(seconds=int(total_seconds))),
                wall=time.strftime('%H:%M'),
            )
            self.write(msg)
            self.tryflush()
        self.write(AFTER_BAR)

Here is an example showing what this does:

# Define a function that takes some time
def is_prime(n):
    return n >= 2 and not any(n % i == 0 for i in range(2, n))
N = 50000

# Default behavior adjusts frequency of progress reporting so
# the performance of the loop is minimally impacted
iterable = (is_prime(n) for n in range(N))
piterable = ProgIter(iterable, length=N)
_ = list(piterable)
 50000/50000...  rate=25097.04 Hz, eta=0:00:00, ellapsed=0:00:22, wall=11:22 EST

# Adjustments can be turned off to give constant feedback
iterable = (is_prime(n) for n in range(N))
piterable = ProgIter(iterable, length=N, adjust=False, freq=100)
_ = list(piterable)
  2200/50000...  rate=32500.44 Hz, eta=0:00:01, ellapsed=0:00:00, wall=11:23 EST
43300/50000...    rate=1398.73 Hz, eta=0:00:06, ellapsed=0:00:18, wall=11:23 EST
50000/50000...  rate=1129.99 Hz, eta=0:00:02, ellapsed=0:00:25, wall=11:25 EST

# only one line of the above output is shown at a time, I copied different points from a few runs 
# to illustrate the output. 

Lastly here is some timeit information that shows how frequency adjusting causes minimal overhead

def is_prime(n):
    return n >= 2 and not any(n % i == 0 for i in range(2, n))

N = 1000

# Test time of raw loop without any fanciness 
%timeit list((is_prime(n) for n in range(N)))
100 loops, best of 3: 10 ms per loop

# Test time of raw loop when printing on every iteration
%timeit list(ProgIter((is_prime(n) for n in range(N)), length=N, freq=1, adjust=False))
10 loops, best of 3: 159 ms per loop

# Test time of raw loop when printing based on an adjusted schedule 
%timeit list(ProgIter((is_prime(n) for n in range(N)), length=N, freq=1, adjust=True))
100 loops, best of 3: 11 ms per loop

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0