Skip to content

Commit

Permalink
sync spdz connections info from gateway
Browse files Browse the repository at this point in the history
  • Loading branch information
ywy2090 committed Dec 5, 2024
1 parent 305a818 commit 9bd6881
Show file tree
Hide file tree
Showing 18 changed files with 361 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,5 @@ public class Constant {

//// the serviceType
public static final String PIR_SERVICE_TYPE = "PIR";
public static final String SPDZ_SERVICE_TYPE = "SPDZ";
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.webank.wedpr.components.project.dao.ProjectMapperWrapper;
import com.webank.wedpr.components.scheduler.SchedulerBuilder;
import com.webank.wedpr.components.scheduler.core.SchedulerTaskImpl;
import com.webank.wedpr.components.scheduler.core.SpdzConnections;
import com.webank.wedpr.components.scheduler.executor.ExecuteResult;
import com.webank.wedpr.components.scheduler.executor.callback.TaskFinishedHandler;
import com.webank.wedpr.components.scheduler.executor.impl.ExecutiveContextBuilder;
Expand Down Expand Up @@ -81,6 +82,10 @@ public class SchedulerLoader {
@Autowired
private WeDPRTransport weDPRTransport;

@Qualifier("spdzConnections")
@Autowired
private SpdzConnections spdzConnections;

@Autowired private ServiceAuthMapper serviceAuthMapper;
@Autowired private ApiCredentialMapper apiCredentialMapper;
@Autowired private DatasetMapper datasetMapper;
Expand Down Expand Up @@ -134,7 +139,8 @@ protected void registerExecutors(
executorManager,
new ExecutiveContextBuilder(projectMapperWrapper),
threadPoolService,
datasetMapper);
datasetMapper,
spdzConnections);
executorManager.registerExecutor(ExecutorType.DAG.getType(), dagSchedulerExecutor);
// default
TaskFinishedHandler taskFinishedHandler =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.webank.wedpr.components.scheduler.config;

import com.webank.wedpr.components.scheduler.core.SpdzConnections;
import com.webank.wedpr.sdk.jni.generated.Error;
import com.webank.wedpr.sdk.jni.transport.WeDPRTransport;
import com.webank.wedpr.sdk.jni.transport.handlers.GetPeersCallback;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class SpdzConnectionsConfig {

private static final Logger logger = LoggerFactory.getLogger(SpdzConnectionsConfig.class);

@Qualifier("weDPRTransport")
@Autowired
private WeDPRTransport weDPRTransport;

@Qualifier("spdzConnections")
@Autowired
private SpdzConnections spdzConnections;

@Bean
public void initUpdateSpdzConnectionsTask() throws Exception {

logger.info("init spdz connection update period task");

ScheduledExecutorService scheduledExecutorService = new ScheduledThreadPoolExecutor(1);

scheduledExecutorService.scheduleAtFixedRate(
new Runnable() {
@Override
public void run() {
weDPRTransport.asyncGetPeers(
new GetPeersCallback() {
@Override
public void onPeers(Error error, String jsonStr) {
spdzConnections.updateSpdzConnections(jsonStr);
}
});
}
},
10,
10,
TimeUnit.SECONDS);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package com.webank.wedpr.components.scheduler.core;

import java.util.List;
import lombok.Data;

@Data
public class GetPeers {

@Data
public static class Meta {
private List<ServiceInfo> serviceInfos;

// getters and setters
}

@Data
public static class Front {
private List<String> components;
private String endPoint;
private String meta;
private String nodeID;

// getters and setters
}

@Data
public static class Gateway {
private String agency;
private List<Front> frontList;
private String gatewayNodeID;

// getters and setters
}

@Data
public static class Peer {
private String agency;
private List<Gateway> gateway;

// getters and setters
}

@Data
public static class ServiceInfo {
private String entryPoint;
private String serviceName;
private List<String> components;

// getters and setters
}

private String agency;
private Gateway gateway;
private String nodeID;
private List<Peer> peers;

// getters and setters
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package com.webank.wedpr.components.scheduler.core;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.webank.wedpr.common.utils.ObjectMapperFactory;
import com.webank.wedpr.common.utils.WeDPRException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import lombok.Data;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

@Component("spdzConnections")
@Data
public class SpdzConnections {

private static final Logger logger = LoggerFactory.getLogger(SpdzConnections.class);
private static final String SPDZ_SERVICE_NAME = "SPDZ";

@Data
public static class Connection {
private String ip;
private int port;

public Connection(String ip, int port) {
this.ip = ip;
this.port = port;
}
};

private final Map<String, Connection> agency2Connection = new ConcurrentHashMap<>();

public Connection getConnection(String agency) throws WeDPRException {
Connection connection = agency2Connection.get(agency);
if (connection == null) {
logger.error("cannot find spdz connection info, agency: {}", agency);
throw new WeDPRException("cannot find spdz connection info, agency: " + agency);
}
logger.info("get spdz connection: {}", connection);
return connection;
}

public boolean updateConnection(String agency, String spdzEndpoint) {
try {
String[] split = spdzEndpoint.split(":");

if (split.length != 2) {
throw new IllegalArgumentException(
"invalid endpoint format, endpoint: " + spdzEndpoint);
}

String ip = split[0].trim();
int port = Integer.parseInt(split[1].trim());

Connection connection = new Connection(ip, port);

Connection oldConnection = agency2Connection.get(agency);
if (oldConnection == null || !oldConnection.equals(connection)) {
agency2Connection.remove(agency);
agency2Connection.put(agency, connection);

logger.info(
"update spdz connection, agency: {}, endpoint: {}", agency, spdzEndpoint);
} else {
logger.debug(
"update spdz connection, agency: {}, endpoint: {}", agency, spdzEndpoint);
}

return true;
} catch (Exception e) {
logger.error(
"update spdz connection failed, agency: {}, endpoint: {}, e:",
agency,
spdzEndpoint,
e);
}

return false;
}

public void updateSpdzConnections(GetPeers getPeers) throws JsonProcessingException {
for (GetPeers.Peer peer : getPeers.getPeers()) {
String agency = peer.getAgency();

boolean updateResult = false;
for (GetPeers.Gateway gateway : peer.getGateway()) {
for (GetPeers.Front front : gateway.getFrontList()) {
String strMeta = front.getMeta();

if (strMeta == null || !strMeta.contains(SPDZ_SERVICE_NAME)) {
continue;
}

if (logger.isDebugEnabled()) {
logger.debug("agency: {}, meta: {}", agency, strMeta);
}

GetPeers.Meta meta =
ObjectMapperFactory.getObjectMapper()
.readValue(strMeta, GetPeers.Meta.class);

if (meta == null) {
continue;
}

for (GetPeers.ServiceInfo serviceInfo : meta.getServiceInfos()) {
String serviceName = serviceInfo.getServiceName();
String entryPoint = serviceInfo.getEntryPoint();
if (serviceName == null || !serviceName.contains(SPDZ_SERVICE_NAME)) {
continue;
}
// SPDZ
logger.debug("serviceName: {}, entryPoint: {}", serviceName, entryPoint);

updateResult = updateConnection(agency, entryPoint);
if (updateResult) {
break;
}
}

if (updateResult) {
break;
}
}
}
}
}

public void updateSpdzConnections(String jsonStr) {
if (jsonStr == null || jsonStr.isEmpty()) {
return;
}

try {
GetPeers getPeers =
ObjectMapperFactory.getObjectMapper().readValue(jsonStr, GetPeers.class);
updateSpdzConnections(getPeers);
} catch (Exception e) {
logger.error("e: ", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.webank.wedpr.components.db.mapper.dataset.mapper.DatasetMapper;
import com.webank.wedpr.components.loadbalancer.LoadBalancer;
import com.webank.wedpr.components.project.dao.JobDO;
import com.webank.wedpr.components.scheduler.core.SpdzConnections;
import com.webank.wedpr.components.scheduler.dag.api.WorkFlowScheduler;
import com.webank.wedpr.components.scheduler.dag.base.DAG;
import com.webank.wedpr.components.scheduler.dag.base.DAGNode;
Expand Down Expand Up @@ -34,6 +35,7 @@ public class DagWorkFlowSchedulerImpl implements WorkFlowScheduler {
private DatasetMapper datasetMapper;
private FileStorageInterface fileStorageInterface;
private FileMetaBuilder fileMetaBuilder;
private SpdzConnections spdzConnections;

private final Integer workerRetryTimes = -1;
private final Integer workerRetryDelayMillis = -1;
Expand All @@ -43,12 +45,14 @@ public DagWorkFlowSchedulerImpl(
JobWorkerMapper jobWorkerMapper,
DatasetMapper datasetMapper,
FileStorageInterface fileStorageInterface,
FileMetaBuilder fileMetaBuilder) {
FileMetaBuilder fileMetaBuilder,
SpdzConnections spdzConnections) {
this.loadBalancer = loadBalancer;
this.jobWorkerMapper = jobWorkerMapper;
this.datasetMapper = datasetMapper;
this.fileStorageInterface = fileStorageInterface;
this.fileMetaBuilder = fileMetaBuilder;
this.spdzConnections = spdzConnections;
}

public LoadBalancer getLoadBalancer() {
Expand Down Expand Up @@ -128,7 +132,8 @@ public List<Worker> prepareDag(String jobId, JobDO jobDO, WorkFlow workFlow)
jobWorkerMapper,
datasetMapper,
fileStorageInterface,
fileMetaBuilder);
fileMetaBuilder,
spdzConnections);
workerList.add(worker);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.webank.wedpr.components.loadbalancer.LoadBalancer;
import com.webank.wedpr.components.project.dao.JobDO;
import com.webank.wedpr.components.scheduler.client.ModelClient;
import com.webank.wedpr.components.scheduler.core.SpdzConnections;
import com.webank.wedpr.components.scheduler.dag.entity.JobWorker;
import com.webank.wedpr.components.scheduler.executor.impl.model.FileMetaBuilder;
import com.webank.wedpr.components.scheduler.mapper.JobWorkerMapper;
Expand All @@ -28,7 +29,8 @@ public ModelWorker(
JobWorkerMapper jobWorkerMapper,
DatasetMapper datasetMapper,
FileStorageInterface fileStorageInterface,
FileMetaBuilder fileMetaBuilder) {
FileMetaBuilder fileMetaBuilder,
SpdzConnections spdzConnections) {
super(
jobDO,
jobWorker,
Expand All @@ -38,7 +40,8 @@ public ModelWorker(
jobWorkerMapper,
datasetMapper,
fileStorageInterface,
fileMetaBuilder);
fileMetaBuilder,
spdzConnections);
}

@Override
Expand Down
Loading

0 comments on commit 9bd6881

Please sign in to comment.