Skip to content

Commit

Permalink
ZTS Client to read access tokens from file system (AthenZ#1028)
Browse files Browse the repository at this point in the history
Co-authored-by: dma <[email protected]>
  • Loading branch information
MartinTrojans and dma authored Jul 17, 2020
1 parent 9a3c35a commit 3b5bcdf
Show file tree
Hide file tree
Showing 11 changed files with 378 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package com.yahoo.athenz.zts;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.athenz.auth.token.AccessToken;
import com.yahoo.athenz.auth.token.jwts.JwtsSigningKeyResolver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.util.*;

public class ZTSAccessTokenFileLoader {

private static final Logger LOG = LoggerFactory.getLogger(ZTSAccessTokenFileLoader.class);

static public final String ACCESS_TOKEN_PATH_PROPERTY = "athenz.zts.client.accesstoken.path";
static private final String DEFAULT_ACCESS_TOKEN_DIR_PATH = "/var/lib/sia/tokens/";
static private final String ROLE_NAME_CONNECTOR = ",";
static private final String DOMAIN_ROLE_CONNECTOR = ":role:";
final private String path;
private JwtsSigningKeyResolver accessSignKeyResolver;
private ObjectMapper objectMapper = new ObjectMapper();
private Map<String, String> roleNameMap;

public ZTSAccessTokenFileLoader(JwtsSigningKeyResolver resolver) {
roleNameMap = new HashMap<>();
accessSignKeyResolver = resolver;
path = System.getProperty(ACCESS_TOKEN_PATH_PROPERTY, DEFAULT_ACCESS_TOKEN_DIR_PATH);
}

public void preload() {
File dir = new File(path);

// preload the map from the <domain, rolesname> -> <file path>
// expected dir should be <base token path>/<domain dir>/<token file>s
// after preload the map, when we look up the access token,
// the map will directly read the required file
if (dir.exists() && dir.isDirectory()) {
for (File domainDir: dir.listFiles()) {
if (domainDir.isDirectory()) {
for (File tokenFile: domainDir.listFiles()) {
if (!tokenFile.isDirectory()) {
AccessTokenResponse accessTokenResponse = null;
try {
accessTokenResponse = objectMapper.readValue(tokenFile, AccessTokenResponse.class);
} catch (IOException e) {
LOG.error("Failed to load or parse token file: {}", tokenFile);
}

// if access token parsed fail, continue to scan tokens
if (accessTokenResponse == null) {
continue;
}

AccessTokenResponseCacheEntry cacheEntry = new AccessTokenResponseCacheEntry(accessTokenResponse);

// check access token is still valid
if (!cacheEntry.isExpired(-1)) {
addToRoleMap(domainDir.getName(), tokenFile.getName(), accessTokenResponse);
}
}
}
}
}
}

}

// function load the access token from file
public AccessTokenResponse lookupAccessTokenFromDisk(String domain, List<String> rolesName) throws IOException {
final String rolesStr = getRolesStr(domain, rolesName);
final String fileName = roleNameMap.get(rolesStr);
LOG.debug("Trying to fetch access token from disk for domain: {}, roleNames: {}, roleMap key: {}. file name: {}",
domain, rolesName, rolesStr, fileName);
if (fileName == null) {
return null;
}
File tokenFile = new File(path + File.separator + domain + File.separator + fileName);

return objectMapper.readValue(tokenFile, AccessTokenResponse.class);
}

static private String getRolesStr(String domain, List<String> roleNames) {
// in case the rolesName is immutable, make a copy of role name list
if (roleNames == null || roleNames.isEmpty()) {
//if no role name specific, should return all roles
return domain + DOMAIN_ROLE_CONNECTOR + "*";
}
List<String> roleNamesCopy = new ArrayList<>(roleNames);
Collections.sort(roleNamesCopy);
return domain + DOMAIN_ROLE_CONNECTOR + String.join(ROLE_NAME_CONNECTOR, roleNamesCopy);
}

private void addToRoleMap(String domain, String fileName, AccessTokenResponse accessTokenResponse) {
// parse roles from access token
final String token = accessTokenResponse.getAccess_token();

try {
AccessToken accessToken = new AccessToken(token, accessSignKeyResolver);
List<String> roleNames = accessToken.getScope();
roleNameMap.put(getRolesStr(domain, roleNames), fileName);
} catch (Exception e) {
LOG.error("Got error to parse access token file {}, error: {}", fileName, e.getMessage());
return;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import com.oath.auth.KeyRefresherException;
import com.oath.auth.KeyRefresherListener;
import com.oath.auth.Utils;
import com.yahoo.athenz.auth.token.jwts.JwtsSigningKeyResolver;
import org.apache.http.config.Registry;
import org.apache.http.config.RegistryBuilder;
import org.apache.http.conn.socket.ConnectionSocketFactory;
Expand Down Expand Up @@ -111,6 +112,7 @@ public class ZTSClient implements Closeable {
static private int reqConnectTimeout = 30000;
static private String x509CertDNSName = null;
static private String confZtsUrl = null;
static private JwtsSigningKeyResolver resolver = null;

private boolean enablePrefetch = true;
private boolean ztsClientOverride = false;
Expand Down Expand Up @@ -173,6 +175,7 @@ public class ZTSClient implements Closeable {
private static ServiceLoader<ZTSClientService> ztsTokenProviders;
private static AtomicReference<Set<String>> svcLoaderCacheKeys;
private static PrivateKeyStore PRIVATE_KEY_STORE = loadServicePrivateKey();
private static ZTSAccessTokenFileLoader ztsAccessTokenFileLoader;

enum TokenType {
ROLE,
Expand Down Expand Up @@ -220,6 +223,10 @@ static boolean initConfigValues() {
// finally retrieve our configuration ZTS url from our config file

lookupZTSUrl();

// init zts file utility

initZTSAccessTokenFileLoader();

return true;
}
Expand Down Expand Up @@ -328,7 +335,19 @@ public static void lookupZTSUrl() {
}
}
}


public static void initZTSAccessTokenFileLoader() {
if (resolver == null) {
resolver = new JwtsSigningKeyResolver(null, null);
}
ztsAccessTokenFileLoader = new ZTSAccessTokenFileLoader(resolver);
ztsAccessTokenFileLoader.preload();
}

public static void setAccessTokenSignKeyResolver(JwtsSigningKeyResolver jwtsSigningKeyResolver) {
resolver = jwtsSigningKeyResolver;
}

/**
* Constructs a new ZTSClient object with default settings.
* The url for ZTS Server is automatically retrieved from the athenz
Expand Down Expand Up @@ -1150,7 +1169,7 @@ public AccessTokenResponse getAccessToken(String domainName, List<String> roleNa
public AccessTokenResponse getAccessToken(String domainName, List<String> roleNames, String idTokenServiceName,
String proxyForPrincipal, long expiryTime, boolean ignoreCache) {

AccessTokenResponse accessTokenResponse;
AccessTokenResponse accessTokenResponse = null;

// first lookup in our cache to see if it can be satisfied
// only if we're not asked to ignore the cache
Expand All @@ -1177,31 +1196,41 @@ public AccessTokenResponse getAccessToken(String domainName, List<String> roleNa
}
}

// if no hit then we need to look up in disk
try {
accessTokenResponse = ztsAccessTokenFileLoader.lookupAccessTokenFromDisk(domainName, roleNames);
} catch (IOException e) {
LOG.error("GetAccessToken: failed to load access token from disk ", e.getMessage());
}

// if no hit then we need to request a new token from ZTS

updateServicePrincipal();
try {
final String requestBody = generateAccessTokenRequestBody(domainName, roleNames,
idTokenServiceName, proxyForPrincipal, expiryTime);
accessTokenResponse = ztsClient.postAccessTokenRequest(requestBody);
} catch (ResourceException ex) {
if (cacheKey != null && !ignoreCache) {
accessTokenResponse = lookupAccessTokenResponseInCache(cacheKey, -1);
if (accessTokenResponse != null) {
return accessTokenResponse;
if (accessTokenResponse == null) {
updateServicePrincipal();
try {
final String requestBody = generateAccessTokenRequestBody(domainName, roleNames,
idTokenServiceName, proxyForPrincipal, expiryTime);
accessTokenResponse = ztsClient.postAccessTokenRequest(requestBody);
} catch (ResourceException ex) {
if (cacheKey != null && !ignoreCache) {
accessTokenResponse = lookupAccessTokenResponseInCache(cacheKey, -1);
if (accessTokenResponse != null) {
return accessTokenResponse;
}
}
}
throw new ZTSClientException(ex.getCode(), ex.getData());
} catch (Exception ex) {
if (cacheKey != null && !ignoreCache) {
accessTokenResponse = lookupAccessTokenResponseInCache(cacheKey, -1);
if (accessTokenResponse != null) {
return accessTokenResponse;
throw new ZTSClientException(ex.getCode(), ex.getData());
} catch (Exception ex) {
if (cacheKey != null && !ignoreCache) {
accessTokenResponse = lookupAccessTokenResponseInCache(cacheKey, -1);
if (accessTokenResponse != null) {
return accessTokenResponse;
}
}
throw new ZTSClientException(ZTSClientException.BAD_REQUEST, ex.getMessage());
}
throw new ZTSClientException(ZTSClientException.BAD_REQUEST, ex.getMessage());
}


// need to add the token to our cache. If our principal was
// updated then we need to retrieve a new cache key

Expand All @@ -1213,6 +1242,7 @@ public AccessTokenResponse getAccessToken(String domainName, List<String> roleNa
ACCESS_TOKEN_CACHE.put(cacheKey, new AccessTokenResponseCacheEntry(accessTokenResponse));
}
}

return accessTokenResponse;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public void init() {
System.setProperty(ZTSClient.ZTS_CLIENT_PROP_PREFETCH_AUTO_ENABLE, "false");
System.setProperty(ZTSClient.ZTS_CLIENT_PROP_X509CSR_DN, "ou=eng,o=athenz,c=us");
System.setProperty(ZTSClient.ZTS_CLIENT_PROP_X509CSR_DOMAIN, "athenz.cloud");
System.setProperty(ZTSClient.ZTS_CLIENT_PROP_ATHENZ_CONF, "src/test/resources/athenz.conf");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package com.yahoo.athenz.zts;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.athenz.auth.token.AccessToken;
import com.yahoo.athenz.auth.util.Crypto;
import io.jsonwebtoken.SignatureAlgorithm;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.Collections;

import static org.testng.Assert.fail;

public class AccessTokenTestFileHelper {

private static final File ecPrivateKey = new File("./src/test/resources/unit_test_ec_private.key");
private static final File tokenFile = new File("./src/test/resources/test.domain/admin");
private static final File invalidTokenFile = new File("./src/test/resources/test.domain/invalid");

private static AccessToken createAccessToken(long now) {

AccessToken accessToken = new AccessToken();
accessToken.setAuthTime(now);
accessToken.setScope(Collections.singletonList("admin"));
accessToken.setSubject("subject");
accessToken.setUserId("userid");
accessToken.setExpiryTime(now + 3600);
accessToken.setIssueTime(now);
accessToken.setClientId("mtls");
accessToken.setAudience("coretech");
accessToken.setVersion(1);
accessToken.setIssuer("athenz");
accessToken.setProxyPrincipal("proxy.user");
accessToken.setConfirmEntry("x5t#uri", "spiffe://athenz/sa/api");

try {
Path path = Paths.get("src/test/resources/mtls_token_spec.cert");
String certStr = new String(Files.readAllBytes(path));
X509Certificate cert = Crypto.loadX509Certificate(certStr);
accessToken.setConfirmX509CertHash(cert);
} catch (IOException ignored) {
fail();
}

return accessToken;
}

public static void setupTokenFile() {
AccessTokenResponse accessTokenResponse = new AccessTokenResponse();
long now = System.currentTimeMillis() / 1000;
AccessToken accessToken = createAccessToken(now);
PrivateKey privateKey = Crypto.loadPrivateKey(ecPrivateKey);
String accessJws = accessToken.getSignedToken(privateKey, "eckey1", SignatureAlgorithm.ES256);

accessTokenResponse.setAccess_token(accessJws);
accessTokenResponse.setExpires_in(28800);
accessTokenResponse.setScope("admin");
accessTokenResponse.setToken_type("Bearer");

ObjectMapper objectMapper = new ObjectMapper();

try {
objectMapper.writeValue(tokenFile, accessTokenResponse);
System.out.println("Write new access token " + accessTokenResponse.toString() + " to file: " + tokenFile + " successfully");
} catch (IOException e) {
e.printStackTrace();
fail();
}
}

public static void setupInvalidTokenFile() {
String str = "Invalid access token";

try {
BufferedWriter writer = new BufferedWriter(new FileWriter(invalidTokenFile));
writer.write(str);
writer.close();
} catch (IOException e) {
e.printStackTrace();
fail();
}

}
}
Loading

0 comments on commit 3b5bcdf

Please sign in to comment.