Skip to content

Commit

Permalink
add pytorch code
Browse files Browse the repository at this point in the history
  • Loading branch information
caozhangjie committed Mar 22, 2018
1 parent 8e2a509 commit 10d715a
Show file tree
Hide file tree
Showing 11 changed files with 5,144 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ This is the transfer learning library for the following paper:
### Unsupervised Domain Adaptation with Residual Transfer Networks
### Deep Transfer Learning with Joint Adaptation Networks

The pytorch and tensorflow versions are under developing.
The tensorflow versions are under developing.

## Citation
If you use this code for your research, please consider citing:
Expand Down
72 changes: 72 additions & 0 deletions pytorch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Deep Transfer Learning on PyTorch

This is a PyTorch library for deep transfer learning. We use the PyTorch version 0.2.0\_3.

Data Preparation
---------------
In `data/office/*.txt`, we give the lists of three domains in [Office](https://cs.stanford.edu/~jhoffman/domainadapt/#datasets_code) dataset.

Training Model
---------------

## Citation
If you use this library for your research, we would be pleased if you cite the following papers:

```
@inproceedings{DBLP:conf/icml/LongC0J15,
author = {Mingsheng Long and
Yue Cao and
Jianmin Wang and
Michael I. Jordan},
title = {Learning Transferable Features with Deep Adaptation Networks},
booktitle = {Proceedings of the 32nd International Conference on Machine Learning,
{ICML} 2015, Lille, France, 6-11 July 2015},
pages = {97--105},
year = {2015},
crossref = {DBLP:conf/icml/2015},
url = {http://jmlr.org/proceedings/papers/v37/long15.html},
timestamp = {Tue, 12 Jul 2016 21:51:15 +0200},
biburl = {http://dblp2.uni-trier.de/rec/bib/conf/icml/LongC0J15},
bibsource = {dblp computer science bibliography, http://dblp.org}
}
@inproceedings{DBLP:conf/nips/LongZ0J16,
author = {Mingsheng Long and
Han Zhu and
Jianmin Wang and
Michael I. Jordan},
title = {Unsupervised Domain Adaptation with Residual Transfer Networks},
booktitle = {Advances in Neural Information Processing Systems 29: Annual Conference
on Neural Information Processing Systems 2016, December 5-10, 2016,
Barcelona, Spain},
pages = {136--144},
year = {2016},
crossref = {DBLP:conf/nips/2016},
url = {http://papers.nips.cc/paper/6110-unsupervised-domain-adaptation-with-residual-transfer-networks},
timestamp = {Fri, 16 Dec 2016 19:45:58 +0100},
biburl = {http://dblp.uni-trier.de/rec/bib/conf/nips/LongZ0J16},
bibsource = {dblp computer science bibliography, http://dblp.org}
}
@inproceedings{DBLP:conf/icml/LongZ0J17,
author = {Mingsheng Long and
Han Zhu and
Jianmin Wang and
Michael I. Jordan},
title = {Deep Transfer Learning with Joint Adaptation Networks},
booktitle = {Proceedings of the 34th International Conference on Machine Learning,
{ICML} 2017, Sydney, NSW, Australia, 6-11 August 2017},
pages = {2208--2217},
year = {2017},
crossref = {DBLP:conf/icml/2017},
url = {http://proceedings.mlr.press/v70/long17a.html},
timestamp = {Tue, 25 Jul 2017 17:27:57 +0200},
biburl = {http://dblp.uni-trier.de/rec/bib/conf/icml/LongZ0J17},
bibsource = {dblp computer science bibliography, http://dblp.org}
}
```

## Contact
If you have any problem about this library, please create an Issue or send us an Email at:
- [email protected]
- [email protected]
2,817 changes: 2,817 additions & 0 deletions pytorch/data/office/amazon_list.txt

Large diffs are not rendered by default.

498 changes: 498 additions & 0 deletions pytorch/data/office/dslr_list.txt

Large diffs are not rendered by default.

795 changes: 795 additions & 0 deletions pytorch/data/office/webcam_list.txt

Large diffs are not rendered by default.

203 changes: 203 additions & 0 deletions pytorch/src/data_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
#from __future__ import print_function, division

import torch
import numpy as np
from sklearn.preprocessing import StandardScaler
import random
from PIL import Image
import torch.utils.data as data
import os
import os.path

class TextData():
def __init__(self, text_file, label_file, source_batch_size=64, target_batch_size=64, val_batch_size=4):
all_text = np.load(text_file)
self.source_text = all_text[0:92664, :]
self.target_text = all_text[92664:, :]
self.val_text = all_text[0:92664, :]
all_label = np.load(label_file)
self.label_source = all_label[0:92664, :]
self.label_target = all_label[92664:, :]
self.label_val = all_label[0:92664, :]
self.scaler = StandardScaler().fit(all_text)
self.source_id = 0
self.target_id = 0
self.val_id = 0
self.source_size = self.source_text.shape[0]
self.target_size = self.target_text.shape[0]
self.val_size = self.val_text.shape[0]
self.source_batch_size = source_batch_size
self.target_batch_size = target_batch_size
self.val_batch_size = val_batch_size
self.source_list = random.sample(range(self.source_size), self.source_size)
self.target_list = random.sample(range(self.target_size), self.target_size)
self.val_list = random.sample(range(self.val_size), self.val_size)
self.feature_dim = self.source_text.shape[1]

def next_batch(self, train=True):
data = []
label = []
if train:
remaining = self.source_size - self.source_id
start = self.source_id
if remaining <= self.source_batch_size:
for i in self.source_list[start:]:
data.append(self.source_text[i, :])
label.append(self.label_source[i, :])
self.source_id += 1
self.source_list = random.sample(range(self.source_size), self.source_size)
self.source_id = 0
for i in self.source_list[0:(self.source_batch_size-remaining)]:
data.append(self.source_text[i, :])
label.append(self.label_source[i, :])
self.source_id += 1
else:
for i in self.source_list[start:start+self.source_batch_size]:
data.append(self.source_text[i, :])
label.append(self.label_source[i, :])
self.source_id += 1
remaining = self.target_size - self.target_id
start = self.target_id
if remaining <= self.target_batch_size:
for i in self.target_list[start:]:
data.append(self.target_text[i, :])
# no target label
#label.append(self.label_target[i, :])
self.target_id += 1
self.target_list = random.sample(range(self.target_size), self.target_size)
self.target_id = 0
for i in self.target_list[0:self.target_batch_size-remaining]:
data.append(self.target_text[i, :])
#label.append(self.label_target[i, :])
self.target_id += 1
else:
for i in self.target_list[start:start+self.target_batch_size]:
data.append(self.target_text[i, :])
#label.append(self.label_target[i, :])
self.target_id += 1
else:
remaining = self.val_size - self.val_id
start = self.val_id
if remaining <= self.val_batch_size:
for i in self.val_list[start:]:
data.append(self.val_text[i, :])
label.append(self.label_val[i, :])
self.val_id += 1
self.val_list = random.sample(range(self.val_size), self.val_size)
self.val_id = 0
for i in self.val_list[0:self.val_batch_size-remaining]:
data.append(self.val_text[i, :])
label.append(self.label_val[i, :])
self.val_id += 1
else:
for i in self.val_list[start:start+self.val_batch_size]:
data.append(self.val_text[i, :])
label.append(self.label_val[i, :])
self.val_id += 1
data = self.scaler.transform(np.vstack(data))
label = np.vstack(label)
return torch.from_numpy(data).float(),torch.from_numpy(label).float()


def make_dataset(image_list, labels):
if labels:
len_ = len(image_list)
images = [(image_list[i].strip(), labels[i, :]) for i in xrange(len_)]
else:
if len(image_list[0].split()) > 2:
images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
else:
images = [(val.split()[0], int(val.split()[1])) for val in image_list]
return images


def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')


def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)


def default_loader(path):
#from torchvision import get_image_backend
#if get_image_backend() == 'accimage':
# return accimage_loader(path)
#else:
return pil_loader(path)


class ImageList(object):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""

def __init__(self, image_list, labels=None, transform=None, target_transform=None,
loader=default_loader):
imgs = make_dataset(image_list, labels)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader

def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
return len(self.imgs)

def ClassSamplingImageList(image_list, transform, return_keys=False):
data = open(image_list).readlines()
label_dict = {}
for line in data:
label_dict[int(line.split()[1])] = []
for line in data:
label_dict[int(line.split()[1])].append(line)
all_image_list = {}
for i in label_dict.keys():
all_image_list[i] = ImageList(label_dict[i], transform=transform)
if return_keys:
return all_image_list, label_dict.keys()
else:
return all_image_list
71 changes: 71 additions & 0 deletions pytorch/src/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable

def EntropyLoss(input_):
mask = input_.ge(0.000001)
mask_out = torch.masked_select(input_, mask)
entropy = -(torch.sum(mask_out * torch.log(mask_out)))
return entropy / float(input_.size(0))

def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.size()[0])+int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
L2_distance = ((total0-total1)**2).sum(2)
if fix_sigma:
bandwidth = fix_sigma
else:
bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
return sum(kernel_val)#/len(kernel_val)


def DAN(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
batch_size = int(source.size()[0])
kernels = guassian_kernel(source, target,
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
loss = 0
for i in range(batch_size):
s1, s2 = i, (i+1)%batch_size
t1, t2 = s1+batch_size, s2+batch_size
loss += kernels[s1, s2] + kernels[t1, t2]
loss -= kernels[s1, t2] + kernels[s2, t1]
return loss / float(batch_size)

def RTN():
pass


def JAN(source_list, target_list, kernel_muls=[2.0, 2.0], kernel_nums=[5, 1], fix_sigma_list=[None, 1.68]):
batch_size = int(source_list[0].size()[0])
layer_num = len(source_list)
joint_kernels = None
for i in range(layer_num):
source = source_list[i]
target = target_list[i]
kernel_mul = kernel_muls[i]
kernel_num = kernel_nums[i]
fix_sigma = fix_sigma_list[i]
kernels = guassian_kernel(source, target,
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
if joint_kernels is not None:
joint_kernels = joint_kernels * kernels
else:
joint_kernels = kernels

loss = 0
for i in range(batch_size):
s1, s2 = i, (i+1)%batch_size
t1, t2 = s1+batch_size, s2+batch_size
loss += joint_kernels[s1, s2] + joint_kernels[t1, t2]
loss -= joint_kernels[s1, t2] + joint_kernels[s2, t1]
return loss / float(batch_size)



loss_dict = {"DAN":DAN, "RTN":RTN, "JAN":JAN}
13 changes: 13 additions & 0 deletions pytorch/src/lr_schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
def inv_lr_scheduler(param_lr, optimizer, iter_num, gamma, power, init_lr=0.001):
"""Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
lr = init_lr * (1 + gamma * iter_num) ** (-power)

i=0
for param_group in optimizer.param_groups:
param_group['lr'] = lr * param_lr[i]
i+=1

return optimizer


schedule_dict = {"inv":inv_lr_scheduler}
Loading

0 comments on commit 10d715a

Please sign in to comment.