diff --git a/java/api/src/main/java/com/exadel/frs/core/trainservice/controller/EmbeddingController.java b/java/api/src/main/java/com/exadel/frs/core/trainservice/controller/EmbeddingController.java index 1fd66632f9..bb4d9b9e48 100644 --- a/java/api/src/main/java/com/exadel/frs/core/trainservice/controller/EmbeddingController.java +++ b/java/api/src/main/java/com/exadel/frs/core/trainservice/controller/EmbeddingController.java @@ -29,6 +29,9 @@ import com.exadel.frs.core.trainservice.aspect.WriteEndpoint; import com.exadel.frs.core.trainservice.dto.Base64File; import com.exadel.frs.core.trainservice.dto.EmbeddingDto; +import com.exadel.frs.core.trainservice.dto.EmbeddingsRecognitionRequest; +import com.exadel.frs.core.trainservice.dto.EmbeddingsVerificationProcessResponse; +import com.exadel.frs.core.trainservice.dto.ProcessEmbeddingsParams; import com.exadel.frs.core.trainservice.dto.ProcessImageParams; import com.exadel.frs.core.trainservice.dto.VerificationResult; import com.exadel.frs.core.trainservice.mapper.EmbeddingMapper; @@ -69,7 +72,7 @@ @Validated @RestController -@RequestMapping(API_V1 + "/recognition/faces") +@RequestMapping(API_V1 + "/recognition") @RequiredArgsConstructor public class EmbeddingController { @@ -81,7 +84,7 @@ public class EmbeddingController { @WriteEndpoint @ResponseStatus(CREATED) - @PostMapping + @PostMapping("/faces") public EmbeddingDto addEmbedding( @ApiParam(value = IMAGE_WITH_ONE_FACE_DESC, required = true) @RequestParam @@ -112,7 +115,7 @@ public EmbeddingDto addEmbedding( @WriteEndpoint @ResponseStatus(CREATED) - @PostMapping(consumes = MediaType.APPLICATION_JSON_VALUE) + @PostMapping(value = "/faces", consumes = MediaType.APPLICATION_JSON_VALUE) public EmbeddingDto addEmbeddingBase64( @ApiParam(value = API_KEY_DESC, required = true) @RequestHeader(X_FRS_API_KEY_HEADER) @@ -142,7 +145,7 @@ public EmbeddingDto addEmbeddingBase64( } @ResponseBody - @GetMapping(value = "/{embeddingId}/img", produces = MediaType.APPLICATION_OCTET_STREAM_VALUE) + @GetMapping(value = "/faces/{embeddingId}/img", produces = MediaType.APPLICATION_OCTET_STREAM_VALUE) public byte[] downloadImg(HttpServletResponse response, @ApiParam(value = API_KEY_DESC, required = true) @RequestHeader(name = X_FRS_API_KEY_HEADER) @@ -157,7 +160,7 @@ public byte[] downloadImg(HttpServletResponse response, .orElse(new byte[]{}); } - @GetMapping + @GetMapping("/faces") public Faces listEmbeddings( @ApiParam(value = API_KEY_DESC, required = true) @RequestHeader(name = X_FRS_API_KEY_HEADER) @@ -172,7 +175,7 @@ public Faces listEmbeddings( } @WriteEndpoint - @DeleteMapping + @DeleteMapping("/faces") public Map removeAllSubjectEmbeddings( @ApiParam(value = API_KEY_DESC, required = true) @RequestHeader(name = X_FRS_API_KEY_HEADER) @@ -190,7 +193,7 @@ public Map removeAllSubjectEmbeddings( } @WriteEndpoint - @DeleteMapping("/{embeddingId}") + @DeleteMapping("/faces/{embeddingId}") public EmbeddingDto deleteEmbeddingById( @ApiParam(value = API_KEY_DESC, required = true) @RequestHeader(name = X_FRS_API_KEY_HEADER) @@ -204,7 +207,7 @@ public EmbeddingDto deleteEmbeddingById( } @WriteEndpoint - @PostMapping("/delete") + @PostMapping("/faces/delete") public List deleteEmbeddingsById( @ApiParam(value = API_KEY_DESC, required = true) @RequestHeader(name = X_FRS_API_KEY_HEADER) @@ -220,7 +223,7 @@ public List deleteEmbeddingsById( return dtoList; } - @PostMapping(value = "/{embeddingId}/verify", consumes = MediaType.MULTIPART_FORM_DATA_VALUE) + @PostMapping(value = "/faces/{embeddingId}/verify", consumes = MediaType.MULTIPART_FORM_DATA_VALUE) public VerificationResult recognizeFile( @ApiParam(value = API_KEY_DESC, required = true) @RequestHeader(X_FRS_API_KEY_HEADER) @@ -263,7 +266,7 @@ public VerificationResult recognizeFile( ); } - @PostMapping(value = "/{embeddingId}/verify", consumes = MediaType.APPLICATION_JSON_VALUE) + @PostMapping(value = "/faces/{embeddingId}/verify", consumes = MediaType.APPLICATION_JSON_VALUE) public VerificationResult recognizeBase64( @ApiParam(value = API_KEY_DESC, required = true) @RequestHeader(X_FRS_API_KEY_HEADER) @@ -307,6 +310,28 @@ public VerificationResult recognizeBase64( ); } + @PostMapping(value = "/embeddings/faces/{embeddingId}/verify", consumes = MediaType.APPLICATION_JSON_VALUE) + public EmbeddingsVerificationProcessResponse recognizeEmbeddings( + @ApiParam(value = API_KEY_DESC, required = true) + @RequestHeader(X_FRS_API_KEY_HEADER) + final String apiKey, + @ApiParam(value = IMAGE_ID_DESC, required = true) + @PathVariable + final UUID embeddingId, + @RequestBody + @Valid + final EmbeddingsRecognitionRequest recognitionRequest + ) { + ProcessEmbeddingsParams processParams = + ProcessEmbeddingsParams.builder() + .apiKey(apiKey) + .embeddings(recognitionRequest.getEmbeddings()) + .additionalParams(Map.of(IMAGE_ID, embeddingId)) + .build(); + + return subjectService.verifyEmbedding(processParams); + } + @RequiredArgsConstructor private static final class Faces { diff --git a/java/api/src/main/java/com/exadel/frs/core/trainservice/service/SubjectService.java b/java/api/src/main/java/com/exadel/frs/core/trainservice/service/SubjectService.java index 6477d078c2..542b8ac05f 100644 --- a/java/api/src/main/java/com/exadel/frs/core/trainservice/service/SubjectService.java +++ b/java/api/src/main/java/com/exadel/frs/core/trainservice/service/SubjectService.java @@ -4,6 +4,7 @@ import com.exadel.frs.commonservice.entity.Subject; import com.exadel.frs.commonservice.exception.EmbeddingNotFoundException; import com.exadel.frs.commonservice.exception.TooManyFacesException; +import com.exadel.frs.commonservice.exception.WrongEmbeddingCountException; import com.exadel.frs.commonservice.sdk.faces.FacesApiClient; import com.exadel.frs.commonservice.sdk.faces.feign.dto.FindFacesResponse; import com.exadel.frs.commonservice.sdk.faces.feign.dto.FindFacesResult; @@ -13,11 +14,17 @@ import com.exadel.frs.core.trainservice.component.classifiers.EuclideanDistanceClassifier; import com.exadel.frs.core.trainservice.dao.SubjectDao; import com.exadel.frs.core.trainservice.dto.EmbeddingInfo; +import com.exadel.frs.core.trainservice.dto.EmbeddingVerificationProcessResult; +import com.exadel.frs.core.trainservice.dto.EmbeddingsVerificationProcessResponse; import com.exadel.frs.core.trainservice.dto.FaceVerification; +import com.exadel.frs.core.trainservice.dto.ProcessEmbeddingsParams; import com.exadel.frs.core.trainservice.dto.ProcessImageParams; import com.exadel.frs.core.trainservice.system.global.Constants; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.stereotype.Service; @@ -36,6 +43,7 @@ @Slf4j public class SubjectService { + private static final int MINIMUM_EMBEDDING_COUNT = 1; private static final int MAX_FACES_TO_SAVE = 1; public static final int MAX_FACES_TO_RECOGNIZE = 2; @@ -272,4 +280,28 @@ public Pair, PluginsVersions> verifyFace(ProcessImagePara Boolean.TRUE.equals(processImageParams.getStatus()) ? findFacesResponse.getPluginsVersions() : null ); } + + public EmbeddingsVerificationProcessResponse verifyEmbedding(ProcessEmbeddingsParams processEmbeddingsParams) { + double[][] targets = processEmbeddingsParams.getEmbeddings(); + if (ArrayUtils.isEmpty(targets)) { + throw new WrongEmbeddingCountException(MINIMUM_EMBEDDING_COUNT, 0); + } + + UUID sourceId = (UUID) processEmbeddingsParams.getAdditionalParams().get(Constants.IMAGE_ID); + String apiKey = processEmbeddingsParams.getApiKey(); + + List results = + Arrays.stream(targets) + .map(target -> processTarget(target, sourceId, apiKey)) + .sorted((e1, e2) -> Float.compare(e2.getSimilarity(), e1.getSimilarity())) + .collect(Collectors.toList()); + + return new EmbeddingsVerificationProcessResponse(results); + } + + private EmbeddingVerificationProcessResult processTarget(double[] target, UUID sourceId, String apiKey) { + double similarity = predictor.verify(apiKey, target, sourceId); + float scaledSimilarity = BigDecimal.valueOf(similarity).setScale(5, HALF_UP).floatValue(); + return new EmbeddingVerificationProcessResult(target, scaledSimilarity); + } }