Skip to content

Commit

Permalink
fix(hub): ensure that an update is dispatched if any of its topics is…
Browse files Browse the repository at this point in the history
… subscribed and allowed (dunglas#688)

* chore: enable debug mode by default in VS Code

* fix(hub): ensure that an update is dispatched if any of its topics is subscribed and allowed
  • Loading branch information
dunglas authored Aug 12, 2022
1 parent a9a90ca commit aff4aab
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 88 deletions.
3 changes: 2 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
"env": {
"MERCURE_PUBLISHER_JWT_KEY": "!ChangeMe!",
"MERCURE_SUBSCRIBER_JWT_KEY": "!ChangeMe!",
"MERCURE_EXTRA_DIRECTIVES": "anonymous"
"MERCURE_EXTRA_DIRECTIVES": "anonymous",
"GLOBAL_OPTIONS": "debug"
},
"args": ["run", "-config", "../../Caddyfile.dev"]
}
Expand Down
6 changes: 3 additions & 3 deletions bolt_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,12 @@ func TestBoltTransportDispatch(t *testing.T) {
subscribedNotAuthorized := &Update{Topics: []string{"https://example.com/foo"}, Private: true}
require.Nil(t, transport.Dispatch(subscribedNotAuthorized))

public := &Update{Topics: s.Topics}
public := &Update{Topics: s.SubscribedTopics}
require.Nil(t, transport.Dispatch(public))

assert.Equal(t, public, <-s.Receive())

private := &Update{Topics: s.PrivateTopics, Private: true}
private := &Update{Topics: s.AllowedPrivateTopics, Private: true}
require.Nil(t, transport.Dispatch(private))

assert.Equal(t, private, <-s.Receive())
Expand All @@ -256,7 +256,7 @@ func TestBoltTransportClosed(t *testing.T) {
require.Nil(t, transport.Close())
require.NotNil(t, transport.AddSubscriber(s))

assert.Equal(t, transport.Dispatch(&Update{Topics: s.Topics}), ErrClosedTransport)
assert.Equal(t, transport.Dispatch(&Update{Topics: s.SubscribedTopics}), ErrClosedTransport)

_, ok := <-s.out
assert.False(t, ok)
Expand Down
4 changes: 2 additions & 2 deletions local_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestLocalTransportDispatch(t *testing.T) {
s.SetTopics([]string{"http://example.com/foo"}, nil)
assert.Nil(t, transport.AddSubscriber(s))

u := &Update{Topics: s.Topics}
u := &Update{Topics: s.SubscribedTopics}
require.Nil(t, transport.Dispatch(u))
assert.Equal(t, u, <-s.Receive())
}
Expand Down Expand Up @@ -97,7 +97,7 @@ func TestLiveReading(t *testing.T) {
s.SetTopics([]string{"https://example.com"}, nil)
require.Nil(t, transport.AddSubscriber(s))

u := &Update{Topics: s.Topics}
u := &Update{Topics: s.SubscribedTopics}
assert.Nil(t, transport.Dispatch(u))

receivedUpdate := <-s.Receive()
Expand Down
6 changes: 3 additions & 3 deletions publish_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func TestPublishOK(t *testing.T) {
assert.True(t, ok)
require.NotNil(t, u)
assert.Equal(t, "id", u.ID)
assert.Equal(t, s.Topics, u.Topics)
assert.Equal(t, s.SubscribedTopics, u.Topics)
assert.Equal(t, "Hello!", u.Data)
assert.True(t, u.Private)
}(&wg)
Expand All @@ -201,7 +201,7 @@ func TestPublishOK(t *testing.T) {

req := httptest.NewRequest(http.MethodPost, defaultHubURL, strings.NewReader(form.Encode()))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, rolePublisher, s.Topics))
req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, rolePublisher, s.SubscribedTopics))

w := httptest.NewRecorder()
hub.PublishHandler(w, req)
Expand Down Expand Up @@ -239,7 +239,7 @@ func TestPublishGenerateUUID(t *testing.T) {
h := createDummy()

s := NewSubscriber("", zap.NewNop())
s.SetTopics([]string{"http://example.com/books/1"}, s.Topics)
s.SetTopics([]string{"http://example.com/books/1"}, s.SubscribedTopics)

require.Nil(t, h.transport.AddSubscriber(s))

Expand Down
118 changes: 62 additions & 56 deletions subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ import (

// Subscriber represents a client subscribed to a list of topics.
type Subscriber struct {
ID string
EscapedID string
Claims *claims
EscapedTopics []string
RequestLastEventID string
RemoteAddr string
Topics []string
TopicRegexps []*regexp.Regexp
PrivateTopics []string
PrivateRegexps []*regexp.Regexp
Debug bool
ID string
EscapedID string
Claims *claims
EscapedTopics []string
RequestLastEventID string
RemoteAddr string
SubscribedTopics []string
SubscribedTopicRegexps []*regexp.Regexp
AllowedPrivateTopics []string
AllowedPrivateRegexps []*regexp.Regexp
Debug bool

disconnected int32
out chan *Update
Expand Down Expand Up @@ -121,26 +121,26 @@ func (s *Subscriber) Disconnect() {
}

// SetTopics compiles topic selector regexps.
func (s *Subscriber) SetTopics(topics, privateTopics []string) {
s.Topics = topics
s.TopicRegexps = make([]*regexp.Regexp, len(topics))
for i, ts := range topics {
func (s *Subscriber) SetTopics(subscribedTopics, allowedPrivateTopics []string) {
s.SubscribedTopics = subscribedTopics
s.SubscribedTopicRegexps = make([]*regexp.Regexp, len(subscribedTopics))
for i, ts := range subscribedTopics {
var r *regexp.Regexp
if tpl, err := uritemplate.New(ts); err == nil {
r = tpl.Regexp()
}
s.TopicRegexps[i] = r
s.SubscribedTopicRegexps[i] = r
}
s.PrivateTopics = privateTopics
s.PrivateRegexps = make([]*regexp.Regexp, len(privateTopics))
for i, ts := range privateTopics {
s.AllowedPrivateTopics = allowedPrivateTopics
s.AllowedPrivateRegexps = make([]*regexp.Regexp, len(allowedPrivateTopics))
for i, ts := range allowedPrivateTopics {
var r *regexp.Regexp
if tpl, err := uritemplate.New(ts); err == nil {
r = tpl.Regexp()
}
s.PrivateRegexps[i] = r
s.AllowedPrivateRegexps[i] = r
}
s.EscapedTopics = escapeTopics(topics)
s.EscapedTopics = escapeTopics(subscribedTopics)
}

func escapeTopics(topics []string) []string {
Expand All @@ -153,36 +153,48 @@ func escapeTopics(topics []string) []string {
}

// MatchTopic checks if the current subscriber can access to the given topic.
func (s *Subscriber) MatchTopic(topic string, private bool) (match bool) {
for i, ts := range s.Topics {
if ts == "*" || ts == topic {
match = true
//
//nolint:gocognit
func (s *Subscriber) MatchTopics(topics []string, private bool) bool {
var subscribed bool
canAccess := !private

break
}
for _, topic := range topics {
if !subscribed {
for i, ts := range s.SubscribedTopics {
if ts == "*" || ts == topic {
subscribed = true

r := s.TopicRegexps[i]
if r != nil && r.MatchString(topic) {
match = true
break
}

break
r := s.SubscribedTopicRegexps[i]
if r != nil && r.MatchString(topic) {
subscribed = true

break
}
}
}
}

if !match {
return false
}
if !private {
return true
}
if !canAccess {
for i, ts := range s.AllowedPrivateTopics {
if ts == "*" || ts == topic {
canAccess = true

for i, ts := range s.PrivateTopics {
if ts == "*" || ts == topic {
return true
break
}

r := s.AllowedPrivateRegexps[i]
if r != nil && r.MatchString(topic) {
canAccess = true

break
}
}
}

r := s.PrivateRegexps[i]
if r != nil && r.MatchString(topic) {
if subscribed && canAccess {
return true
}
}
Expand All @@ -192,20 +204,14 @@ func (s *Subscriber) MatchTopic(topic string, private bool) (match bool) {

// Match checks if the current subscriber can receive the given update.
func (s *Subscriber) Match(u *Update) bool {
for _, t := range u.Topics {
if s.MatchTopic(t, u.Private) {
return true
}
}

return false
return s.MatchTopics(u.Topics, u.Private)
}

// getSubscriptions return the list of subscriptions associated to this subscriber.
func (s *Subscriber) getSubscriptions(topic, context string, active bool) []subscription {
var subscriptions []subscription //nolint:prealloc
for k, t := range s.Topics {
if topic != "" && !s.MatchTopic(topic, false) {
for k, t := range s.SubscribedTopics {
if topic != "" && !s.MatchTopics([]string{topic}, false) {
continue
}

Expand Down Expand Up @@ -233,13 +239,13 @@ func (s *Subscriber) MarshalLogObject(enc zapcore.ObjectEncoder) error {
if s.RemoteAddr != "" {
enc.AddString("remote_addr", s.RemoteAddr)
}
if s.PrivateTopics != nil {
if err := enc.AddArray("topic_selectors", stringArray(s.PrivateTopics)); err != nil {
if s.AllowedPrivateTopics != nil {
if err := enc.AddArray("topic_selectors", stringArray(s.AllowedPrivateTopics)); err != nil {
return fmt.Errorf("log error: %w", err)
}
}
if s.Topics != nil {
if err := enc.AddArray("topics", stringArray(s.Topics)); err != nil {
if s.SubscribedTopics != nil {
if err := enc.AddArray("topics", stringArray(s.SubscribedTopics)); err != nil {
return fmt.Errorf("log error: %w", err)
}
}
Expand Down
27 changes: 10 additions & 17 deletions subscriber_list.go
Original file line number Diff line number Diff line change
@@ -1,39 +1,32 @@
package mercure

import (
"strings"

"github.com/kevburnsjr/skipfilter"
)

type filter struct {
topics []string
private bool
}

type SubscriberList struct {
skipfilter *skipfilter.SkipFilter
}

func NewSubscriberList(size int) *SubscriberList {
return &SubscriberList{
skipfilter: skipfilter.New(func(s interface{}, topic interface{}) bool {
p := strings.SplitN(topic.(string), "_", 2)
if len(p) < 2 {
return false
}
skipfilter: skipfilter.New(func(s interface{}, fil interface{}) bool {
f := fil.(*filter)

return s.(*Subscriber).MatchTopic(p[1], p[0] == "p")
return s.(*Subscriber).MatchTopics(f.topics, f.private)
}, size),
}
}

func (sc *SubscriberList) MatchAny(u *Update) (res []*Subscriber) {
scopedTopics := make([]interface{}, len(u.Topics))
for i, t := range u.Topics {
if u.Private {
scopedTopics[i] = "p_" + t
} else {
scopedTopics[i] = "_" + t
}
}
f := &filter{u.Topics, u.Private}

for _, m := range sc.skipfilter.MatchAny(scopedTopics...) {
for _, m := range sc.skipfilter.MatchAny(f) {
res = append(res, m.(*Subscriber))
}

Expand Down
25 changes: 20 additions & 5 deletions subscriber_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ import (

func TestDispatch(t *testing.T) {
s := NewSubscriber("1", zap.NewNop())
s.Topics = []string{"http://example.com"}
s.SubscribedTopics = []string{"http://example.com"}
s.SubscribedTopics = []string{"http://example.com"}
defer s.Disconnect()

// Dispatch must be non-blocking
// Messages coming from the history can be sent after live messages, but must be received first
s.Dispatch(&Update{Topics: s.Topics, Event: Event{ID: "3"}}, false)
s.Dispatch(&Update{Topics: s.Topics, Event: Event{ID: "1"}}, true)
s.Dispatch(&Update{Topics: s.Topics, Event: Event{ID: "4"}}, false)
s.Dispatch(&Update{Topics: s.Topics, Event: Event{ID: "2"}}, true)
s.Dispatch(&Update{Topics: s.SubscribedTopics, Event: Event{ID: "3"}}, false)
s.Dispatch(&Update{Topics: s.SubscribedTopics, Event: Event{ID: "1"}}, true)
s.Dispatch(&Update{Topics: s.SubscribedTopics, Event: Event{ID: "4"}}, false)
s.Dispatch(&Update{Topics: s.SubscribedTopics, Event: Event{ID: "2"}}, true)
s.HistoryDispatched("")

s.Ready()
Expand Down Expand Up @@ -56,3 +57,17 @@ func TestLogSubscriber(t *testing.T) {
assert.Contains(t, log, `"topic_selectors":["https://example.com/foo"]`)
assert.Contains(t, log, `"topics":["https://example.com/bar"]`)
}

func TestMatchTopic(t *testing.T) {
s := NewSubscriber("", zap.NewNop())
s.SetTopics([]string{"https://example.com/no-match", "https://example.com/books/{id}"}, []string{"https://example.com/users/foo/{?topic}"})

assert.False(t, s.Match(&Update{Topics: []string{"https://example.com/not-subscribed"}}))
assert.False(t, s.Match(&Update{Topics: []string{"https://example.com/not-subscribed"}, Private: true}))
assert.False(t, s.Match(&Update{Topics: []string{"https://example.com/no-match"}, Private: true}))
assert.False(t, s.Match(&Update{Topics: []string{"https://example.com/books/1"}, Private: true}))
assert.False(t, s.Match(&Update{Topics: []string{"https://example.com/books/1", "https://example.com/users/bar/?topic=https%3A%2F%2Fexample.com%2Fbooks%2F1"}, Private: true}))

assert.True(t, s.Match(&Update{Topics: []string{"https://example.com/books/1"}}))
assert.True(t, s.Match(&Update{Topics: []string{"https://example.com/books/1", "https://example.com/users/foo/?topic=https%3A%2F%2Fexample.com%2Fbooks%2F1"}, Private: true}))
}
2 changes: 1 addition & 1 deletion subscription_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func TestSubscriptionHandler(t *testing.T) {

var subscription subscription
json.Unmarshal(w.Body.Bytes(), &subscription)
expectedSub := s.getSubscriptions(s.Topics[1], "https://mercure.rocks/", true)[1]
expectedSub := s.getSubscriptions(s.SubscribedTopics[1], "https://mercure.rocks/", true)[1]
expectedSub.LastEventID, _, _ = hub.transport.(TransportSubscribers).GetSubscribers()
assert.Equal(t, expectedSub, subscription)

Expand Down

0 comments on commit aff4aab

Please sign in to comment.