Skip to content

Commit

Permalink
query/batch: improve context usage (apache#1228)
Browse files Browse the repository at this point in the history
This also fixes an issue where using WithContext mutated the original
query, it now does what http.Request.WithContext does and returns a
shallow copy of the query.

Rename GetContext() to Context() and have it default o returning
context.Background() when no context has been set.

Use context to cancel speculated queries so they dont leak or run beyond
when one wins.

Fix hiding errors when speculating queries.
  • Loading branch information
Zariel authored Oct 25, 2018
1 parent 2127b8d commit 271c061
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 197 deletions.
25 changes: 16 additions & 9 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func TestObserve_Pagination(t *testing.T) {
Iter().Scanner()
for i := 0; i < 50; i++ {
if !scanner.Next() {
t.Fatalf("next: should still be true: %d", i)
t.Fatalf("next: should still be true: %d: %v", i, scanner.Err())
}
if i%10 == 0 {
if observedRows != 10 {
Expand Down Expand Up @@ -1354,38 +1354,45 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
}

func TestPrepare_MissingSchemaPrepare(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

s := createSession(t)
conn := getRandomConn(t, s)
defer s.Close()

insertQry := &Query{stmt: "INSERT INTO invalidschemaprep (val) VALUES (?)", values: []interface{}{5}, cons: s.cons,
session: s, pageSize: s.pageSize, trace: s.trace,
prefetch: s.prefetch, rt: s.cfg.RetryPolicy}

if err := conn.executeQuery(insertQry).err; err == nil {
insertQry := s.Query("INSERT INTO invalidschemaprep (val) VALUES (?)", 5)
if err := conn.executeQuery(ctx, insertQry).err; err == nil {
t.Fatal("expected error, but got nil.")
}

if err := createTable(s, "CREATE TABLE gocql_test.invalidschemaprep (val int, PRIMARY KEY (val))"); err != nil {
t.Fatal("create table:", err)
}

if err := conn.executeQuery(insertQry).err; err != nil {
if err := conn.executeQuery(ctx, insertQry).err; err != nil {
t.Fatal(err) // unconfigured columnfamily
}
}

func TestPrepare_ReprepareStatement(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

session := createSession(t)
defer session.Close()

stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement")
query := session.Query(stmt, "bar")
if err := conn.executeQuery(query).Close(); err != nil {
if err := conn.executeQuery(ctx, query).Close(); err != nil {
t.Fatalf("Failed to execute query for reprepare statement: %v", err)
}
}

func TestPrepare_ReprepareBatch(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

session := createSession(t)
defer session.Close()

Expand All @@ -1396,7 +1403,7 @@ func TestPrepare_ReprepareBatch(t *testing.T) {
stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch")
batch := session.NewBatch(UnloggedBatch)
batch.Query(stmt, "bar")
if err := conn.executeBatch(batch).Close(); err != nil {
if err := conn.executeBatch(ctx, batch).Close(); err != nil {
t.Fatalf("Failed to execute query for reprepare statement: %v", err)
}
}
Expand Down
51 changes: 26 additions & 25 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error
return nil
}

func (c *Conn) executeQuery(qry *Query) *Iter {
func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
params := queryParams{
consistency: qry.cons,
}
Expand Down Expand Up @@ -992,7 +992,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
if qry.shouldPrepare() {
// Prepare all DML queries. Other queries can not be prepared.
var err error
info, err = c.prepareStatement(qry.context, qry.stmt, qry.trace)
info, err = c.prepareStatement(ctx, qry.stmt, qry.trace)
if err != nil {
return &Iter{err: err}
}
Expand Down Expand Up @@ -1043,7 +1043,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
}
}

framer, err := c.exec(qry.context, frame, qry.trace)
framer, err := c.exec(ctx, frame, qry.trace)
if err != nil {
return &Iter{err: err}
}
Expand All @@ -1070,19 +1070,18 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
if params.skipMeta {
if info != nil {
iter.meta = info.response
iter.meta.pagingState = x.meta.pagingState
iter.meta.pagingState = copyBytes(x.meta.pagingState)
} else {
return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")}
}
} else {
iter.meta = x.meta
}

if len(x.meta.pagingState) > 0 && !qry.disableAutoPage {
if x.meta.morePages() && !qry.disableAutoPage {
iter.next = &nextIter{
qry: *qry,
pos: int((1 - qry.prefetch) * float64(x.numRows)),
conn: c,
qry: qry,
pos: int((1 - qry.prefetch) * float64(x.numRows)),
}

iter.next.qry.pageState = copyBytes(x.meta.pagingState)
Expand All @@ -1096,7 +1095,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
return &Iter{framer: framer}
case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction, *schemaChangeAggregate, *schemaChangeType:
iter := &Iter{framer: framer}
if err := c.awaitSchemaAgreement(); err != nil {
if err := c.awaitSchemaAgreement(ctx); err != nil {
// TODO: should have this behind a flag
Logger.Println(err)
}
Expand All @@ -1107,7 +1106,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
case *RequestErrUnprepared:
stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, qry.stmt)
if c.session.stmtsLRU.remove(stmtCacheKey) {
return c.executeQuery(qry)
return c.executeQuery(ctx, qry)
}

return &Iter{err: x, framer: framer}
Expand Down Expand Up @@ -1167,7 +1166,7 @@ func (c *Conn) UseKeyspace(keyspace string) error {
return nil
}

func (c *Conn) executeBatch(batch *Batch) *Iter {
func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
if c.version == protoVersion1 {
return &Iter{err: ErrUnsupported}
}
Expand All @@ -1190,7 +1189,7 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
b := &req.statements[i]

if len(entry.Args) > 0 || entry.binding != nil {
info, err := c.prepareStatement(batch.context, entry.Stmt, nil)
info, err := c.prepareStatement(batch.Context(), entry.Stmt, nil)
if err != nil {
return &Iter{err: err}
}
Expand Down Expand Up @@ -1233,7 +1232,7 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
}

// TODO: should batch support tracing?
framer, err := c.exec(batch.context, req, nil)
framer, err := c.exec(batch.Context(), req, nil)
if err != nil {
return &Iter{err: err}
}
Expand All @@ -1254,7 +1253,7 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
}

if found {
return c.executeBatch(batch)
return c.executeBatch(ctx, batch)
} else {
return &Iter{err: x, framer: framer}
}
Expand All @@ -1273,13 +1272,13 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
}
}

func (c *Conn) query(statement string, values ...interface{}) (iter *Iter) {
func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) (iter *Iter) {
q := c.session.Query(statement, values...).Consistency(One)
q.trace = nil
return c.executeQuery(q)
return c.executeQuery(ctx, q)
}

func (c *Conn) awaitSchemaAgreement() (err error) {
func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) {
const (
peerSchemas = "SELECT schema_version, peer FROM system.peers"
localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"
Expand All @@ -1289,7 +1288,7 @@ func (c *Conn) awaitSchemaAgreement() (err error) {

endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement)
for time.Now().Before(endDeadline) {
iter := c.query(peerSchemas)
iter := c.query(ctx, peerSchemas)

versions = make(map[string]struct{})

Expand All @@ -1309,7 +1308,7 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
goto cont
}

iter = c.query(localSchemas)
iter = c.query(ctx, localSchemas)
for iter.Scan(&schemaVersion) {
versions[schemaVersion] = struct{}{}
schemaVersion = ""
Expand All @@ -1324,11 +1323,15 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
}

cont:
time.Sleep(200 * time.Millisecond)
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(200 * time.Millisecond):
}
}

if err != nil {
return
return err
}

schemas := make([]string, 0, len(versions))
Expand All @@ -1340,10 +1343,8 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas)
}

const localHostInfo = "SELECT * FROM system.local WHERE key='local'"

func (c *Conn) localHostInfo() (*HostInfo, error) {
row, err := c.query(localHostInfo).rowMap()
func (c *Conn) localHostInfo(ctx context.Context) (*HostInfo, error) {
row, err := c.query(ctx, "SELECT * FROM system.local WHERE key='local'").rowMap()
if err != nil {
return nil, err
}
Expand Down
45 changes: 8 additions & 37 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ func TestCancel(t *testing.T) {
}
defer db.Close()

qry := db.Query("timeout")
qry := db.Query("timeout").WithContext(ctx)

// Make sure we finish the query without leftovers
var wg sync.WaitGroup
Expand All @@ -313,7 +313,7 @@ func TestCancel(t *testing.T) {
}()

// The query will timeout after about 1 seconds, so cancel it after a short pause
time.AfterFunc(20*time.Millisecond, qry.Cancel)
time.AfterFunc(20*time.Millisecond, cancel)
wg.Wait()
}

Expand Down Expand Up @@ -780,41 +780,11 @@ func TestStream0(t *testing.T) {
}
}

func TestConnClosedBlocked(t *testing.T) {
t.Skip("FLAKE: skipping test flake see https://github.com/gocql/gocql/issues/1088")
// issue 664
const proto = 3

srv := NewTestServer(t, proto, context.Background())
defer srv.Stop()
errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
t.Log(err)
})

s, err := srv.session()
if err != nil {
t.Fatal(err)
}
defer s.Close()

conn, err := s.connect(srv.host(), errorHandler)
if err != nil {
t.Fatal(err)
}

if err := conn.conn.Close(); err != nil {
t.Fatal(err)
}

// This will block indefintaly if #664 is not fixed
err = conn.executeQuery(&Query{stmt: "void"}).Close()
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
t.Fatalf("expected to get use of closed networking connection error got: %v\n", err)
}
}

func TestContext_Timeout(t *testing.T) {
srv := NewTestServer(t, defaultProto, context.Background())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

srv := NewTestServer(t, defaultProto, ctx)
defer srv.Stop()

cluster := testCluster(defaultProto, srv.Address)
Expand All @@ -825,8 +795,9 @@ func TestContext_Timeout(t *testing.T) {
}
defer db.Close()

ctx, cancel := context.WithCancel(context.Background())
ctx, cancel = context.WithCancel(ctx)
cancel()

err = db.Query("timeout").WithContext(ctx).Exec()
if err != context.Canceled {
t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err)
Expand Down
6 changes: 3 additions & 3 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ func (c *controlConn) setupConn(conn *Conn) error {

// TODO(zariel): do we need to fetch host info everytime
// the control conn connects? Surely we have it cached?
host, err := conn.localHostInfo()
host, err := conn.localHostInfo(context.TODO())
if err != nil {
return err
}
Expand Down Expand Up @@ -446,7 +446,7 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter

for {
iter = c.withConn(func(conn *Conn) *Iter {
return conn.executeQuery(q)
return conn.executeQuery(context.TODO(), q)
})

if gocqlDebug && iter.err != nil {
Expand All @@ -464,7 +464,7 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter

func (c *controlConn) awaitSchemaAgreement() error {
return c.withConn(func(conn *Conn) *Iter {
return &Iter{err: conn.awaitSchemaAgreement()}
return &Iter{err: conn.awaitSchemaAgreement(context.TODO())}
}).err
}

Expand Down
4 changes: 4 additions & 0 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,10 @@ type resultMetadata struct {
actualColCount int
}

func (r *resultMetadata) morePages() bool {
return r.flags&flagHasMorePages == flagHasMorePages
}

func (r resultMetadata) String() string {
return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v]", r.flags, r.pagingState, r.columns)
}
Expand Down
5 changes: 3 additions & 2 deletions host_source.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gocql

import (
"context"
"errors"
"fmt"
"net"
Expand Down Expand Up @@ -555,7 +556,7 @@ func (r *ringDescriber) getClusterPeerInfo() ([]*HostInfo, error) {
var hosts []*HostInfo
iter := r.session.control.withConnHost(func(ch *connHost) *Iter {
hosts = append(hosts, ch.host)
return ch.conn.query("SELECT * FROM system.peers")
return ch.conn.query(context.TODO(), "SELECT * FROM system.peers")
})

if iter == nil {
Expand Down Expand Up @@ -622,7 +623,7 @@ func (r *ringDescriber) getHostInfo(ip net.IP, port int) (*HostInfo, error) {
return nil
}

return ch.conn.query("SELECT * FROM system.peers")
return ch.conn.query(context.TODO(), "SELECT * FROM system.peers")
})

if iter != nil {
Expand Down
2 changes: 1 addition & 1 deletion policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ type RetryableQuery interface {
Attempts() int
SetConsistency(c Consistency)
GetConsistency() Consistency
GetContext() context.Context
Context() context.Context
}

type RetryType uint16
Expand Down
Loading

0 comments on commit 271c061

Please sign in to comment.