Skip to content

Commit

Permalink
shards: unload shards only once
Browse files Browse the repository at this point in the history
Add a test for shard loader

Fixes google#38.

Change-Id: I5d97e10d0624018c439e0c65341ee4621f9118c8
  • Loading branch information
hanwen committed Jan 15, 2018
1 parent 13a8efc commit dca4ed5
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 6 deletions.
14 changes: 8 additions & 6 deletions shards/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type shardWatcher struct {
dir string
timestamps map[string]time.Time
loader shardLoader
quit chan struct{}
quit chan<- struct{}
}

func (sw *shardWatcher) Close() error {
Expand All @@ -47,17 +47,18 @@ func (sw *shardWatcher) Close() error {
}

func NewDirectoryWatcher(dir string, loader shardLoader) (io.Closer, error) {
quitter := make(chan struct{}, 1)
sw := &shardWatcher{
dir: dir,
timestamps: map[string]time.Time{},
loader: loader,
quit: make(chan struct{}, 1),
quit: quitter,
}
if err := sw.scan(); err != nil {
return nil, err
}

if err := sw.watch(); err != nil {
if err := sw.watch(quitter); err != nil {
return nil, err
}

Expand Down Expand Up @@ -93,7 +94,7 @@ func (s *shardWatcher) scan() error {
return err
}

if len(fs) == 0 {
if len(s.timestamps) == 0 && len(fs) == 0 {
return fmt.Errorf("directory %s is empty", s.dir)
}

Expand All @@ -120,6 +121,7 @@ func (s *shardWatcher) scan() error {
for k := range s.timestamps {
if _, ok := ts[k]; !ok {
toDrop = append(toDrop, k)
delete(s.timestamps, k)
}
}

Expand All @@ -135,7 +137,7 @@ func (s *shardWatcher) scan() error {
return nil
}

func (s *shardWatcher) watch() error {
func (s *shardWatcher) watch(quitter <-chan struct{}) error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
Expand All @@ -153,7 +155,7 @@ func (s *shardWatcher) watch() error {
if err != nil {
log.Println("watcher error:", err)
}
case <-s.quit:
case <-quitter:
watcher.Close()
return
}
Expand Down
106 changes: 106 additions & 0 deletions shards/watcher_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright 2018 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package shards

import (
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
"time"
)

type loggingLoader struct {
loads chan string
drops chan string
}

func (l *loggingLoader) load(k string) {
l.loads <- k
}

func (l *loggingLoader) drop(k string) {
l.drops <- k
}

func advanceFS() {
time.Sleep(10 * time.Millisecond)
}

func TestDirWatcherUnloadOnce(t *testing.T) {
dir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}

logger := &loggingLoader{
loads: make(chan string, 10),
drops: make(chan string, 10),
}
_, err = NewDirectoryWatcher(dir, logger)
if err == nil || !strings.Contains(err.Error(), "empty") {
t.Fatalf("got %v, want 'empty'", err)
}

shard := filepath.Join(dir, "foo.zoekt")
if err := ioutil.WriteFile(shard, []byte("hello"), 0644); err != nil {
t.Fatalf("WriteFile: %v", err)
}

dw, err := NewDirectoryWatcher(dir, logger)
if err != nil {
t.Fatalf("NewDirectoryWatcher: %v", err)
}
defer dw.Close()

if got := <-logger.loads; got != shard {
t.Fatalf("got load event %v, want %v", got, shard)
}

// Must sleep because of FS timestamp resolution.
advanceFS()
if err := ioutil.WriteFile(shard, []byte("changed"), 0644); err != nil {
t.Fatalf("WriteFile: %v", err)
}

if got := <-logger.loads; got != shard {
t.Fatalf("got load event %v, want %v", got, shard)
}

advanceFS()
if err := os.Remove(shard); err != nil {
t.Fatalf("Remove: %v", err)
}

if got := <-logger.drops; got != shard {
t.Fatalf("got drops event %v, want %v", got, shard)
}

advanceFS()
if err := ioutil.WriteFile(shard+".bla", []byte("changed"), 0644); err != nil {
t.Fatalf("WriteFile: %v", err)
}

dw.Close()

select {
case k := <-logger.loads:
t.Errorf("spurious load of %q", k)
case k := <-logger.drops:
t.Errorf("spurious drops of %q", k)
default:
}
}

0 comments on commit dca4ed5

Please sign in to comment.