-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5a21265
commit b6217c5
Showing
43 changed files
with
1,189 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
newmtl material_0 | ||
Ka 0.200000 0.200000 0.200000 | ||
Kd 0.000000 0.000000 0.000000 | ||
Ks 1.000000 1.000000 1.000000 | ||
Tr 0.000000 | ||
illum 2 | ||
Ns 0.000000 | ||
map_Kd 1.jpg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import os | ||
|
||
|
||
def get_mtl(in_path='results/'): | ||
|
||
base_obj = open('data/template_mesh.obj', "r") | ||
base_lines = base_obj.readlines() | ||
|
||
base_mtl = open('data/base_mtl.mtl', "r") | ||
base_mtl_lines = base_mtl.readlines() | ||
|
||
ids = os.listdir(in_path) | ||
|
||
for idd in ids: | ||
|
||
exps = os.listdir(os.path.join(in_path, idd)) | ||
|
||
for exp in exps: | ||
|
||
files = os.listdir(os.path.join(in_path,idd,exp)) | ||
|
||
files.sort() | ||
|
||
for f in files: | ||
|
||
if 'obj' in f: | ||
|
||
obj_filename = os.path.join(in_path,idd,exp,f) | ||
|
||
a_file = open(obj_filename, "r") | ||
lines = a_file.readlines() | ||
#print(lines[56864]) | ||
|
||
base_lines[1:26318] = lines[:26317] | ||
base_lines[0] = 'mtllib ./' + f[:-4] + '.mtl\n' | ||
#print(lines[56864]) | ||
|
||
with open(obj_filename, 'w') as new_obj: | ||
new_obj.writelines(base_lines) | ||
new_obj.close() | ||
|
||
#print(f) | ||
with open(obj_filename[:-4] + '.mtl', 'w') as new_mtl: | ||
base_mtl_lines[7] = 'map_Kd ' + str(f[:-4]) + '.jpg' | ||
new_mtl.writelines(base_mtl_lines) | ||
new_mtl.close() | ||
|
||
|
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
from torchvision import datasets, transforms | ||
import numpy as np | ||
|
||
class CostumDataset(object): | ||
def __init__(self, args): | ||
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {} | ||
|
||
Data=np.load(args.dataset) | ||
Feature=Data[:,0:78951] | ||
Label_id = Data[:,78951] | ||
Label_ex = Data[:, 78952] | ||
tensor_x = torch.Tensor(Feature) # transform to torch tensor | ||
tensor_y = torch.Tensor(Label_id) | ||
tensor_z = torch.Tensor(Label_ex) | ||
|
||
#Original | ||
trainset = torch.utils.data.TensorDataset(tensor_x, tensor_y,tensor_z) # create your datset | ||
|
||
self.train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True,**kwargs) | ||
self.test_loader = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import sys | ||
sys.path.append(".") | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from shape_model.architectures import Encoder_identity, Encoder_expression, Decoder | ||
from shape_model.mesh_obj import mesh_obj | ||
|
||
device = "cpu" | ||
|
||
encoder_id = Encoder_identity(input_size=78951, num_features=100, num_classes=847).to(device) | ||
encoder_ex = Encoder_expression(input_size=78951, num_features=30, num_classes=20).to(device) | ||
decoder = Decoder(num_features=130,output_size=78951).to(device) | ||
|
||
encoder_id.load_state_dict(torch.load("./checkpoints/Encoder_id/2000", map_location="cpu")) | ||
encoder_ex.load_state_dict(torch.load("./checkpoints/Encoder_exp/2000", map_location="cpu")) | ||
decoder.load_state_dict(torch.load("./checkpoints/Decoder/2000", map_location="cpu")) | ||
|
||
train_data_disp = Data=np.load('./data/displace_data.npy') | ||
reduced_train_data = np.empty((train_data_disp.shape[0],132),dtype=np.float32) | ||
|
||
print(reduced_train_data.shape) | ||
for i,mesh in enumerate(train_data_disp): | ||
mesh_disp = mesh[:78951] | ||
mesh_label_id = mesh[78951] | ||
mesh_label_exp = mesh[78952] | ||
|
||
with torch.no_grad(): | ||
z_id, id_pred = encoder_id(torch.from_numpy(mesh_disp).float()) | ||
z_exp, exp_pred = encoder_ex(torch.from_numpy(mesh_disp).float()) | ||
|
||
reduced_train_data[i,:100] = z_id.detach().numpy() | ||
reduced_train_data[i,100:130] = z_exp.detach().numpy() | ||
reduced_train_data[i,130] = mesh_label_id-1.0 ## Converting 1-indexing to 0-indexing | ||
reduced_train_data[i,131] = mesh_label_exp-1.0 ## Converting 1-indexing to 0-indexing | ||
|
||
np.save("./data/reduced_train_data_test.npy", reduced_train_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import torch | ||
import torch.utils.data | ||
from torch import nn, optim | ||
from torch.nn import functional as F | ||
from torchvision import datasets, transforms | ||
|
||
import sys | ||
import os | ||
|
||
sys.path.append(".") | ||
|
||
from shape_model.architectures import Encoder_identity, Encoder_expression, Decoder | ||
from shape_model.ae.datasets import CostumDataset | ||
|
||
class Network(nn.Module): | ||
def __init__(self, args): | ||
super(Network, self).__init__() | ||
|
||
self.encoder1 = Encoder_identity(input_size=78951, num_features=100, num_classes=847) | ||
self.encoder2 = Encoder_expression(input_size=78951, num_features=30, num_classes=20) | ||
self.decoder = Decoder(num_features=130,output_size=78951) | ||
|
||
|
||
def encode1(self, x): | ||
return self.encoder1(x) | ||
|
||
def encode2(self,x): | ||
return self.encoder2(x) | ||
|
||
def decode(self, z): | ||
return self.decoder(z) | ||
|
||
def forward(self, x): | ||
id_feature,id_label = self.encode1(x) | ||
ex_feature,ex_label = self.encode2(x) | ||
z=torch.cat((id_feature,ex_feature),1) | ||
|
||
return self.decode(z),id_label,ex_label | ||
|
||
class AE(object): | ||
def __init__(self, args): | ||
self.args = args | ||
self.device = torch.device("cuda:7") | ||
self._init_dataset() | ||
self.train_loader = self.data.train_loader | ||
self.test_loader = self.data.test_loader | ||
|
||
self.model = Network(args) | ||
self.model.to(self.device) | ||
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4) | ||
|
||
def _init_dataset(self): | ||
self.data = CostumDataset(self.args) | ||
|
||
def loss_function(self, recon_x, x,id_label_pr,ex_label_pr,id_label_gr,ex_label_gr): | ||
L1=torch.nn.L1Loss() | ||
L2=torch.nn.CrossEntropyLoss() | ||
BCE = L1(recon_x, x.view(-1, 78951)) | ||
id_label_gr=id_label_gr.long() | ||
ex_label_gr = ex_label_gr.long() | ||
CEL_id= L2(id_label_pr, id_label_gr) | ||
CEL_ex = L2(ex_label_pr, ex_label_gr) | ||
|
||
return (BCE+CEL_id+CEL_ex),BCE,CEL_id,CEL_ex | ||
|
||
def train(self, epoch): | ||
self.model.train() | ||
train_loss = 0 | ||
recon_loss = 0 | ||
for batch_idx, (data,label_id,label_ex) in enumerate(self.train_loader): | ||
data = data.to(self.device) | ||
label_id = label_id.to(self.device) | ||
label_ex = label_ex.to(self.device) | ||
|
||
self.optimizer.zero_grad() | ||
recon_batch, id_label_pr,ex_label_pr = self.model(data) | ||
id_ac = torch.sum(torch.eq(torch.argmax(id_label_pr, dim=1),label_id)) | ||
ex_ac = torch.sum(torch.eq(torch.argmax(ex_label_pr, dim=1), label_ex)) | ||
|
||
loss,BCE,CEL_id,CEL_ex = self.loss_function(recon_batch, data,id_label_pr,ex_label_pr,label_id,label_ex) | ||
loss.backward() | ||
train_loss += loss.item() | ||
recon_loss += BCE.item() | ||
self.optimizer.step() | ||
if batch_idx % self.args.log_interval == 0: | ||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, recon:{:.6f}, id_loss:{:.6f}, ex_loss:{:.6f}'.format( | ||
epoch, batch_idx * len(data), len(self.train_loader.dataset), | ||
100. * batch_idx / len(self.train_loader), | ||
loss.item() / len(data),BCE.item()/len(data),id_ac.item()/len(data),ex_ac.item()/len(data))) | ||
|
||
print('====> Epoch: {} Average loss: {:.4f} Recon_loss: {:.4f}'.format( | ||
epoch, train_loss / len(self.train_loader.dataset), recon_loss / len(self.train_loader.dataset))) | ||
if epoch%100==0: | ||
torch.save(self.model.encoder1.state_dict(), os.path.join(self.args.path_enc_id, str(int(epoch)))) | ||
torch.save(self.model.encoder2.state_dict(), os.path.join(self.args.path_enc_exp, str(int(epoch)))) | ||
torch.save(self.model.decoder.state_dict(), os.path.join(self.args.path_dec, str(int(epoch)))) | ||
|
||
def test(self, epoch): | ||
self.model.eval() | ||
test_loss = 0 | ||
recon_loss = 0 | ||
with torch.no_grad(): | ||
for batch_idx, (data,label_id,label_ex) in enumerate(self.test_loader): | ||
data = data.to(self.device) | ||
label_id = label_id.to(self.device) | ||
label_ex = label_ex.to(self.device) | ||
|
||
self.optimizer.zero_grad() | ||
recon_batch, id_label_pr,ex_label_pr = self.model(data) | ||
id_ac = torch.sum(torch.eq(torch.argmax(id_label_pr, dim=1),label_id)) | ||
ex_ac = torch.sum(torch.eq(torch.argmax(ex_label_pr, dim=1), label_ex)) | ||
loss,BCE,CEL_id,CEL_ex = self.loss_function(recon_batch, data,id_label_pr,ex_label_pr,label_id,label_ex) | ||
test_loss += loss.item() | ||
recon_loss += BCE.item() | ||
if batch_idx % self.args.log_interval == 0: | ||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, recon:{:.6f}, id_loss:{:.6f}, ex_loss:{:.6f}'.format( | ||
epoch, batch_idx * len(data), len(self.test_loader.dataset), | ||
100. * batch_idx / len(self.test_loader), | ||
loss.item() / len(data),BCE.item()/len(data),id_ac.item()/len(data),ex_ac.item()/len(data))) | ||
|
||
print('====> Epoch: {} Average loss: {:.4f} Recon_loss: {:.4f}'.format( | ||
epoch, test_loss / len(self.test_loader.dataset), recon_loss / len(self.test_loader.dataset))) |
Binary file not shown.
Oops, something went wrong.