Skip to content

Commit

Permalink
Merge pull request exadel-inc#213 from exadel-inc/EFRS-835
Browse files Browse the repository at this point in the history
EFRS-835: synchronize API nodes
  • Loading branch information
andreigaevsky authored Dec 1, 2020
2 parents c29314a + 60b2e17 commit 80f0143
Show file tree
Hide file tree
Showing 21 changed files with 319 additions and 33 deletions.
1 change: 0 additions & 1 deletion api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
import org.springframework.scheduling.annotation.EnableScheduling;

@SpringBootApplication(exclude = {DataSourceAutoConfiguration.class})
@EnableScheduling
public class TrainServiceApplication {

public static void main(String[] args) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,24 @@ public static FaceCollection buildFromFaces(final List<Face> faces) {

synchronized public FaceBO addFace(final Face face) {
val cachedFace = new FaceBO(face.getFaceName(), face.getId());
facesMap.put(cachedFace, size.get());
val faceEmbeddings = face.getEmbedding().getEmbeddings()
.stream()
.mapToDouble(d -> d).toArray();
if (embeddings == null) {
embeddings = Nd4j.create(new double[][]{faceEmbeddings});
} else {
embeddings = Nd4j.concat(
0,
embeddings,
Nd4j.create(new double[][]{faceEmbeddings})
);
if(!facesMap.containsKey(cachedFace)) {
facesMap.put(cachedFace, size.get());
val faceEmbeddings = face.getEmbedding().getEmbeddings()
.stream()
.mapToDouble(d -> d).toArray();
if (embeddings == null) {
embeddings = Nd4j.create(new double[][]{faceEmbeddings});
} else {
embeddings = Nd4j.concat(
0,
embeddings,
Nd4j.create(new double[][]{faceEmbeddings})
);
}

embeddingsCopy = embeddings.dup();
size.getAndIncrement();
}

embeddingsCopy = embeddings.dup();
size.getAndIncrement();

return cachedFace;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package com.exadel.frs.core.trainservice.config.repository;

import static com.exadel.frs.core.trainservice.system.global.Constants.SERVER_UUID;
import com.exadel.frs.core.trainservice.dto.DbActionDto;
import com.exadel.frs.core.trainservice.service.DbActionService;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import lombok.extern.slf4j.Slf4j;
import org.postgresql.PGConnection;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Scheduled;

@Slf4j
class Listener {

@Autowired
private DbActionService actionService;

private final ObjectMapper mapper = new ObjectMapper();

private final Connection conn;
private final PGConnection pgconn;

Listener(Connection conn) throws SQLException {
this.conn = conn;
this.pgconn = (PGConnection)conn;
Statement stmt = conn.createStatement();
stmt.execute("LISTEN face_collection_update_msg");
stmt.close();
log.info(String.format("Listener %s is started", SERVER_UUID));
}

@Scheduled(fixedRate=500)
public void listen() {
try {
Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery("SELECT 1");
rs.close();
stmt.close();

org.postgresql.PGNotification[] notifications = pgconn.getNotifications();
if (notifications != null) {
for (final org.postgresql.PGNotification notification : notifications) {
log.info("Get notification: " + notification.getName());
actionService.synchronizeCache(mapper.readValue(notification.getParameter(), DbActionDto.class));
}
}
} catch (SQLException | JsonProcessingException e) {
e.printStackTrace();
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package com.exadel.frs.core.trainservice.config.repository;

import java.sql.DriverManager;
import java.sql.SQLException;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class NotificationConfig {

@Value("${spring.datasource-pg.username}")
private String dbUsername;

@Value("${spring.datasource-pg.password}")
private String dbPassword;

@Value("${spring.datasource-pg.url}")
private String dbUrl;

@Bean
public Listener dbListenerRun() throws SQLException {
return new Listener(DriverManager.getConnection(dbUrl, dbUsername, dbPassword));
}

@Bean
public Notifier dbNotifier() throws SQLException {
return new Notifier(DriverManager.getConnection(dbUrl, dbUsername, dbPassword));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package com.exadel.frs.core.trainservice.config.repository;

import com.exadel.frs.core.trainservice.dto.DbActionDto;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;

public class Notifier{

private final Connection conn;
private final ObjectMapper mapper = new ObjectMapper();

public Notifier(Connection conn) {
this.conn = conn;
}

public void notifyWithMessage(DbActionDto actionDto) {
try {
String actionString = mapper.writerWithDefaultPrettyPrinter().writeValueAsString(actionDto);
Statement stmt = conn.createStatement();
stmt.execute(String.format("SELECT pg_notify('face_collection_update_msg', '%s');", actionString));
stmt.close();
} catch (SQLException | JsonProcessingException sqle) {
sqle.printStackTrace();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ public List<Face> deleteFaceByName(final String faceName, final String modelApiK

public Face deleteFaceById(final String faceId) {
val foundFace = facesRepository.findById(faceId);
foundFace.ifPresent(face -> {
facesRepository.delete(face);
});
foundFace.ifPresent(facesRepository::delete);

return foundFace.orElse(null);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.exadel.frs.core.trainservice.dto;

import com.exadel.frs.core.trainservice.enums.DbAction;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

@Data
@AllArgsConstructor
@NoArgsConstructor
public class DbActionDto {

@JsonProperty("action")
private DbAction action;

@JsonProperty("apiKey")
private String apiKey;

@JsonProperty("faceIds")
private List<String> faceIds;

@JsonProperty("faceName")
private String faceName;

@JsonProperty("uuid")
private String serverUUID;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.exadel.frs.core.trainservice.enums;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;

@NoArgsConstructor
@AllArgsConstructor
public enum DbAction {
INSERT("I"),
DELETE("D"),
DELETE_ALL("DA");

@Getter
@Setter
private String code;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.exadel.frs.core.trainservice.service;

import static com.exadel.frs.core.trainservice.system.global.Constants.SERVER_UUID;
import com.exadel.frs.core.trainservice.cache.FaceCacheProvider;
import com.exadel.frs.core.trainservice.dto.DbActionDto;
import com.exadel.frs.core.trainservice.repository.FacesRepository;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;

@Service
@RequiredArgsConstructor
public class DbActionService {

private final FaceCacheProvider faceCacheProvider;
private final FacesRepository facesRepository;

public void synchronizeCache(DbActionDto action) {
if (!action.getServerUUID().equals(SERVER_UUID)) {
switch (action.getAction()) {
case DELETE:
action.getFaceIds()
.forEach(face -> faceCacheProvider.getOrLoad(action.getApiKey())
.removeFace(face, action.getFaceName())
);
break;
case INSERT:
faceCacheProvider.getOrLoad(action.getApiKey())
.addFace(facesRepository.findById(action.getFaceIds().get(0)).get());
break;
case DELETE_ALL:
faceCacheProvider.invalidate(action.getApiKey());
break;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,19 @@

package com.exadel.frs.core.trainservice.service;

import static com.exadel.frs.core.trainservice.enums.DbAction.DELETE;
import static com.exadel.frs.core.trainservice.system.global.Constants.SERVER_UUID;
import static java.util.stream.Collectors.toSet;
import com.exadel.frs.core.trainservice.cache.FaceBO;
import com.exadel.frs.core.trainservice.cache.FaceCacheProvider;
import com.exadel.frs.core.trainservice.config.repository.Notifier;
import com.exadel.frs.core.trainservice.dao.FaceDao;
import com.exadel.frs.core.trainservice.dto.DbActionDto;
import com.exadel.frs.core.trainservice.entity.Face;
import com.exadel.frs.core.trainservice.enums.DbAction;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.RequiredArgsConstructor;
import lombok.val;
import org.springframework.stereotype.Service;
Expand All @@ -31,6 +39,7 @@ public class FaceService {

private final FaceDao faceDao;
private final FaceCacheProvider faceCacheProvider;
private final Notifier notifier;

public Set<FaceBO> findFaces(final String apiKey) {
return faceCacheProvider.getOrLoad(apiKey).getFaces();
Expand All @@ -39,16 +48,22 @@ public Set<FaceBO> findFaces(final String apiKey) {
public Set<FaceBO> deleteFaceByName(final String faceName, final String apiKey) {
val faces = faceCacheProvider.getOrLoad(apiKey);

return faceDao.deleteFaceByName(faceName, apiKey)
.stream()
.map(face -> faces.removeFace(face.getId(), face.getFaceName()))
.collect(toSet());
val deletedFaces = faceDao.deleteFaceByName(faceName, apiKey);

val faceIds = deletedFaces.stream().map(Face::getId).collect(Collectors.toList());
deletedFaces.forEach(face -> notifier.notifyWithMessage(new DbActionDto(DELETE, apiKey, faceIds, face.getFaceName(), SERVER_UUID)));

return deletedFaces
.stream()
.map(face -> faces.removeFace(face.getId(), face.getFaceName()))
.collect(toSet());
}

public FaceBO deleteFaceById(final String id, final String apiKey) {
val collection = faceCacheProvider.getOrLoad(apiKey);
val face = faceDao.deleteFaceById(id);
if (face != null) {
notifier.notifyWithMessage(new DbActionDto(DELETE, face.getApiKey(), List.of(face.getId()), face.getFaceName(), SERVER_UUID));
return collection.removeFace(face.getId(), face.getFaceName());
}

Expand All @@ -57,6 +72,7 @@ public FaceBO deleteFaceById(final String id, final String apiKey) {

public void deleteFacesByModel(final String modelKey) {
faceDao.deleteFacesByApiKey(modelKey);
notifier.notifyWithMessage(new DbActionDto(DbAction.DELETE_ALL, modelKey, null, null, SERVER_UUID));
faceCacheProvider.invalidate(modelKey);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,21 @@

package com.exadel.frs.core.trainservice.service;

import static com.exadel.frs.core.trainservice.system.global.Constants.SERVER_UUID;
import com.exadel.frs.core.trainservice.cache.FaceBO;
import com.exadel.frs.core.trainservice.cache.FaceCacheProvider;
import com.exadel.frs.core.trainservice.config.repository.Notifier;
import com.exadel.frs.core.trainservice.dao.FaceDao;
import com.exadel.frs.core.trainservice.dto.DbActionDto;
import com.exadel.frs.core.trainservice.entity.Face.Embedding;
import com.exadel.frs.core.trainservice.enums.DbAction;
import com.exadel.frs.core.trainservice.exception.NoFacesFoundException;
import com.exadel.frs.core.trainservice.exception.TooManyFacesException;
import com.exadel.frs.core.trainservice.system.feign.python.FacesClient;
import com.exadel.frs.core.trainservice.system.feign.python.ScanResponse;
import feign.FeignException;
import java.io.IOException;
import java.util.List;
import lombok.RequiredArgsConstructor;
import lombok.val;
import org.springframework.stereotype.Service;
Expand All @@ -41,6 +46,7 @@ public class ScanServiceImpl implements ScanService {
private final FacesClient facesClient;
private final FaceDao faceDao;
private final FaceCacheProvider faceCacheProvider;
private final Notifier notifier;

@Override
public FaceBO scanAndSaveFace(
Expand All @@ -67,8 +73,18 @@ public FaceBO scanAndSaveFace(

val embeddingToSave = new Embedding(embedding, scanResponse.getCalculatorVersion());

val savedFace = faceDao.addNewFace(embeddingToSave, file, faceName, modelKey);

notifier.notifyWithMessage(new DbActionDto(
DbAction.INSERT,
savedFace.getApiKey(),
List.of(savedFace.getId()),
savedFace.getFaceName(),
SERVER_UUID
));

return faceCacheProvider.getOrLoad(modelKey).addFace(
faceDao.addNewFace(embeddingToSave, file, faceName, modelKey)
savedFace
);
}
}
Loading

0 comments on commit 80f0143

Please sign in to comment.