Skip to content

Instantly share code, notes, and snippets.

@kiyoon
Last active December 21, 2023 09:45
Show Gist options
  • Save kiyoon/3037e680d95c5c2871d8028305fd41de to your computer and use it in GitHub Desktop.
Save kiyoon/3037e680d95c5c2871d8028305fd41de to your computer and use it in GitHub Desktop.
from rich.progress import (
Progress,
ProgressColumn,
SpinnerColumn,
Task,
TextColumn,
TimeElapsedColumn,
)
from rich.text import Text
class SpeedColumn(ProgressColumn):
def render(self, task: "Task") -> Text:
if task.speed is None:
return Text("")
else:
return Text(f"{task.speed:.3f} steps/s {1 / task.speed:.3f} s/step")
def get_rich_progress():
"""
Returns a Rich Progress Bar with default columns and a custom SpeedColumn.
Usage:
```
pbar = get_rich_progress()
# you need at least one task to start the progress bar
pbar_task = pbar.add_task("Training", total=len(train_dataloader), start=False)
# optional: add custom columns
# You need to first initialize the column with the task field name
# pbar_task = pbar.add_task("Training", total=len(train_dataloader), start=False, loss=0.0)
# pbar.columns += (TextColumn("Loss: {task.fields[loss]}"),)
# start progress bar (counting time from here)
pbar.start()
pbar.start_task(pbar_task)
# update progress bar
pbar.update(pbar_task, advance=1, loss=0.123)
```
"""
pbar = Progress(
SpinnerColumn(),
*Progress.get_default_columns(),
TimeElapsedColumn(),
TextColumn("{task.completed}/{task.total}"),
SpeedColumn(),
transient=False,
)
return pbar
def main():
import time
pbar = get_rich_progress()
pbar_task = pbar.add_task("Training", total=100, start=True, loss=0.0)
pbar.columns += (TextColumn("Loss: {task.fields[loss]}"),)
pbar.start()
pbar.start_task(pbar_task)
for _ in range(100):
pbar.update(pbar_task, advance=1, loss=0.123)
time.sleep(0.1)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment