Skip to content

Commit

Permalink
Context propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
teivah committed Feb 21, 2020
1 parent 5374a5b commit 95b2a9f
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 81 deletions.
16 changes: 8 additions & 8 deletions factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ func Test_Defer_ComposedDup(t *testing.T) {
next <- Of(2)
next <- Of(3)
done()
}}).Map(func(i interface{}) (_ interface{}, _ error) {
}}).Map(func(_ context.Context, i interface{}) (_ interface{}, _ error) {
return i.(int) + 1, nil
}).Map(func(i interface{}) (_ interface{}, _ error) {
}).Map(func(_ context.Context, i interface{}) (_ interface{}, _ error) {
return i.(int) + 1, nil
})
Assert(context.Background(), t, obs, HasItems(3, 4, 5), HasNotRaisedError())
Expand All @@ -166,9 +166,9 @@ func Test_Defer_ComposedDup_EagerObservation(t *testing.T) {
next <- Of(2)
next <- Of(3)
done()
}}).Map(func(i interface{}) (_ interface{}, _ error) {
}}).Map(func(_ context.Context, i interface{}) (_ interface{}, _ error) {
return i.(int) + 1, nil
}, WithEagerObservation()).Map(func(i interface{}) (_ interface{}, _ error) {
}, WithEagerObservation()).Map(func(_ context.Context, i interface{}) (_ interface{}, _ error) {
return i.(int) + 1, nil
})
Assert(context.Background(), t, obs, HasItems(3, 4, 5), HasNotRaisedError())
Expand Down Expand Up @@ -211,12 +211,12 @@ func Test_FromChannel_SimpleCapacity(t *testing.T) {

func Test_FromChannel_ComposedCapacity(t *testing.T) {
obs1 := FromChannel(make(chan Item, 10)).
Map(func(_ interface{}) (interface{}, error) {
Map(func(_ context.Context, _ interface{}) (interface{}, error) {
return 1, nil
}, WithBufferedChannel(11))
assert.Equal(t, 11, cap(obs1.Observe()))

obs2 := obs1.Map(func(_ interface{}) (interface{}, error) {
obs2 := obs1.Map(func(_ context.Context, _ interface{}) (interface{}, error) {
return 1, nil
}, WithBufferedChannel(12))
assert.Equal(t, 12, cap(obs2.Observe()))
Expand All @@ -240,12 +240,12 @@ func Test_FromItems_SimpleCapacity(t *testing.T) {
}

func Test_FromItems_ComposedCapacity(t *testing.T) {
obs1 := Just([]Item{Of(1)}).Map(func(_ interface{}) (interface{}, error) {
obs1 := Just([]Item{Of(1)}).Map(func(_ context.Context, _ interface{}) (interface{}, error) {
return 1, nil
}, WithBufferedChannel(11))
assert.Equal(t, 11, cap(obs1.Observe()))

obs2 := obs1.Map(func(_ interface{}) (interface{}, error) {
obs2 := obs1.Map(func(_ context.Context, _ interface{}) (interface{}, error) {
return 1, nil
}, WithBufferedChannel(12))
assert.Equal(t, 12, cap(obs2.Observe()))
Expand Down
43 changes: 21 additions & 22 deletions observable_operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,8 @@ type distinctOperator struct {
keyset map[interface{}]interface{}
}

func (op *distinctOperator) next(_ context.Context, item Item, dst chan<- Item, operatorOptions operatorOptions) {
key, err := op.apply(item.V)
func (op *distinctOperator) next(ctx context.Context, item Item, dst chan<- Item, operatorOptions operatorOptions) {
key, err := op.apply(ctx, item.V)
if err != nil {
dst <- Error(err)
operatorOptions.stop()
Expand Down Expand Up @@ -806,8 +806,8 @@ type distinctUntilChangedOperator struct {
current interface{}
}

func (op *distinctUntilChangedOperator) next(_ context.Context, item Item, dst chan<- Item, operatorOptions operatorOptions) {
key, err := op.apply(item.V)
func (op *distinctUntilChangedOperator) next(ctx context.Context, item Item, dst chan<- Item, operatorOptions operatorOptions) {
key, err := op.apply(ctx, item.V)
if err != nil {
dst <- Error(err)
operatorOptions.stop()
Expand Down Expand Up @@ -1285,9 +1285,8 @@ type mapOperator struct {
apply Func
}

// TODO pass context in map?
func (op *mapOperator) next(_ context.Context, item Item, dst chan<- Item, operatorOptions operatorOptions) {
res, err := op.apply(item.V)
func (op *mapOperator) next(ctx context.Context, item Item, dst chan<- Item, operatorOptions operatorOptions) {
res, err := op.apply(ctx, item.V)
if err != nil {
dst <- Error(err)
operatorOptions.stop()
Expand All @@ -1313,7 +1312,7 @@ func (op *mapOperator) gatherNext(_ context.Context, item Item, dst chan<- Item,

// Marshal transforms the items emitted by an Observable by applying a marshalling to each item.
func (o *ObservableImpl) Marshal(marshaller Marshaller, opts ...Option) Observable {
return o.Map(func(i interface{}) (interface{}, error) {
return o.Map(func(_ context.Context, i interface{}) (interface{}, error) {
return marshaller(i)
}, opts...)
}
Expand Down Expand Up @@ -1500,9 +1499,9 @@ type reduceOperator struct {
empty bool
}

func (op *reduceOperator) next(_ context.Context, item Item, dst chan<- Item, operatorOptions operatorOptions) {
func (op *reduceOperator) next(ctx context.Context, item Item, dst chan<- Item, operatorOptions operatorOptions) {
op.empty = false
v, err := op.apply(op.acc, item.V)
v, err := op.apply(ctx, op.acc, item.V)
if err != nil {
dst <- Error(err)
operatorOptions.stop()
Expand Down Expand Up @@ -1735,8 +1734,8 @@ type scanOperator struct {
current interface{}
}

func (op *scanOperator) next(_ context.Context, item Item, dst chan<- Item, operatorOptions operatorOptions) {
v, err := op.apply(op.current, item.V)
func (op *scanOperator) next(ctx context.Context, item Item, dst chan<- Item, operatorOptions operatorOptions) {
v, err := op.apply(ctx, op.current, item.V)
if err != nil {
dst <- Error(err)
operatorOptions.stop()
Expand Down Expand Up @@ -2119,7 +2118,7 @@ func (o *ObservableImpl) StartWithIterable(iterable Iterable, opts ...Option) Ob

// SumFloat32 calculates the average of float32 emitted by an Observable and emits a float32.
func (o *ObservableImpl) SumFloat32(opts ...Option) OptionalSingle {
return o.Reduce(func(acc interface{}, elem interface{}) (interface{}, error) {
return o.Reduce(func(_ context.Context, acc interface{}, elem interface{}) (interface{}, error) {
if acc == nil {
acc = float32(0)
}
Expand All @@ -2145,7 +2144,7 @@ func (o *ObservableImpl) SumFloat32(opts ...Option) OptionalSingle {

// SumFloat64 calculates the average of float64 emitted by an Observable and emits a float64.
func (o *ObservableImpl) SumFloat64(opts ...Option) OptionalSingle {
return o.Reduce(func(acc interface{}, elem interface{}) (interface{}, error) {
return o.Reduce(func(_ context.Context, acc interface{}, elem interface{}) (interface{}, error) {
if acc == nil {
acc = float64(0)
}
Expand Down Expand Up @@ -2173,7 +2172,7 @@ func (o *ObservableImpl) SumFloat64(opts ...Option) OptionalSingle {

// SumInt64 calculates the average of integers emitted by an Observable and emits an int64.
func (o *ObservableImpl) SumInt64(opts ...Option) OptionalSingle {
return o.Reduce(func(acc interface{}, elem interface{}) (interface{}, error) {
return o.Reduce(func(_ context.Context, acc interface{}, elem interface{}) (interface{}, error) {
if acc == nil {
acc = int64(0)
}
Expand Down Expand Up @@ -2357,8 +2356,8 @@ type toMapOperator struct {
m map[interface{}]interface{}
}

func (op *toMapOperator) next(_ context.Context, item Item, dst chan<- Item, operatorOptions operatorOptions) {
k, err := op.keySelector(item.V)
func (op *toMapOperator) next(ctx context.Context, item Item, dst chan<- Item, operatorOptions operatorOptions) {
k, err := op.keySelector(ctx, item.V)
if err != nil {
dst <- Error(err)
operatorOptions.stop()
Expand Down Expand Up @@ -2397,15 +2396,15 @@ type toMapWithValueSelector struct {
m map[interface{}]interface{}
}

func (op *toMapWithValueSelector) next(_ context.Context, item Item, dst chan<- Item, operatorOptions operatorOptions) {
k, err := op.keySelector(item.V)
func (op *toMapWithValueSelector) next(ctx context.Context, item Item, dst chan<- Item, operatorOptions operatorOptions) {
k, err := op.keySelector(ctx, item.V)
if err != nil {
dst <- Error(err)
operatorOptions.stop()
return
}

v, err := op.valueSelector(item.V)
v, err := op.valueSelector(ctx, item.V)
if err != nil {
dst <- Error(err)
operatorOptions.stop()
Expand Down Expand Up @@ -2461,7 +2460,7 @@ func (op *toSliceOperator) gatherNext(_ context.Context, _ Item, _ chan<- Item,

// Unmarshal transforms the items emitted by an Observable by applying an unmarshalling to each item.
func (o *ObservableImpl) Unmarshal(unmarshaller Unmarshaller, factory func() interface{}, opts ...Option) Observable {
return o.Map(func(i interface{}) (interface{}, error) {
return o.Map(func(_ context.Context, i interface{}) (interface{}, error) {
v := factory()
err := unmarshaller(i.([]byte), v)
if err != nil {
Expand Down Expand Up @@ -2507,7 +2506,7 @@ func (o *ObservableImpl) ZipFromIterable(iterable Iterable, zipper Func2, opts .
next <- i2
return
}
v, err := zipper(i1.V, i2.V)
v, err := zipper(ctx, i1.V, i2.V)
if err != nil {
next <- Error(err)
return
Expand Down
13 changes: 7 additions & 6 deletions observable_operator_bench_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rxgo

import (
"context"
"testing"
"time"
)
Expand All @@ -16,7 +17,7 @@ func Benchmark_Range_Sequential(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
obs := Range(0, benchNumberOfElementsLarge, WithBufferedChannel(benchChannelCap)).
Map(func(i interface{}) (interface{}, error) {
Map(func(_ context.Context, i interface{}) (interface{}, error) {
return i, nil
})
b.StartTimer()
Expand All @@ -28,7 +29,7 @@ func Benchmark_Range_Serialize(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
obs := Range(0, benchNumberOfElementsLarge, WithBufferedChannel(benchChannelCap)).
Map(func(i interface{}) (interface{}, error) {
Map(func(_ context.Context, i interface{}) (interface{}, error) {
return i, nil
}, WithCPUPool(), WithBufferedChannel(benchChannelCap))
b.StartTimer()
Expand All @@ -40,7 +41,7 @@ func Benchmark_Reduce_Sequential(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
obs := Range(0, benchNumberOfElementsSmall, WithBufferedChannel(benchChannelCap)).
Reduce(func(acc interface{}, elem interface{}) (interface{}, error) {
Reduce(func(_ context.Context, acc interface{}, elem interface{}) (interface{}, error) {
// Simulate a blocking IO call
time.Sleep(5 * time.Millisecond)
if a, ok := acc.(int); ok {
Expand All @@ -61,7 +62,7 @@ func Benchmark_Reduce_Parallel(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
obs := Range(0, benchNumberOfElementsSmall, WithBufferedChannel(benchChannelCap)).
Reduce(func(acc interface{}, elem interface{}) (interface{}, error) {
Reduce(func(_ context.Context, acc interface{}, elem interface{}) (interface{}, error) {
// Simulate a blocking IO call
time.Sleep(5 * time.Millisecond)
if a, ok := acc.(int); ok {
Expand All @@ -82,7 +83,7 @@ func Benchmark_Map_Sequential(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
obs := Range(0, benchNumberOfElementsSmall, WithBufferedChannel(benchChannelCap)).
Map(func(i interface{}) (interface{}, error) {
Map(func(_ context.Context, i interface{}) (interface{}, error) {
// Simulate a blocking IO call
time.Sleep(5 * time.Millisecond)
return i, nil
Expand All @@ -96,7 +97,7 @@ func Benchmark_Map_Parallel(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
obs := Range(0, benchNumberOfElementsSmall, WithBufferedChannel(benchChannelCap)).
Map(func(i interface{}) (interface{}, error) {
Map(func(_ context.Context, i interface{}) (interface{}, error) {
// Simulate a blocking IO call
time.Sleep(5 * time.Millisecond)
return i, nil
Expand Down
Loading

0 comments on commit 95b2a9f

Please sign in to comment.