Skip to content

Commit

Permalink
progress
Browse files Browse the repository at this point in the history
  • Loading branch information
mxyng committed Nov 17, 2023
1 parent 9f04e5a commit c4a3ccd
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 0 deletions.
118 changes: 118 additions & 0 deletions progress/bar.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package progress

import (
"fmt"
"os"
"strings"
"time"

"github.com/jmorganca/ollama/format"
"golang.org/x/term"
)

type Bar struct {
message string
messageWidth int

maxValue int64
initialValue int64
currentValue int64

started time.Time
stopped time.Time
}

func NewBar(message string, maxValue, initialValue int64) *Bar {
return &Bar{
message: message,
messageWidth: -1,
maxValue: maxValue,
initialValue: initialValue,
currentValue: initialValue,
started: time.Now(),
}
}

func (b *Bar) String() string {
termWidth, _, err := term.GetSize(int(os.Stderr.Fd()))
if err != nil {
panic(err)
}

var pre, mid, suf strings.Builder

if b.message != "" {
message := strings.TrimSpace(b.message)
if b.messageWidth > 0 && len(message) > b.messageWidth {
message = message[:b.messageWidth]
}

fmt.Fprintf(&pre, "%s", message)
if b.messageWidth-pre.Len() >= 0 {
pre.WriteString(strings.Repeat(" ", b.messageWidth-pre.Len()))
}

pre.WriteString(" ")
}

fmt.Fprintf(&pre, "%.1f%% ", b.percent())

fmt.Fprintf(&suf, "(%s/%s, %s/s, %s)",
format.HumanBytes(b.currentValue),
format.HumanBytes(b.maxValue),
format.HumanBytes(int64(b.rate())),
b.elapsed())

mid.WriteString("[")

// pad 3 for last = or > and "] "
f := termWidth - pre.Len() - mid.Len() - suf.Len() - 3
n := int(float64(f) * b.percent() / 100)
if n > 0 {
mid.WriteString(strings.Repeat("=", n))
}

if b.currentValue >= b.maxValue {
mid.WriteString("=")
} else {
mid.WriteString(">")
}

if f-n > 0 {
mid.WriteString(strings.Repeat(" ", f-n))
}

mid.WriteString("] ")

return pre.String() + mid.String() + suf.String()
}

func (b *Bar) Set(value int64) {
if value >= b.maxValue {
value = b.maxValue
b.stopped = time.Now()
}

b.currentValue = value
}

func (b *Bar) percent() float64 {
if b.maxValue > 0 {
return float64(b.currentValue) / float64(b.maxValue) * 100
}

return 0
}

func (b *Bar) rate() float64 {
return (float64(b.currentValue) - float64(b.initialValue)) / b.elapsed().Seconds()
}

func (b *Bar) elapsed() time.Duration {
stopped := b.stopped
if stopped.IsZero() {
stopped = time.Now()
}

return stopped.Sub(b.started).Round(time.Second)
}
65 changes: 65 additions & 0 deletions progress/progress.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package progress

import (
"fmt"
"io"
"sync"
"time"
)

type State interface {
String() string
}

type Progress struct {
mu sync.Mutex
pos int
w io.Writer

ticker *time.Ticker
states []State
}

func NewProgress(w io.Writer) *Progress {
p := &Progress{pos: -1, w: w}
go p.start()
return p
}

func (p *Progress) Stop() {
if p.ticker != nil {
p.ticker.Stop()
p.ticker = nil
p.render()
}
}

func (p *Progress) Add(key string, state State) {
p.mu.Lock()
defer p.mu.Unlock()

p.states = append(p.states, state)
}

func (p *Progress) render() error {
p.mu.Lock()
defer p.mu.Unlock()

fmt.Fprintf(p.w, "\033[%dA", p.pos)
for _, state := range p.states {
fmt.Fprintln(p.w, state.String())
}

if len(p.states) > 0 {
p.pos = len(p.states)
}

return nil
}

func (p *Progress) start() {
p.ticker = time.NewTicker(100 * time.Millisecond)
for range p.ticker.C {
p.render()
}
}
102 changes: 102 additions & 0 deletions progress/spinner.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package progress

import (
"fmt"
"os"
"strings"
"time"

"golang.org/x/term"
)

type Spinner struct {
message string
messageWidth int

parts []string

value int

ticker *time.Ticker
started time.Time
stopped time.Time
}

func NewSpinner(message string) *Spinner {
s := &Spinner{
message: message,
parts: []string{
"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏",
},
started: time.Now(),
}
go s.start()
return s
}

func (s *Spinner) String() string {
termWidth, _, err := term.GetSize(int(os.Stderr.Fd()))
if err != nil {
panic(err)
}

var pre strings.Builder
if len(s.message) > 0 {
message := strings.TrimSpace(s.message)
if s.messageWidth > 0 && len(message) > s.messageWidth {
message = message[:s.messageWidth]
}

fmt.Fprintf(&pre, "%s", message)
if s.messageWidth-pre.Len() >= 0 {
pre.WriteString(strings.Repeat(" ", s.messageWidth-pre.Len()))
}

pre.WriteString(" ")
}

var pad int
if s.stopped.IsZero() {
// spinner has a string length of 3 but a rune length of 1
// in order to align correctly, we need to pad with (3 - 1) = 2 spaces
spinner := s.parts[s.value]
pre.WriteString(spinner)
pad = len(spinner) - len([]rune(spinner))
}

var suf strings.Builder
fmt.Fprintf(&suf, "(%s)", s.elapsed())

var mid strings.Builder
f := termWidth - pre.Len() - mid.Len() - suf.Len() + pad
if f > 0 {
mid.WriteString(strings.Repeat(" ", f))
}

return pre.String() + mid.String() + suf.String()
}

func (s *Spinner) start() {
s.ticker = time.NewTicker(100 * time.Millisecond)
for range s.ticker.C {
s.value = (s.value + 1) % len(s.parts)
if !s.stopped.IsZero() {
return
}
}
}

func (s *Spinner) Stop() {
if s.stopped.IsZero() {
s.stopped = time.Now()
}
}

func (s *Spinner) elapsed() time.Duration {
stopped := s.stopped
if stopped.IsZero() {
stopped = time.Now()
}

return stopped.Sub(s.started).Round(time.Second)
}

0 comments on commit c4a3ccd

Please sign in to comment.