Skip to content

Commit

Permalink
Add serialization corruption
Browse files Browse the repository at this point in the history
  • Loading branch information
bstrausser committed May 2, 2024
1 parent c2d3e29 commit 331d5c0
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 3 deletions.
5 changes: 4 additions & 1 deletion arraycontainer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ type arrayContainer struct {
content []uint16
}

var ErrArrayIncorrectSort = errors.New("incorrectly sorted array")

func (ac *arrayContainer) String() string {
s := "{"
for it := ac.getShortIterator(); it.hasNext(); {
Expand Down Expand Up @@ -1097,6 +1099,7 @@ func (ac *arrayContainer) addOffset(x uint16) (container, container) {
func (ac *arrayContainer) validate() error {
cardinality := ac.getCardinality()

// TODO use ERR consts
if cardinality <= 0 {
return errors.New("zero or negative size")
}
Expand All @@ -1109,7 +1112,7 @@ func (ac *arrayContainer) validate() error {
for i := 1; i < len(ac.content); i++ {
next := ac.content[i]
if previous > next {
return errors.New("incorrectly sorted array")
return ErrArrayIncorrectSort
}
previous = next

Expand Down
70 changes: 70 additions & 0 deletions roaring_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2678,6 +2678,76 @@ func TestBitMapValidation(t *testing.T) {
assert.NoError(t, bm.Validate())
}

func TestBitMapValidationFromDeserialization(t *testing.T) {
// To understand what is going on here, read https://github.com/RoaringBitmap/RoaringFormatSpec

defer func() {
if err := recover(); err != nil {
// TODO assert on the error type.
fmt.Println("panicked")
}
}()

bm := NewBitmap()

// Maintainers: If you change this construction you must change the statements below.
// The tests expect a certain size, with values at certain location.
// TODO: A good extension would be to inspect the map and dynamically figure out what you can corrupt
randomEntries := make([]uint32, 0, 10)
for i := 0; i < 10; i++ {
randomEntries = append(randomEntries, uint32(i))
}
bm.AddMany(randomEntries)
assert.NoError(t, bm.Validate())
serialized, err := bm.ToBytes()
assert.NoError(t, err)

deserializedBitMap := NewBitmap()
deserializedBitMap.MustReadFrom(bytes.NewReader(serialized))
deserializedBitMap.Validate()

// corrupt the byte stream
// break sort order, serialized[34] equals 9 in a correct sort
serialized[34] = 0
corruptedDeserializedBitMap := NewBitmap()
corruptedDeserializedBitMap.MustReadFrom(bytes.NewReader(serialized))
// We will never hit this because of the recover
t.Errorf("did not panic")
}

func TestBitMapValidationFromDeserializationRuns(t *testing.T) {
// See above tests for more information

bm := NewBitmap()
bm.AddRange(100, 110)
assert.NoError(t, bm.Validate())
serialized, err := bm.ToBytes()
serialized[13] = 0
assert.NoError(t, err)
corruptedDeserializedBitMap := NewBitmap()
corruptedDeserializedBitMap.ReadFrom(bytes.NewReader(serialized))
assert.ErrorIs(t, corruptedDeserializedBitMap.Validate(), ErrRunIntervalLength)
}

func TestBitMapValidationFromDeserializationNumRuns(t *testing.T) {
// See above tests for more information

bm := NewBitmap()
bm.AddRange(100, 110)
bm.AddRange(115, 125)
// assert.NoError(t, bm.Validate())
serialized, err := bm.ToBytes()
assert.NoError(t, err)
// Force run overlap
serialized[15] = 108

corruptedDeserializedBitMap := NewBitmap()
corruptedDeserializedBitMap.ReadFrom(bytes.NewReader(serialized))
v := corruptedDeserializedBitMap.Validate()
fmt.Println(v)
assert.ErrorIs(t, corruptedDeserializedBitMap.Validate(), ErrRunIntervalOverlap)
}

func BenchmarkFromDense(b *testing.B) {
testDense(func(name string, rb *Bitmap) {
dense := make([]uint64, rb.DenseSize())
Expand Down
8 changes: 6 additions & 2 deletions runcontainer.go
Original file line number Diff line number Diff line change
Expand Up @@ -1509,9 +1509,11 @@ func (iv interval16) isNonContiguousDisjoint(b interval16) bool {
if nonContiguous1 || nonContiguous2 {
return false
}
ivl := iv.last()
bl := b.last()

c1 := iv.start <= b.start && b.start <= iv.last()
c2 := b.start <= iv.start && iv.start <= b.last()
c1 := iv.start <= b.start && b.start <= ivl
c2 := b.start <= iv.start && iv.start <= bl

return !c1 && !c2
}
Expand Down Expand Up @@ -2662,6 +2664,8 @@ func (rc *runContainer16) validate() error {
return ErrRunIntervalEqual
}

// only check the start of runs
// if the run length overlap the next check will catch that.
if outer.start >= inner.start {
return ErrRunNonSorted
}
Expand Down

0 comments on commit 331d5c0

Please sign in to comment.