Skip to content

Commit c4d674e

Browse files
committed
Merge branch 'master' of https://github.com/vacancy/Jacinle
* 'master' of https://github.com/vacancy/Jacinle: [master] sync mldash [master] mldash sync [master] bump jacmldash version [master] fix git command error, jactorch dataset supports sort and shuffle [master] as_numpy now does detach
2 parents e36abc2 + bf19ae0 commit c4d674e

File tree

4 files changed

+19
-5
lines changed

4 files changed

+19
-5
lines changed

jacinle/cli/git.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212

1313

1414
def get_git_revision_hash(short=False):
15-
if short:
16-
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'])
17-
return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('utf-8').strip()
15+
try:
16+
if short:
17+
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'])
18+
return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('utf-8').strip()
19+
except subprocess.CalledProcessError:
20+
return None
1821

1922

2023
def get_git_uncommitted_files():

jactorch/data/dataset.py

+11
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,17 @@ def repeat(self, nr_repeats):
207207
logger.critical('Repeat the dataset: #before={}, #after={}.'.format(len(self), len(indices)))
208208
return type(self)(self, indices=indices, filter_name='repeat[{}]'.format(nr_repeats))
209209

210+
def sort(self, key, key_name=None):
211+
if key_name is None:
212+
key_name = str(key)
213+
indices = sorted(range(len(self)), key=lambda x: key(self.get_metainfo(x)))
214+
return type(self)(self, indices=indices, filter_name='sort[{}]'.format(key_name))
215+
216+
def random_shuffle(self):
217+
indices = list(range(len(self)))
218+
random.shuffle(indices)
219+
return type(self)(self, indices=indices, filter_name='random_shuffle')
220+
210221
def __getitem__(self, index):
211222
if self.indices is None:
212223
return self.owner_dataset[index]

jactorch/utils/meta.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _as_numpy(o):
7676
if isinstance(o, Variable):
7777
o = o
7878
if torch.is_tensor(o):
79-
return o.cpu().numpy()
79+
return o.detach().cpu().numpy()
8080
return np.array(o)
8181

8282

0 commit comments

Comments
 (0)