Skip to content

Commit

Permalink
EFRS-1333: Implemented the ability to verify embedding by id with oth…
Browse files Browse the repository at this point in the history
…er embeddings
  • Loading branch information
VolodymyrBushko committed Nov 10, 2022
1 parent 475f750 commit 130ff41
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
import com.exadel.frs.core.trainservice.aspect.WriteEndpoint;
import com.exadel.frs.core.trainservice.dto.Base64File;
import com.exadel.frs.core.trainservice.dto.EmbeddingDto;
import com.exadel.frs.core.trainservice.dto.EmbeddingsRecognitionRequest;
import com.exadel.frs.core.trainservice.dto.EmbeddingsVerificationProcessResponse;
import com.exadel.frs.core.trainservice.dto.ProcessEmbeddingsParams;
import com.exadel.frs.core.trainservice.dto.ProcessImageParams;
import com.exadel.frs.core.trainservice.dto.VerificationResult;
import com.exadel.frs.core.trainservice.mapper.EmbeddingMapper;
Expand Down Expand Up @@ -69,7 +72,7 @@

@Validated
@RestController
@RequestMapping(API_V1 + "/recognition/faces")
@RequestMapping(API_V1 + "/recognition")
@RequiredArgsConstructor
public class EmbeddingController {

Expand All @@ -81,7 +84,7 @@ public class EmbeddingController {

@WriteEndpoint
@ResponseStatus(CREATED)
@PostMapping
@PostMapping("/faces")
public EmbeddingDto addEmbedding(
@ApiParam(value = IMAGE_WITH_ONE_FACE_DESC, required = true)
@RequestParam
Expand Down Expand Up @@ -112,7 +115,7 @@ public EmbeddingDto addEmbedding(

@WriteEndpoint
@ResponseStatus(CREATED)
@PostMapping(consumes = MediaType.APPLICATION_JSON_VALUE)
@PostMapping(value = "/faces", consumes = MediaType.APPLICATION_JSON_VALUE)
public EmbeddingDto addEmbeddingBase64(
@ApiParam(value = API_KEY_DESC, required = true)
@RequestHeader(X_FRS_API_KEY_HEADER)
Expand Down Expand Up @@ -142,7 +145,7 @@ public EmbeddingDto addEmbeddingBase64(
}

@ResponseBody
@GetMapping(value = "/{embeddingId}/img", produces = MediaType.APPLICATION_OCTET_STREAM_VALUE)
@GetMapping(value = "/faces/{embeddingId}/img", produces = MediaType.APPLICATION_OCTET_STREAM_VALUE)
public byte[] downloadImg(HttpServletResponse response,
@ApiParam(value = API_KEY_DESC, required = true)
@RequestHeader(name = X_FRS_API_KEY_HEADER)
Expand All @@ -157,7 +160,7 @@ public byte[] downloadImg(HttpServletResponse response,
.orElse(new byte[]{});
}

@GetMapping
@GetMapping("/faces")
public Faces listEmbeddings(
@ApiParam(value = API_KEY_DESC, required = true)
@RequestHeader(name = X_FRS_API_KEY_HEADER)
Expand All @@ -172,7 +175,7 @@ public Faces listEmbeddings(
}

@WriteEndpoint
@DeleteMapping
@DeleteMapping("/faces")
public Map<String, Object> removeAllSubjectEmbeddings(
@ApiParam(value = API_KEY_DESC, required = true)
@RequestHeader(name = X_FRS_API_KEY_HEADER)
Expand All @@ -190,7 +193,7 @@ public Map<String, Object> removeAllSubjectEmbeddings(
}

@WriteEndpoint
@DeleteMapping("/{embeddingId}")
@DeleteMapping("/faces/{embeddingId}")
public EmbeddingDto deleteEmbeddingById(
@ApiParam(value = API_KEY_DESC, required = true)
@RequestHeader(name = X_FRS_API_KEY_HEADER)
Expand All @@ -204,7 +207,7 @@ public EmbeddingDto deleteEmbeddingById(
}

@WriteEndpoint
@PostMapping("/delete")
@PostMapping("/faces/delete")
public List<EmbeddingDto> deleteEmbeddingsById(
@ApiParam(value = API_KEY_DESC, required = true)
@RequestHeader(name = X_FRS_API_KEY_HEADER)
Expand All @@ -220,7 +223,7 @@ public List<EmbeddingDto> deleteEmbeddingsById(
return dtoList;
}

@PostMapping(value = "/{embeddingId}/verify", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
@PostMapping(value = "/faces/{embeddingId}/verify", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
public VerificationResult recognizeFile(
@ApiParam(value = API_KEY_DESC, required = true)
@RequestHeader(X_FRS_API_KEY_HEADER)
Expand Down Expand Up @@ -263,7 +266,7 @@ public VerificationResult recognizeFile(
);
}

@PostMapping(value = "/{embeddingId}/verify", consumes = MediaType.APPLICATION_JSON_VALUE)
@PostMapping(value = "/faces/{embeddingId}/verify", consumes = MediaType.APPLICATION_JSON_VALUE)
public VerificationResult recognizeBase64(
@ApiParam(value = API_KEY_DESC, required = true)
@RequestHeader(X_FRS_API_KEY_HEADER)
Expand Down Expand Up @@ -307,6 +310,28 @@ public VerificationResult recognizeBase64(
);
}

@PostMapping(value = "/embeddings/faces/{embeddingId}/verify", consumes = MediaType.APPLICATION_JSON_VALUE)
public EmbeddingsVerificationProcessResponse recognizeEmbeddings(
@ApiParam(value = API_KEY_DESC, required = true)
@RequestHeader(X_FRS_API_KEY_HEADER)
final String apiKey,
@ApiParam(value = IMAGE_ID_DESC, required = true)
@PathVariable
final UUID embeddingId,
@RequestBody
@Valid
final EmbeddingsRecognitionRequest recognitionRequest
) {
ProcessEmbeddingsParams processParams =
ProcessEmbeddingsParams.builder()
.apiKey(apiKey)
.embeddings(recognitionRequest.getEmbeddings())
.additionalParams(Map.of(IMAGE_ID, embeddingId))
.build();

return subjectService.verifyEmbedding(processParams);
}

@RequiredArgsConstructor
private static final class Faces {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.exadel.frs.commonservice.entity.Subject;
import com.exadel.frs.commonservice.exception.EmbeddingNotFoundException;
import com.exadel.frs.commonservice.exception.TooManyFacesException;
import com.exadel.frs.commonservice.exception.WrongEmbeddingCountException;
import com.exadel.frs.commonservice.sdk.faces.FacesApiClient;
import com.exadel.frs.commonservice.sdk.faces.feign.dto.FindFacesResponse;
import com.exadel.frs.commonservice.sdk.faces.feign.dto.FindFacesResult;
Expand All @@ -13,11 +14,17 @@
import com.exadel.frs.core.trainservice.component.classifiers.EuclideanDistanceClassifier;
import com.exadel.frs.core.trainservice.dao.SubjectDao;
import com.exadel.frs.core.trainservice.dto.EmbeddingInfo;
import com.exadel.frs.core.trainservice.dto.EmbeddingVerificationProcessResult;
import com.exadel.frs.core.trainservice.dto.EmbeddingsVerificationProcessResponse;
import com.exadel.frs.core.trainservice.dto.FaceVerification;
import com.exadel.frs.core.trainservice.dto.ProcessEmbeddingsParams;
import com.exadel.frs.core.trainservice.dto.ProcessImageParams;
import com.exadel.frs.core.trainservice.system.global.Constants;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Service;
Expand All @@ -36,6 +43,7 @@
@Slf4j
public class SubjectService {

private static final int MINIMUM_EMBEDDING_COUNT = 1;
private static final int MAX_FACES_TO_SAVE = 1;
public static final int MAX_FACES_TO_RECOGNIZE = 2;

Expand Down Expand Up @@ -272,4 +280,28 @@ public Pair<List<FaceVerification>, PluginsVersions> verifyFace(ProcessImagePara
Boolean.TRUE.equals(processImageParams.getStatus()) ? findFacesResponse.getPluginsVersions() : null
);
}

public EmbeddingsVerificationProcessResponse verifyEmbedding(ProcessEmbeddingsParams processEmbeddingsParams) {
double[][] targets = processEmbeddingsParams.getEmbeddings();
if (ArrayUtils.isEmpty(targets)) {
throw new WrongEmbeddingCountException(MINIMUM_EMBEDDING_COUNT, 0);
}

UUID sourceId = (UUID) processEmbeddingsParams.getAdditionalParams().get(Constants.IMAGE_ID);
String apiKey = processEmbeddingsParams.getApiKey();

List<EmbeddingVerificationProcessResult> results =
Arrays.stream(targets)
.map(target -> processTarget(target, sourceId, apiKey))
.sorted((e1, e2) -> Float.compare(e2.getSimilarity(), e1.getSimilarity()))
.collect(Collectors.toList());

return new EmbeddingsVerificationProcessResponse(results);
}

private EmbeddingVerificationProcessResult processTarget(double[] target, UUID sourceId, String apiKey) {
double similarity = predictor.verify(apiKey, target, sourceId);
float scaledSimilarity = BigDecimal.valueOf(similarity).setScale(5, HALF_UP).floatValue();
return new EmbeddingVerificationProcessResult(target, scaledSimilarity);
}
}

0 comments on commit 130ff41

Please sign in to comment.