Skip to content

Commit

Permalink
First serialize option implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
teivah committed Feb 25, 2020
1 parent ea311af commit 4905cb2
Show file tree
Hide file tree
Showing 6 changed files with 365 additions and 60 deletions.
162 changes: 157 additions & 5 deletions observable.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,50 @@ func observable(iterable Iterable, operatorFactory func() operator, forceSeq, by
return &ObservableImpl{iterable: newChannelIterable(next)}
}

if forceSeq || !parallel {
return &ObservableImpl{
iterable: newFactoryIterable(func(propagatedOptions ...Option) <-chan Item {
mergedOptions := append(opts, propagatedOptions...)
option := parseOptions(mergedOptions...)

next := option.buildChannel()
ctx := option.buildContext()
runSeq(ctx, next, iterable, operatorFactory, option, mergedOptions...)
return next
}),
}
}

if serialized, f := option.isSerialized(); serialized {
ch := make(chan int, 1)
notif := make(chan int, 1)
obs := &ObservableImpl{
iterable: newFactoryIterable(func(propagatedOptions ...Option) <-chan Item {
mergedOptions := append(opts, propagatedOptions...)
option := parseOptions(mergedOptions...)

next := option.buildChannel()
ctx := option.buildContext()
observe := iterable.Observe(opts...)
go func() {
ch <- <-notif
runPar2(ctx, observe, next, iterable, operatorFactory, bypassGather, option, mergedOptions...)
}()
runFirstItem(ctx, f, notif, observe, next, iterable, operatorFactory, bypassGather, option, mergedOptions...)
return next
}),
}
return obs.serialize(ch, f)
}

return &ObservableImpl{
iterable: newFactoryIterable(func(propagatedOptions ...Option) <-chan Item {
mergedOptions := append(opts, propagatedOptions...)
option := parseOptions(mergedOptions...)

next := option.buildChannel()
ctx := option.buildContext()
if forceSeq || !parallel {
runSeq(ctx, next, iterable, operatorFactory, option, mergedOptions...)
} else {
runPar(ctx, next, iterable, operatorFactory, bypassGather, option, mergedOptions...)
}
runPar(ctx, next, iterable, operatorFactory, bypassGather, option, mergedOptions...)
return next
}),
}
Expand Down Expand Up @@ -318,6 +350,126 @@ func runPar(ctx context.Context, next chan Item, iterable Iterable, operatorFact
}()
}

func runPar2(ctx context.Context, observe <-chan Item, next chan Item, iterable Iterable, operatorFactory func() operator, bypassGather bool, option Option, opts ...Option) {
wg := sync.WaitGroup{}
_, pool := option.getPool()
wg.Add(pool)

var gather chan Item
if bypassGather {
gather = next
} else {
gather = make(chan Item, 1)

// Gather
go func() {
op := operatorFactory()
stopped := false
operator := operatorOptions{
stop: func() {
if option.getErrorStrategy() == StopOnError {
stopped = true
}
},
resetIterable: func(newIterable Iterable) {
observe = newIterable.Observe(opts...)
},
}
for item := range gather {
if stopped {
break
}
if item.Error() {
op.err(ctx, item, next, operator)
} else {
op.gatherNext(ctx, item, next, operator)
}
}
op.end(ctx, next)
close(next)
}()
}

// Scatter
for i := 0; i < pool; i++ {
go func() {
op := operatorFactory()
stopped := false
operator := operatorOptions{
stop: func() {
if option.getErrorStrategy() == StopOnError {
stopped = true
}
},
resetIterable: func(newIterable Iterable) {
observe = newIterable.Observe(opts...)
},
}
defer wg.Done()
for !stopped {
select {
case <-ctx.Done():
return
case item, ok := <-observe:
if !ok {
if !bypassGather {
Of(op).SendContext(ctx, gather)
}
return
}
if item.Error() {
op.err(ctx, item, gather, operator)
} else {
op.next(ctx, item, gather, operator)
}
}
}
}()
}

go func() {
wg.Wait()
close(gather)
}()
}

func runFirstItem(ctx context.Context, f func(interface{}) int, notif chan int, observe <-chan Item, next chan Item, iterable Iterable, operatorFactory func() operator, bypassGather bool, option Option, opts ...Option) {
go func() {
op := operatorFactory()
stopped := false
operator := operatorOptions{
stop: func() {
if option.getErrorStrategy() == StopOnError {
stopped = true
}
},
resetIterable: func(newIterable Iterable) {
observe = newIterable.Observe(opts...)
},
}

loop:
for !stopped {
select {
case <-ctx.Done():
break loop
case i, ok := <-observe:
if !ok {
break loop
}
if i.Error() {
op.err(ctx, i, next, operator)
} else {
op.next(ctx, i, next, operator)
notif <- f(i.V)
}
}
}
op.end(ctx, next)
close(next)
}()
}

func customObservableOperator(f func(ctx context.Context, next chan Item, option Option, opts ...Option), opts ...Option) Observable {
option := parseOptions(opts...)

Expand Down
98 changes: 98 additions & 0 deletions observable_operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -1936,6 +1936,104 @@ func (o *ObservableImpl) Serialize(from int, identifier func(interface{}) int, o
}
}

func (o *ObservableImpl) serialize(fromCh chan int, identifier func(interface{}) int, opts ...Option) Observable {
option := parseOptions(opts...)
next := option.buildChannel()

ctx := option.buildContext()
mutex := sync.Mutex{}
minHeap := binaryheap.NewWith(func(a, b interface{}) int {
return a.(int) - b.(int)
})
status := make(map[int]interface{})
notif := make(chan struct{})
ready := make(chan struct{})

var from int
var counter int64
src := o.Observe(opts...)
go func() {
from = <-fromCh
minHeap.Push(from)
counter = int64(from)
close(ready)
}()

// Scatter
go func() {
<-ready
defer close(notif)

for {
select {
case <-ctx.Done():
return
case item, ok := <-src:
if !ok {
return
}
if item.Error() {
next <- item
return
}

id := identifier(item.V)
mutex.Lock()
if id != from {
minHeap.Push(id)
}
status[id] = item.V
mutex.Unlock()
select {
case <-ctx.Done():
return
case notif <- struct{}{}:
}
}
}
}()

// Gather
go func() {
<-ready
defer close(next)

for {
select {
case <-ctx.Done():
return
case _, ok := <-notif:
if !ok {
return
}

mutex.Lock()
for !minHeap.Empty() {
v, _ := minHeap.Peek()
id := v.(int)
if atomic.LoadInt64(&counter) == int64(id) {
if itemValue, contains := status[id]; contains {
minHeap.Pop()
delete(status, id)
mutex.Unlock()
Of(itemValue).SendContext(ctx, next)
mutex.Lock()
atomic.AddInt64(&counter, 1)
continue
}
}
break
}
mutex.Unlock()
}
}
}()

return &ObservableImpl{
iterable: newChannelIterable(next),
}
}

// Skip suppresses the first n items in the original Observable and
// returns a new Observable with the rest items.
// Cannot be run in parallel.
Expand Down
16 changes: 16 additions & 0 deletions observable_operator_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@ func Benchmark_Range_Serialize(b *testing.B) {
}
}

func Benchmark_Range_OptionSerialize(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
obs := Range(0, benchNumberOfElementsSmall, WithBufferedChannel(benchChannelCap)).
Map(func(_ context.Context, i interface{}) (interface{}, error) {
// Simulate a blocking IO call
time.Sleep(5 * time.Millisecond)
return i, nil
}, WithCPUPool(), WithBufferedChannel(benchChannelCap), Serialize(func(i interface{}) int {
return i.(int)
}))
b.StartTimer()
<-obs.Run()
}
}

func Benchmark_Reduce_Sequential(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
Expand Down
78 changes: 78 additions & 0 deletions observable_operator_option_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package rxgo

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
)

func Test_Observable_Option_WithOnErrorStrategy_Single(t *testing.T) {
obs := testObservable(1, 2, 3).
Map(func(_ context.Context, i interface{}) (interface{}, error) {
if i == 2 {
return nil, errFoo
}
return i, nil
}, WithErrorStrategy(ContinueOnError))
Assert(context.Background(), t, obs, HasItems(1, 3), HasError(errFoo))
}

func Test_Observable_Option_WithOnErrorStrategy_Propagate(t *testing.T) {
obs := testObservable(1, 2, 3).
Map(func(_ context.Context, i interface{}) (interface{}, error) {
if i == 1 {
return nil, errFoo
}
return i, nil
}).
Map(func(_ context.Context, i interface{}) (interface{}, error) {
if i == 2 {
return nil, errBar
}
return i, nil
}, WithErrorStrategy(ContinueOnError))
Assert(context.Background(), t, obs, HasItems(3), HasErrors(errFoo, errBar))
}

func Test_Observable_Option_SimpleCapacity(t *testing.T) {
ch := Just(1, WithBufferedChannel(5)).Observe()
assert.Equal(t, 5, cap(ch))
}

func Test_Observable_Option_ComposedCapacity(t *testing.T) {
obs1 := Just(1).Map(func(_ context.Context, _ interface{}) (interface{}, error) {
return 1, nil
}, WithBufferedChannel(11))
obs2 := obs1.Map(func(_ context.Context, _ interface{}) (interface{}, error) {
return 1, nil
}, WithBufferedChannel(12))

assert.Equal(t, 11, cap(obs1.Observe()))
assert.Equal(t, 12, cap(obs2.Observe()))
}

func Test_Observable_Option_ContextPropagation(t *testing.T) {
expectedCtx := context.Background()
var gotCtx context.Context
<-Just(1).Map(func(ctx context.Context, i interface{}) (interface{}, error) {
gotCtx = ctx
return i, nil
}, WithContext(expectedCtx)).Run()
assert.Equal(t, expectedCtx, gotCtx)
}

func Test_Observable_Option_Serialize(t *testing.T) {
idx := 0
<-Range(0, 10000).Map(func(_ context.Context, i interface{}) (interface{}, error) {
return i, nil
}, WithBufferedChannel(1), WithCPUPool(), Serialize(func(i interface{}) int {
return i.(int)
})).DoOnNext(func(i interface{}) {
v := i.(int)
if v != idx {
assert.FailNow(t, "not sequential", "expected=%d, got=%d", idx, v)
}
idx++
})
}
Loading

0 comments on commit 4905cb2

Please sign in to comment.