forked from 1ytic/warp-rnnt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcore.h
31 lines (24 loc) · 1004 Bytes
/
core.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#ifndef RNNT_CORE_H
#define RNNT_CORE_H
#include <cuda_runtime.h>
typedef enum {
RNNT_STATUS_SUCCESS = 0,
RNNT_STATUS_WARP_FAILED = 1,
RNNT_STATUS_GRADS_BLANK_FAILED = 2,
RNNT_STATUS_GRADS_LABEL_FAILED = 3,
RNNT_STATUS_COSTS_FAILED = 4
} rnntStatus_t;
#ifdef __cplusplus
#include <cstddef>
extern "C" {
#endif
rnntStatus_t run_warp_rnnt(cudaStream_t stream, unsigned int *counts, float *alphas, float *betas,
const int *labels, const float *log_probs, float *grads, float *costs,
const int *xn, const int *yn, int N, int T, int U, int V, int blank, float fastemit_lambda);
rnntStatus_t run_warp_rnnt_gather(cudaStream_t stream, unsigned int *counts, float *alphas, float *betas,
const float *log_probs, float *grads, float *costs,
const int *xn, const int *yn, int N, int T, int U, float fastemit_lambda);
#ifdef __cplusplus
}
#endif
#endif //RNNT_CORE_H