forked from gravitational/teleport
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjoin_iam.go
500 lines (436 loc) · 17.4 KB
/
join_iam.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
/*
* Teleport
* Copyright (C) 2023 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package auth
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"net/http"
"net/url"
"slices"
"strings"
awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/coreos/go-semver/semver"
"github.com/gravitational/trace"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
cloudaws "github.com/gravitational/teleport/lib/cloud/aws"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/aws"
)
const (
// Hardcoding the sts API version here may be more strict than necessary,
// but this is set by the Teleport node and can only be changed when we
// update our AWS SDK dependency. Since Auth should always be upgraded
// before nodes, we will have a chance to update the check on Auth if we
// ever have a need to allow a newer API version.
expectedSTSIdentityRequestBody = "Action=GetCallerIdentity&Version=2011-06-15"
// AWS SignedHeaders will always be lowercase
// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-auth-using-authorization-header.html#sigv4-auth-header-overview
challengeHeaderKey = "x-teleport-challenge"
)
// validateSTSHost returns an error if the given stsHost is not a valid regional
// endpoint for the AWS STS service, or nil if it is valid. If fips is true, the
// endpoint must be a valid FIPS endpoint.
//
// This is a security-critical check: we are allowing the client to tell us
// which URL we should use to validate their identity. If the client could pass
// off an attacker-controlled URL as the STS endpoint, the entire security
// mechanism of the IAM join method would be compromised.
//
// To keep this validation simple and secure, we check the given endpoint
// against a static list of known valid endpoints. We will need to update this
// list as AWS adds new regions.
func validateSTSHost(stsHost string, cfg *iamRegisterConfig) error {
valid := slices.Contains(validSTSEndpoints, stsHost)
if !valid {
return trace.AccessDenied("IAM join request uses unknown STS host %q. "+
"This could mean that the Teleport Node attempting to join the cluster is "+
"running in a new AWS region which is unknown to this Teleport auth server. "+
"Alternatively, if this URL looks suspicious, an attacker may be attempting to "+
"join your Teleport cluster. "+
"Following is the list of valid STS endpoints known to this auth server. "+
"If a legitimate STS endpoint is not included, please file an issue at "+
"https://github.com/gravitational/teleport. %v",
stsHost, validSTSEndpoints)
}
if cfg.fips && !slices.Contains(fipsSTSEndpoints, stsHost) {
return trace.AccessDenied("node selected non-FIPS STS endpoint (%s) for the IAM join method", stsHost)
}
return nil
}
// validateSTSIdentityRequest checks that a received sts:GetCallerIdentity
// request is valid and includes the challenge as a signed header. An example
// valid request looks like:
// ```
// POST / HTTP/1.1
// Host: sts.amazonaws.com
// Accept: application/json
// Authorization: AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20211108/us-east-1/sts/aws4_request, SignedHeaders=accept;content-length;content-type;host;x-amz-date;x-amz-security-token;x-teleport-challenge, Signature=999...
// Content-Length: 43
// Content-Type: application/x-www-form-urlencoded; charset=utf-8
// User-Agent: aws-sdk-go/1.37.17 (go1.17.1; darwin; amd64)
// X-Amz-Date: 20211108T190420Z
// X-Amz-Security-Token: aaa...
// X-Teleport-Challenge: 0ezlc3usTAkXeZTcfOazUq0BGrRaKmb4EwODk8U7J5A
//
// Action=GetCallerIdentity&Version=2011-06-15
// ```
func validateSTSIdentityRequest(req *http.Request, challenge string, cfg *iamRegisterConfig) (err error) {
defer func() {
// Always log a warning on the Auth server if the function detects an
// invalid sts:GetCallerIdentity request, it's either going to be caused
// by a node in a unknown region or an attacker.
if err != nil {
log.WithError(err).Warn("Detected an invalid sts:GetCallerIdentity used by a client attempting to use the IAM join method.")
}
}()
if err := validateSTSHost(req.Host, cfg); err != nil {
return trace.Wrap(err)
}
if req.Method != http.MethodPost {
return trace.AccessDenied("sts identity request method %q does not match expected method %q", req.RequestURI, http.MethodPost)
}
if req.Header.Get(challengeHeaderKey) != challenge {
return trace.AccessDenied("sts identity request does not include challenge header or it does not match")
}
authHeader := req.Header.Get(aws.AuthorizationHeader)
sigV4, err := aws.ParseSigV4(authHeader)
if err != nil {
return trace.Wrap(err)
}
if !slices.Contains(sigV4.SignedHeaders, challengeHeaderKey) {
return trace.AccessDenied("sts identity request auth header %q does not include "+
challengeHeaderKey+" as a signed header", authHeader)
}
body, err := utils.GetAndReplaceRequestBody(req)
if err != nil {
return trace.Wrap(err)
}
if !bytes.Equal([]byte(expectedSTSIdentityRequestBody), body) {
return trace.BadParameter("sts request body %q does not equal expected %q", string(body), expectedSTSIdentityRequestBody)
}
return nil
}
func parseSTSRequest(req []byte) (*http.Request, error) {
httpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(req)))
if err != nil {
return nil, trace.Wrap(err)
}
// Unset RequestURI and set req.URL instead (necessary quirk of sending a
// request parsed by http.ReadRequest). Also, force https here.
if httpReq.RequestURI != "/" {
return nil, trace.AccessDenied("unexpected sts identity request URI: %q", httpReq.RequestURI)
}
httpReq.RequestURI = ""
httpReq.URL = &url.URL{
Scheme: "https",
Host: httpReq.Host,
}
return httpReq, nil
}
// awsIdentity holds aws Account and Arn, used for JSON parsing
type awsIdentity struct {
Account string `json:"Account"`
Arn string `json:"Arn"`
}
// getCallerIdentityReponse is used for JSON parsing
type getCallerIdentityResponse struct {
GetCallerIdentityResult awsIdentity `json:"GetCallerIdentityResult"`
}
// stsIdentityResponse is used for JSON parsing
type stsIdentityResponse struct {
GetCallerIdentityResponse getCallerIdentityResponse `json:"GetCallerIdentityResponse"`
}
// executeSTSIdentityRequest sends the sts:GetCallerIdentity HTTP request to the
// AWS API, parses the response, and returns the awsIdentity
func executeSTSIdentityRequest(ctx context.Context, client utils.HTTPDoClient, req *http.Request) (*awsIdentity, error) {
if client == nil {
client = http.DefaultClient
}
// set the http request context so it can be canceled
req = req.WithContext(ctx)
resp, err := client.Do(req)
if err != nil {
return nil, trace.Wrap(err)
}
defer resp.Body.Close()
body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize)
if err != nil {
return nil, trace.Wrap(err)
}
if resp.StatusCode != http.StatusOK {
return nil, trace.AccessDenied("aws sts api returned status: %q body: %q",
resp.Status, body)
}
var identityResponse stsIdentityResponse
if err := json.Unmarshal(body, &identityResponse); err != nil {
return nil, trace.Wrap(err)
}
id := &identityResponse.GetCallerIdentityResponse.GetCallerIdentityResult
if id.Account == "" {
return nil, trace.BadParameter("received empty AWS account ID from sts API")
}
if id.Arn == "" {
return nil, trace.BadParameter("received empty AWS identity ARN from sts API")
}
return id, nil
}
// arnMatches returns true if arn matches the pattern.
// Pattern should be an AWS ARN which may include "*" to match any combination
// of zero or more characters and "?" to match any single character.
// See https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements_resource.html
func arnMatches(pattern, arn string) (bool, error) {
return globMatch(pattern, arn)
}
// checkIAMAllowRules checks if the given identity matches any of the given
// allowRules.
func checkIAMAllowRules(identity *awsIdentity, token string, allowRules []*types.TokenRule) error {
for _, rule := range allowRules {
// if this rule specifies an AWS account, the identity must match
if len(rule.AWSAccount) > 0 {
if rule.AWSAccount != identity.Account {
// account doesn't match, continue to check the next rule
continue
}
}
// if this rule specifies an AWS ARN, the identity must match
if len(rule.AWSARN) > 0 {
matches, err := arnMatches(rule.AWSARN, identity.Arn)
if err != nil {
return trace.Wrap(err)
}
if !matches {
// arn doesn't match, continue to check the next rule
continue
}
}
// node identity matches this allow rule
return nil
}
return trace.AccessDenied("instance %v did not match any allow rules in token %v", identity.Arn, token)
}
// checkIAMRequest checks if the given request satisfies the token rules and
// included the required challenge.
func (a *Server) checkIAMRequest(ctx context.Context, challenge string, req *proto.RegisterUsingIAMMethodRequest, cfg *iamRegisterConfig) error {
tokenName := req.RegisterUsingTokenRequest.Token
provisionToken, err := a.GetToken(ctx, tokenName)
if err != nil {
return trace.Wrap(err)
}
if provisionToken.GetJoinMethod() != types.JoinMethodIAM {
return trace.AccessDenied("this token does not support the IAM join method")
}
// parse the incoming http request to the sts:GetCallerIdentity endpoint
identityRequest, err := parseSTSRequest(req.StsIdentityRequest)
if err != nil {
return trace.Wrap(err)
}
// validate that the host, method, and headers are correct and the expected
// challenge is included in the signed portion of the request
if err := validateSTSIdentityRequest(identityRequest, challenge, cfg); err != nil {
return trace.Wrap(err)
}
// send the signed request to the public AWS API and get the node identity
// from the response
identity, err := executeSTSIdentityRequest(ctx, a.httpClientForAWSSTS, identityRequest)
if err != nil {
return trace.Wrap(err)
}
// check that the node identity matches an allow rule for this token
if err := checkIAMAllowRules(identity, provisionToken.GetName(), provisionToken.GetAllowRules()); err != nil {
return trace.Wrap(err)
}
return nil
}
func generateIAMChallenge() (string, error) {
challenge, err := generateChallenge(base64.RawStdEncoding, 32)
return challenge, trace.Wrap(err)
}
type iamRegisterConfig struct {
authVersion *semver.Version
fips bool
}
func defaultIAMRegisterConfig(fips bool) *iamRegisterConfig {
return &iamRegisterConfig{
authVersion: teleport.SemVersion,
fips: fips,
}
}
type iamRegisterOption func(cfg *iamRegisterConfig)
func withAuthVersion(v *semver.Version) iamRegisterOption {
return func(cfg *iamRegisterConfig) {
cfg.authVersion = v
}
}
func withFips(fips bool) iamRegisterOption {
return func(cfg *iamRegisterConfig) {
cfg.fips = fips
}
}
// RegisterUsingIAMMethod registers the caller using the IAM join method and
// returns signed certs to join the cluster.
//
// The caller must provide a ChallengeResponseFunc which returns a
// *types.RegisterUsingTokenRequest with a signed sts:GetCallerIdentity request
// including the challenge as a signed header.
func (a *Server) RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc, opts ...iamRegisterOption) (*proto.Certs, error) {
cfg := defaultIAMRegisterConfig(a.fips)
for _, opt := range opts {
opt(cfg)
}
challenge, err := generateIAMChallenge()
if err != nil {
return nil, trace.Wrap(err)
}
req, err := challengeResponse(challenge)
if err != nil {
return nil, trace.Wrap(err)
}
if err := req.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
// perform common token checks
provisionToken, err := a.checkTokenJoinRequestCommon(ctx, req.RegisterUsingTokenRequest)
if err != nil {
return nil, trace.Wrap(err)
}
// check that the GetCallerIdentity request is valid and matches the token
if err := a.checkIAMRequest(ctx, challenge, req, cfg); err != nil {
return nil, trace.Wrap(err)
}
if req.RegisterUsingTokenRequest.Role == types.RoleBot {
certs, err := a.generateCertsBot(ctx, provisionToken, req.RegisterUsingTokenRequest, nil)
return certs, trace.Wrap(err)
}
certs, err := a.generateCerts(ctx, provisionToken, req.RegisterUsingTokenRequest, nil)
return certs, trace.Wrap(err)
}
type stsIdentityRequestConfig struct {
regionalEndpointOption endpoints.STSRegionalEndpoint
fipsEndpointOption endpoints.FIPSEndpointState
}
type stsIdentityRequestOption func(cfg *stsIdentityRequestConfig)
func withRegionalEndpoint(useRegionalEndpoint bool) stsIdentityRequestOption {
return func(cfg *stsIdentityRequestConfig) {
if useRegionalEndpoint {
cfg.regionalEndpointOption = endpoints.RegionalSTSEndpoint
} else {
cfg.regionalEndpointOption = endpoints.LegacySTSEndpoint
}
}
}
func withFIPSEndpoint(useFIPS bool) stsIdentityRequestOption {
return func(cfg *stsIdentityRequestConfig) {
if useFIPS {
cfg.fipsEndpointOption = endpoints.FIPSEndpointStateEnabled
} else {
cfg.fipsEndpointOption = endpoints.FIPSEndpointStateDisabled
}
}
}
// createSignedSTSIdentityRequest is called on the client side and returns an
// sts:GetCallerIdentity request signed with the local AWS credentials
func createSignedSTSIdentityRequest(ctx context.Context, challenge string, opts ...stsIdentityRequestOption) ([]byte, error) {
cfg := &stsIdentityRequestConfig{}
for _, opt := range opts {
opt(cfg)
}
stsClient, err := newSTSClient(ctx, cfg)
if err != nil {
return nil, trace.Wrap(err)
}
req, _ := stsClient.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{})
// set challenge header
req.HTTPRequest.Header.Set(challengeHeaderKey, challenge)
// request json for simpler parsing
req.HTTPRequest.Header.Set("Accept", "application/json")
// sign the request, including headers
if err := req.Sign(); err != nil {
return nil, trace.Wrap(err)
}
// write the signed HTTP request to a buffer
var signedRequest bytes.Buffer
if err := req.HTTPRequest.Write(&signedRequest); err != nil {
return nil, trace.Wrap(err)
}
return signedRequest.Bytes(), nil
}
func newSTSClient(ctx context.Context, cfg *stsIdentityRequestConfig) (*sts.STS, error) {
awsConfig := awssdk.Config{
UseFIPSEndpoint: cfg.fipsEndpointOption,
STSRegionalEndpoint: cfg.regionalEndpointOption,
}
sess, err := session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
Config: awsConfig,
})
if err != nil {
return nil, trace.Wrap(err)
}
stsClient := sts.New(sess)
if slices.Contains(globalSTSEndpoints, strings.TrimPrefix(stsClient.Endpoint, "https://")) {
// If the caller wants to use the regional endpoint but it was not resolved
// from the environment, attempt to find the region from the EC2 IMDS
if cfg.regionalEndpointOption == endpoints.RegionalSTSEndpoint {
region, err := getEC2LocalRegion(ctx)
if err != nil {
return nil, trace.Wrap(err, "failed to resolve local AWS region from environment or IMDS")
}
stsClient = sts.New(sess, awssdk.NewConfig().WithRegion(region))
} else {
log.Info("Attempting to use the global STS endpoint for the IAM join method. " +
"This will probably fail in non-default AWS partitions such as China or GovCloud, or if FIPS mode is enabled. " +
"Consider setting the AWS_REGION environment variable, setting the region in ~/.aws/config, or enabling the IMDSv2.")
}
}
if cfg.fipsEndpointOption == endpoints.FIPSEndpointStateEnabled &&
!slices.Contains(validSTSEndpoints, strings.TrimPrefix(stsClient.Endpoint, "https://")) {
// The AWS SDK will generate invalid endpoints when attempting to
// resolve the FIPS endpoint for a region that does not have one.
// In this case, try to use the FIPS endpoint in us-east-1. This should
// work for all regions in the standard partition. In GovCloud, we should
// not hit this because all regional endpoints support FIPS. In China or
// other partitions, this will fail, and FIPS mode will not be supported.
log.Infof("AWS SDK resolved FIPS STS endpoint %s, which does not appear to be valid. "+
"Attempting to use the FIPS STS endpoint for us-east-1.",
stsClient.Endpoint)
stsClient = sts.New(sess, awssdk.NewConfig().WithRegion("us-east-1"))
}
return stsClient, nil
}
// getEC2LocalRegion returns the AWS region this EC2 instance is running in, or
// a NotFound error if the EC2 IMDS is unavailable.
func getEC2LocalRegion(ctx context.Context) (string, error) {
imdsClient, err := cloudaws.NewInstanceMetadataClient(ctx)
if err != nil {
return "", trace.Wrap(err)
}
if !imdsClient.IsAvailable(ctx) {
return "", trace.NotFound("IMDS is unavailable")
}
region, err := imdsClient.GetRegion(ctx)
return region, trace.Wrap(err)
}