diff --git a/jwt/internal/types/date.go b/jwt/internal/types/date.go index 1f68b63b4..79f965f97 100644 --- a/jwt/internal/types/date.go +++ b/jwt/internal/types/date.go @@ -3,11 +3,20 @@ package types import ( "fmt" "strconv" + "strings" "time" "github.com/lestrrat-go/jwx/v2/internal/json" ) +const ( + DefaultPrecision uint32 = 0 // second level + MaxPrecision uint32 = 9 // nanosecond level +) + +var ParsePrecision = DefaultPrecision +var FormatPrecision = DefaultPrecision + // NumericDate represents the date format used in the 'nbf' claim type NumericDate struct { time.Time @@ -20,7 +29,7 @@ func (n *NumericDate) Get() time.Time { return n.Time } -func numericToTime(v interface{}, t *time.Time) bool { +func intToTime(v interface{}, t *time.Time) bool { var n int64 switch x := v.(type) { case int64: @@ -33,10 +42,6 @@ func numericToTime(v interface{}, t *time.Time) bool { n = int64(x) case int: n = int64(x) - case float32: - n = int64(x) - case float64: - n = int64(x) default: return false } @@ -45,27 +50,71 @@ func numericToTime(v interface{}, t *time.Time) bool { return true } +func parseNumericString(x string) (time.Time, error) { + var t time.Time // empty time for empty return value + var fractional string + whole := x + if i := strings.IndexRune(x, '.'); i > 0 { + if ParsePrecision > 0 && len(x) > i+1 { + fractional = x[i+1:] // everything after the '.' + if int(ParsePrecision) < len(fractional) { + // Remove insignificant digits + fractional = fractional[:int(ParsePrecision)] + } + // Replace missing fractional diits with zeros + for len(fractional) < int(MaxPrecision) { + fractional = fractional + "0" + } + } + whole = x[:i] + } + n, err := strconv.ParseInt(whole, 10, 64) + if err != nil { + return t, fmt.Errorf(`failed to parse whole value %q: %w`, whole, err) + } + var nsecs int64 + if fractional != "" { + v, err := strconv.ParseInt(fractional, 10, 64) + if err != nil { + return t, fmt.Errorf(`failed to parse fractional value %q: %w`, fractional, err) + } + nsecs = v + } + + return time.Unix(n, nsecs).UTC(), nil +} + func (n *NumericDate) Accept(v interface{}) error { var t time.Time - switch x := v.(type) { - case string: - i, err := strconv.ParseInt(x[:], 10, 64) + case float32: + tv, err := parseNumericString(fmt.Sprintf(`%.9f`, x)) if err != nil { - return fmt.Errorf(`invalid epoch value %#v`, x) + return fmt.Errorf(`failed to accept float32 %.9f: %w`, x, err) } - t = time.Unix(i, 0) - + t = tv + case float64: + tv, err := parseNumericString(fmt.Sprintf(`%.9f`, x)) + if err != nil { + return fmt.Errorf(`failed to accept float32 %.9f: %w`, x, err) + } + t = tv case json.Number: - intval, err := x.Int64() + tv, err := parseNumericString(x.String()) if err != nil { - return fmt.Errorf(`failed to convert json value %#v to int64: %w`, x, err) + return fmt.Errorf(`failed to accept json.Number %q: %w`, x.String(), err) } - t = time.Unix(intval, 0) + t = tv + case string: + tv, err := parseNumericString(x) + if err != nil { + return fmt.Errorf(`failed to accept string %q: %w`, x, err) + } + t = tv case time.Time: t = x default: - if !numericToTime(v, &t) { + if !intToTime(v, &t) { return fmt.Errorf(`invalid type %T`, v) } } @@ -73,13 +122,36 @@ func (n *NumericDate) Accept(v interface{}) error { return nil } +func (n NumericDate) String() string { + if FormatPrecision == 0 { + return strconv.FormatInt(n.Unix(), 10) + } + + // This is cheating,but it's better (easier) than doing floating point math + // We basically munge with strings after formatting an integer balue + // for nanoseconds since epoch + s := strconv.FormatInt(n.UnixNano(), 10) + for len(s) < int(MaxPrecision) { + s = "0" + s + } + + slwhole := len(s) - int(MaxPrecision) + s = s[:slwhole] + "." + s[slwhole:slwhole+int(FormatPrecision)] + if s[0] == '.' { + s = "0" + s + } + + return s +} + // MarshalJSON translates from internal representation to JSON NumericDate // See https://tools.ietf.org/html/rfc7519#page-6 func (n *NumericDate) MarshalJSON() ([]byte, error) { if n.IsZero() { return json.Marshal(nil) } - return json.Marshal(n.Unix()) + + return json.Marshal(n.String()) } func (n *NumericDate) UnmarshalJSON(data []byte) error { diff --git a/jwt/internal/types/date_test.go b/jwt/internal/types/date_test.go index 5562ccddb..19d369879 100644 --- a/jwt/internal/types/date_test.go +++ b/jwt/internal/types/date_test.go @@ -13,16 +13,13 @@ import ( ) func TestDate(t *testing.T) { - t.Parallel() t.Run("Get from a nil NumericDate", func(t *testing.T) { - t.Parallel() var n *types.NumericDate if !assert.Equal(t, time.Time{}, n.Get()) { return } }) t.Run("MarshalJSON with a zero value", func(t *testing.T) { - t.Parallel() var n *types.NumericDate buf, err := json.Marshal(n) if !assert.NoError(t, err, `json.Marshal against a zero value should succeed`) { @@ -33,18 +30,73 @@ func TestDate(t *testing.T) { return } }) + + // This test alters global behavior, and can't be ran in parallel t.Run("Accept values", func(t *testing.T) { - t.Parallel() // NumericDate allows assignment from various different Go types, // so that it's easier for the devs, and conversion to/from JSON - // use of "127" is just to allow use of int8's - now := time.Unix(127, 0).UTC() - for _, ut := range []interface{}{int64(127), int32(127), int16(127), int8(127), float32(127), float64(127), json.Number("127")} { - ut := ut - t.Run(fmt.Sprintf("%T", ut), func(t *testing.T) { - t.Parallel() + testcases := []struct { + Input interface{} + Expected time.Time + Precision int + }{ + { + Input: int64(127), + Expected: time.Unix(127, 0).UTC(), + }, + { + Input: int32(127), + Expected: time.Unix(127, 0).UTC(), + }, + { + Input: int16(127), + Expected: time.Unix(127, 0).UTC(), + }, + { + Input: int8(127), + Expected: time.Unix(127, 0).UTC(), + }, + { + Input: float32(127.11), + Expected: time.Unix(127, 0).UTC(), + }, + { + Input: float32(127.11), + Expected: time.Unix(127, 0).UTC(), + }, + { + Input: json.Number("127"), + Expected: time.Unix(127, 0).UTC(), + }, + { + Input: json.Number("127.11"), + Expected: time.Unix(127, 0).UTC(), + }, + { + Input: json.Number("127.11"), + Expected: time.Unix(127, 110000000).UTC(), + Precision: 4, + }, + { + Input: json.Number("127.110000011"), + Expected: time.Unix(127, 110000011).UTC(), + Precision: 9, + }, + { + Input: json.Number("127.110000011111"), + Expected: time.Unix(127, 110000011).UTC(), + Precision: 9, + }, + } + + for _, tc := range testcases { + tc := tc + precision := tc.Precision + t.Run(fmt.Sprintf("%v(type=%T, precision=%d)", tc.Input, tc.Input, precision), func(t *testing.T) { + jwt.Settings(jwt.WithNumericDateParsePrecision(precision)) + t1 := jwt.New() - err := t1.Set(jwt.IssuedAtKey, ut) + err := t1.Set(jwt.IssuedAtKey, tc.Input) if !assert.NoError(t, err) { return } @@ -53,7 +105,7 @@ func TestDate(t *testing.T) { return } realized := v.(time.Time) - if !assert.Equal(t, now, realized) { + if !assert.Equal(t, tc.Expected, realized) { return } }) diff --git a/jwt/jwt.go b/jwt/jwt.go index 4bc86e3c6..101fa5fc7 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -12,27 +12,58 @@ import ( "github.com/lestrrat-go/jwx/v2" "github.com/lestrrat-go/jwx/v2/internal/json" "github.com/lestrrat-go/jwx/v2/jws" + "github.com/lestrrat-go/jwx/v2/jwt/internal/types" ) // Settings controls global settings that are specific to JWTs. func Settings(options ...GlobalOption) { var flattenAudienceBool bool + var parsePrecision = types.MaxPrecision + 1 // illegal value, so we can detect nothing was set + var formatPrecision = types.MaxPrecision + 1 // illegal value, so we can detect nothing was set //nolint:forcetypeassert for _, option := range options { switch option.Ident() { case identFlattenAudience{}: flattenAudienceBool = option.Value().(bool) + case identNumericDateParsePrecision{}: + v := option.Value().(int) + // only accept this value if it's in our desired range + if v >= 0 && v <= int(types.MaxPrecision) { + parsePrecision = uint32(v) + } + case identNumericDateFormatPrecision{}: + v := option.Value().(int) + // only accept this value if it's in our desired range + if v >= 0 && v <= int(types.MaxPrecision) { + formatPrecision = uint32(v) + } + } + } + + if parsePrecision <= types.MaxPrecision { // remember we set default to max + 1 + v := atomic.LoadUint32(&types.ParsePrecision) + if v != parsePrecision { + atomic.CompareAndSwapUint32(&types.ParsePrecision, v, parsePrecision) } } - v := atomic.LoadUint32(&json.FlattenAudience) - if (v == 1) != flattenAudienceBool { - var newVal uint32 - if flattenAudienceBool { - newVal = 1 + if formatPrecision <= types.MaxPrecision { // remember we set default to max + 1 + v := atomic.LoadUint32(&types.FormatPrecision) + if v != formatPrecision { + atomic.CompareAndSwapUint32(&types.FormatPrecision, v, formatPrecision) + } + } + + { + v := atomic.LoadUint32(&json.FlattenAudience) + if (v == 1) != flattenAudienceBool { + var newVal uint32 + if flattenAudienceBool { + newVal = 1 + } + atomic.CompareAndSwapUint32(&json.FlattenAudience, v, newVal) } - atomic.CompareAndSwapUint32(&json.FlattenAudience, v, newVal) } } diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index 8a69419bb..f42e4ae5b 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -23,12 +23,14 @@ import ( "github.com/lestrrat-go/jwx/v2/internal/json" "github.com/lestrrat-go/jwx/v2/internal/jwxtest" "github.com/lestrrat-go/jwx/v2/jwe" + "github.com/lestrrat-go/jwx/v2/jwt/internal/types" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) /* This is commented out, because it is intended to cause compilation errors */ @@ -1483,3 +1485,115 @@ func TestSerializer(t *testing.T) { } }) } + +func TestFractional(t *testing.T) { + t.Run("FormatPrecision", func(t *testing.T) { + var nd types.NumericDate + jwt.Settings(jwt.WithNumericDateParsePrecision(int(types.MaxPrecision))) + s := fmt.Sprintf("%d.100000001", aLongLongTimeAgo) + _ = nd.Accept(s) + jwt.Settings(jwt.WithNumericDateParsePrecision(0)) + testcases := []struct { + Input types.NumericDate + Expected string + Precision int + }{ + { + Input: nd, + Expected: fmt.Sprintf(`%d`, aLongLongTimeAgo), + }, + { + Input: types.NumericDate{Time: time.Unix(0, 1).UTC()}, + Expected: "0", + }, + { + Input: types.NumericDate{Time: time.Unix(0, 1).UTC()}, + Precision: 9, + Expected: "0.000000001", + }, + { + Input: types.NumericDate{Time: time.Unix(0, 100000000).UTC()}, + Precision: 9, + Expected: "0.100000000", + }, + } + + for i := 1; i <= int(types.MaxPrecision); i++ { + fractional := (fmt.Sprintf(`%d`, 100000001))[:i] + testcases = append(testcases, struct { + Input types.NumericDate + Expected string + Precision int + }{ + Input: nd, + Precision: i, + Expected: fmt.Sprintf(`%d.%s`, aLongLongTimeAgo, fractional), + }) + } + + for _, tc := range testcases { + tc := tc + t.Run(fmt.Sprintf("%s (precision=%d)", tc.Input, tc.Precision), func(t *testing.T) { + jwt.Settings(jwt.WithNumericDateFormatPrecision(tc.Precision)) + require.Equal(t, tc.Expected, tc.Input.String()) + }) + } + jwt.Settings(jwt.WithNumericDateFormatPrecision(0)) + }) + t.Run("ParsePrecision", func(t *testing.T) { + const template = `{"iat":"%s"}` + + testcases := []struct { + Input string + Expected time.Time + Precision int + }{ + { + Input: "0", + Expected: time.Unix(0, 0).UTC(), + }, + { + Input: "0.000000001", + Expected: time.Unix(0, 0).UTC(), + }, + { + Input: fmt.Sprintf("%d.111111111", aLongLongTimeAgo), + Expected: time.Unix(aLongLongTimeAgo, 0).UTC(), + }, + { + // Max precision + Input: fmt.Sprintf("%d.100000001", aLongLongTimeAgo), + Precision: int(types.MaxPrecision), + Expected: time.Unix(aLongLongTimeAgo, 100000001).UTC(), + }, + } + + for i := 1; i < int(types.MaxPrecision); i++ { + testcases = append(testcases, struct { + Input string + Expected time.Time + Precision int + }{ + Input: fmt.Sprintf("%d.100000001", aLongLongTimeAgo), + Precision: i, + Expected: time.Unix(aLongLongTimeAgo, 100000000).UTC(), + }) + } + + for _, tc := range testcases { + tc := tc + t.Run(fmt.Sprintf("%s (precision=%d)", tc.Input, tc.Precision), func(t *testing.T) { + jwt.Settings(jwt.WithNumericDateParsePrecision(tc.Precision)) + tok, err := jwt.Parse( + []byte(fmt.Sprintf(template, tc.Input)), + jwt.WithVerify(false), + jwt.WithValidate(false), + ) + require.NoError(t, err, `jwt.Parse should succeed`) + + require.Equal(t, tc.Expected, tok.IssuedAt(), `iat should match`) + }) + } + jwt.Settings(jwt.WithNumericDateParsePrecision(0)) + }) +} diff --git a/jwt/options.yaml b/jwt/options.yaml index d9274ce87..303caa9ee 100644 --- a/jwt/options.yaml +++ b/jwt/options.yaml @@ -172,4 +172,18 @@ options: argument_type: fs.FS comment: | WithFS specifies the source `fs.FS` object to read the file from. + - ident: NumericDateParsePrecision + interface: GlobalOption + argument_type: int + comment: | + WithNumericDateParsePrecision sets the precision up to which the + library uses to parse fractional dates found in the numeric date + fields. Default is 0 (second, no fractionals), max is 9 (nanosecond) + - ident: NumericDateFormatPrecision + interface: GlobalOption + argument_type: int + comment: | + WithNumericDateFormatPrecision sets the precision up to which the + library uses to format fractional dates found in the numeric date + fields. Default is 0 (second, no fractionals), max is 9 (nanosecond) diff --git a/jwt/options_gen.go b/jwt/options_gen.go index 4107803be..41c48ba62 100644 --- a/jwt/options_gen.go +++ b/jwt/options_gen.go @@ -131,6 +131,8 @@ type identFlattenAudience struct{} type identFormKey struct{} type identHeaderKey struct{} type identKeyProvider struct{} +type identNumericDateFormatPrecision struct{} +type identNumericDateParsePrecision struct{} type identPedantic struct{} type identSignOption struct{} type identToken struct{} @@ -174,6 +176,14 @@ func (identKeyProvider) String() string { return "WithKeyProvider" } +func (identNumericDateFormatPrecision) String() string { + return "WithNumericDateFormatPrecision" +} + +func (identNumericDateParsePrecision) String() string { + return "WithNumericDateParsePrecision" +} + func (identPedantic) String() string { return "WithPedantic" } @@ -268,6 +278,20 @@ func WithKeyProvider(v jws.KeyProvider) ParseOption { return &parseOption{option.New(identKeyProvider{}, v)} } +// WithNumericDateFormatPrecision sets the precision up to which the +// library uses to format fractional dates found in the numeric date +// fields. Default is 0 (second, no fractionals), max is 9 (nanosecond) +func WithNumericDateFormatPrecision(v int) GlobalOption { + return &globalOption{option.New(identNumericDateFormatPrecision{}, v)} +} + +// WithNumericDateParsePrecision sets the precision up to which the +// library uses to parse fractional dates found in the numeric date +// fields. Default is 0 (second, no fractionals), max is 9 (nanosecond) +func WithNumericDateParsePrecision(v int) GlobalOption { + return &globalOption{option.New(identNumericDateParsePrecision{}, v)} +} + // WithPedantic enables pedantic mode for parsing JWTs. Currently this only // applies to checking for the correct `typ` and/or `cty` when necessary. func WithPedantic(v bool) ParseOption { diff --git a/jwt/options_gen_test.go b/jwt/options_gen_test.go index 67a72e88e..ba07af850 100644 --- a/jwt/options_gen_test.go +++ b/jwt/options_gen_test.go @@ -18,6 +18,8 @@ func TestOptionIdent(t *testing.T) { require.Equal(t, "WithFormKey", identFormKey{}.String()) require.Equal(t, "WithHeaderKey", identHeaderKey{}.String()) require.Equal(t, "WithKeyProvider", identKeyProvider{}.String()) + require.Equal(t, "WithNumericDateFormatPrecision", identNumericDateFormatPrecision{}.String()) + require.Equal(t, "WithNumericDateParsePrecision", identNumericDateParsePrecision{}.String()) require.Equal(t, "WithPedantic", identPedantic{}.String()) require.Equal(t, "WithSignOption", identSignOption{}.String()) require.Equal(t, "WithToken", identToken{}.String())