diff --git a/init.sh b/init.sh index 4168db6..bda5ae4 100755 --- a/init.sh +++ b/init.sh @@ -3,4 +3,3 @@ # Linux/MacOS 下初始化 monotonic_align 模块 cd monotonic_align python3 setup.py build_ext --inplace -mv monotonic_align/*.so core.dll diff --git a/monotonic_align/__init__.py b/monotonic_align/__init__.py index aaa260a..b63d171 100644 --- a/monotonic_align/__init__.py +++ b/monotonic_align/__init__.py @@ -1,8 +1,6 @@ -import numpy as np -import torch -from ctypes import cdll -maximum_path_c = cdll.LoadLibrary('./monotonic_align/core.dll') - +from numpy import zeros, int32, float32 +from torch import from_numpy +from monotonic_align.monotonic_align.core import maximum_path_c def maximum_path(neg_cent, mask): """ Cython optimized version. @@ -11,10 +9,10 @@ def maximum_path(neg_cent, mask): """ device = neg_cent.device dtype = neg_cent.dtype - neg_cent = neg_cent.data.cpu().numpy().astype(np.float32) - path = np.zeros(neg_cent.shape, dtype=np.int32) + neg_cent = neg_cent.data.cpu().numpy().astype(float32) + path = zeros(neg_cent.shape, dtype=int32) - t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) - t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) + t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32) + t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32) maximum_path_c(path, neg_cent, t_t_max, t_s_max) - return torch.from_numpy(path).to(device=device, dtype=dtype) + return from_numpy(path).to(device=device, dtype=dtype)