Skip to content

Commit

Permalink
minor correction
Browse files Browse the repository at this point in the history
  • Loading branch information
belli13 committed Mar 15, 2021
1 parent bbb1a91 commit 4a42d23
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 26 deletions.
8 changes: 1 addition & 7 deletions test_DSL.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
{'source': 'Hollywood', 'path': os.path.join('data','Hollywood2','test')},
{'source': 'UCFSports', 'path': os.path.join('data','UCF','test')}]


dataset_index = 0
fromVideo=False
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
image_size=(128, 192)


dataset_source = source_datasets[dataset_index]['source']
encoder_pretrained = False

Expand Down Expand Up @@ -76,7 +74,6 @@ def main():
file_info.writelines(info)
file_info.close()


if fromVideo:
if dataset_source=='LEDOV' or dataset_source=='UAV123':
list_video= pd.read_csv(os.path.join('data',dataset_source,'test.csv'))['0'].values.tolist()
Expand Down Expand Up @@ -105,7 +102,7 @@ def main():

original_length= len(list_frames)

#if number of video frames are less of 2*lentemporal, we append the frames to the list by going back
#if the number of video frames are less of 2*lentemporal, we append the frames to the list by going back
if original_length<2*len_temporal-1:
num_missed_frames = 2*len_temporal -1 - original_length
for k in range(num_missed_frames):
Expand All @@ -122,7 +119,6 @@ def main():
for i in tqdm(range(len(list_frames))):
img = list_frames[i]


snippet.append(img)

if i<original_length:
Expand Down Expand Up @@ -172,9 +168,7 @@ def resized_frames_from_video(v, path_video):
print(os.path.join(path_video,v))
vidcap = cv2.VideoCapture(os.path.join(path_video,v))
success,image = vidcap.read()

frames=[]

success = True
while success:
image = cv2.resize(image, dsize=(image_size[1], image_size[0]),interpolation=cv2.INTER_CUBIC)
Expand Down
23 changes: 4 additions & 19 deletions train_DSL.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,14 @@
from dataset.videoDataset import Dataset3D
from dataset.infiniteDataLoader import InfiniteDataLoader

'''

source_datasets = [{'source': 'DHF1K', 'path': os.path.join('data','DHF1K','source')},
{'source': 'Hollywood', 'path': os.path.join('data','Hollywood2','train')},
{'source': 'UCFSports', 'path': os.path.join('data','UCF','train')}]

validation_datasets = [{'source': 'DHF1K', 'path': os.path.join('data','DHF1K','source')},
{'source': 'Hollywood', 'path': os.path.join('data','Hollywood2','train')},
{'source': 'UCFSports', 'path': os.path.join('data','UCF','train')}]
'''

source_datasets = [{'source': 'DHF1K', 'path': os.path.join("C:\\","Users","gbellitto","Desktop","GitRepository","video-saliency-detection","data","DHF1K","source")},
{'source': 'Hollywood', 'path': os.path.join("C:\\","Users","gbellitto","Desktop","GitRepository","video-saliency-detection","data",'Hollywood2','train')},
{'source': 'UCFSports', 'path': os.path.join("C:\\","Users","gbellitto","Desktop","GitRepository","video-saliency-detection","data",'UCF','train')}]

validation_datasets = [{'source': 'DHF1K', 'path': os.path.join("C:\\","Users","gbellitto","Desktop","GitRepository","video-saliency-detection","data",'DHF1K','source')},
{'source': 'Hollywood', 'path': os.path.join("C:\\","Users","gbellitto","Desktop","GitRepository","video-saliency-detection","data",'Hollywood2','train')},
{'source': 'UCFSports', 'path': os.path.join("C:\\","Users","gbellitto","Desktop","GitRepository","video-saliency-detection","data",'UCF','train')}]


def main():
Expand Down Expand Up @@ -87,7 +78,6 @@ def main():

for idx, p in enumerate(path_source_data):

#print(idx)
print(p)

if 'LEDOV' in p:
Expand Down Expand Up @@ -428,7 +418,7 @@ def main():
plt.savefig(os.path.join('output', subfolder, test_name,'loss_validation.png'))
plt.close()

#Plot validation loss Per-Dataset
#Plot validation per-dataset loss
x = torch.arange(1, len(check_point['per_dataset_loss'][f'val_{list_source_validation[0]}'])+1).numpy()
for s in list_source_validation:
plt.plot(x, check_point['per_dataset_loss'][f'val_{s}'], label=f"val_loss {s}")
Expand All @@ -452,15 +442,10 @@ def main():


def transform(snippet):
snippet = np.concatenate(snippet, axis=-1) # 224 x 384 x 24
snippet = torch.from_numpy(snippet).permute(2, 0, 1).contiguous().float() #24 x 224 x 384
snippet = np.concatenate(snippet, axis=-1)
snippet = torch.from_numpy(snippet).permute(2, 0, 1).contiguous().float()
snippet = snippet.mul_(2.).sub_(255).div(255)
# 1 x 8 x 3 x 224 x 384 1 x 3 x 8 x 224 x 384
# | |
snippet = snippet.view(1,-1,3,snippet.size(1),snippet.size(2)).permute(0,2,1,3,4)
''' a differenza del transform usato nel dataloader qui aggiungiuamo quell' 1 iniziale, a rappresentare
la dimensione del batch, perchè il modello vuole in input: batch x C x len_temporal x H x W
'''
return snippet

if __name__ == '__main__':
Expand Down

0 comments on commit 4a42d23

Please sign in to comment.