Skip to content

Commit

Permalink
resolve some warnings from the latest pytorch / chromvar_bug (reporte…
Browse files Browse the repository at this point in the history
…d by Conrad) /
  • Loading branch information
ruochiz committed Aug 13, 2024
1 parent 90456c4 commit f51f55d
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 199 deletions.
7 changes: 5 additions & 2 deletions scprinter/chromvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,11 @@ def compute_deviations(adata, chunk_size: int = 10000, device="cuda"):
temp_adata = adata[start:end].copy()
X_chunk = temp_adata.X
expectation_obs_chunk = backend.asarray(expectation_obs[start:end])
if sparse.isspmatrix(X_chunk) and device == "cuda":
X_chunk = scipy_to_cupy_sparse(X_chunk)
if sparse.isspmatrix(X_chunk):
if device == "cuda":
X_chunk = scipy_to_cupy_sparse(X_chunk)
else:
X_chunk = X_chunk.tocsr()
else:
X_chunk = backend.array(X_chunk)
res = _compute_deviations(
Expand Down
175 changes: 140 additions & 35 deletions scprinter/seq/Models.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,33 @@ def validation_step_footprint(model, validation_data, validation_size, dispmodel
)


def adjust_embedding_scale(embedding, coverages, coverage_in_lora, A_embedding):
# test A_output distribution
with torch.no_grad():

embedding.eval()
A_embedding.eval()
test_cell_num = min(100, embedding.weight.shape[0])

A_cells = embedding(torch.arange(test_cell_num).long())
if coverages is not None:
coverages_cells = coverages(torch.arange(test_cell_num).long())
else:
coverages_cells = None
if coverage_in_lora:
A_cells = torch.cat([A_cells, coverages_cells], dim=-1)
A_output = A_embedding(A_cells)
mean, std = A_output.mean(), A_output.std()

print("A_output mean: {}, std: {}".format(mean, std))
# self.scale *= 1 / (std * r)
rescale_factor = 1 / (std)
embedding.weight.data[...] *= rescale_factor
if coverage_in_lora:
coverages.weight.data[...] *= rescale_factor
return rescale_factor # rescale the embedding matrix


class seq2PRINT(nn.Module):
"""
This is the seq2PRINT model
Expand All @@ -143,6 +170,8 @@ class seq2PRINT(nn.Module):
The length of the output peak window
embeddings: np.ndarray | None
The embeddings to use for single cells or pseudobulks
coverages: np.ndarray | None
The bias terms to regress out when doing LoRA
rank: int
The rank to use for the LoRA model
hidden_dim: int
Expand All @@ -159,8 +188,8 @@ class seq2PRINT(nn.Module):
Whether to use LoRA on the output CNN model the footprint part
lora_count_cnn: bool
Whether to use LoRA on the output CNN model the count part
coverage: bool
Whether embedding contains the coverage information at the last dimension
coverage_in_lora: bool
Whether to use the coverage information in the LoRA finetuning weights
"""

def __init__(
Expand All @@ -171,6 +200,7 @@ def __init__(
dna_len=2114,
output_len=1000,
embeddings=None,
coverages=None,
rank=8,
hidden_dim=None,
n_lora_layers=0,
Expand All @@ -179,7 +209,7 @@ def __init__(
lora_pff_cnn=False,
lora_profile_cnn=False,
lora_count_cnn=False,
coverage=True,
coverage_in_lora=False,
):
super().__init__()
self.dna_cnn_model = dna_cnn_model
Expand All @@ -188,67 +218,96 @@ def __init__(
self.dna_len = dna_len
self.output_len = output_len

if coverages is not None:
self.coverages = nn.Embedding(coverages.shape[0], coverages.shape[1])
self.coverages.weight.data = torch.from_numpy(coverages).float()
self.coverages.weight.requires_grad = False
else:
self.coverages = None
if embeddings is not None:
if coverage:
coverages = embeddings[:, -1][:, None]
embeddings = embeddings[:, :-1]
self.coverages = nn.Embedding(coverages.shape[0], coverages.shape[1])
self.coverages.weight.data = torch.from_numpy(coverages).float()
self.coverages.weight.requires_grad = False
else:
self.coverages = None
self.embeddings = nn.Embedding(embeddings.shape[0], embeddings.shape[1])
self.embeddings.weight.data = torch.from_numpy(embeddings).float()
self.embeddings.weight.requires_grad = False
lora_embedding_dim = (
embeddings.shape[-1] + coverages.shape[-1]
if coverage_in_lora
else embeddings.shape[-1]
)
else:
self.embeddings = None
self.coverages = None
self.coverage_in_lora = coverage_in_lora

# make the LoRA models for the DNA CNN, hidden layer, and profile CNN
if lora_dna_cnn:
assert self.embeddings is not None, "Embeddings must be provided for LoRA"
self.dna_cnn_model.conv = Conv1dLoRA(
self.dna_cnn_model.conv,
A_embedding=self.embeddings,
B_embedding=self.embeddings,
A_embedding_dim=lora_embedding_dim,
B_embedding_dim=lora_embedding_dim,
r=rank,
hidden_dim=hidden_dim,
n_layers=n_lora_layers,
)
rescale_factor = adjust_embedding_scale(
self.embeddings,
self.coverages,
self.coverage_in_lora,
self.dna_cnn_model.conv.A_embedding,
)

hidden_layers = self.hidden_layer_model.layers
for i in range(len(hidden_layers)):
if lora_dilated_cnn:
assert self.embeddings is not None, "Embeddings must be provided for LoRA"
hidden_layers[i].module.conv1 = Conv1dLoRA(
hidden_layers[i].module.conv1,
A_embedding=self.embeddings,
B_embedding=self.embeddings,
A_embedding_dim=lora_embedding_dim,
B_embedding_dim=lora_embedding_dim,
r=rank,
hidden_dim=hidden_dim,
n_layers=n_lora_layers,
)
rescale_factor = adjust_embedding_scale(
self.embeddings,
self.coverages,
self.coverage_in_lora,
hidden_layers[i].module.conv1.A_embedding,
)

if lora_pff_cnn:
assert self.embeddings is not None, "Embeddings must be provided for LoRA"
hidden_layers[i].module.conv2 = Conv1dLoRA(
hidden_layers[i].module.conv2,
A_embedding=self.embeddings,
B_embedding=self.embeddings,
A_embedding_dim=lora_embedding_dim,
B_embedding_dim=lora_embedding_dim,
r=rank,
hidden_dim=hidden_dim,
n_layers=n_lora_layers,
)
rescale_factor = adjust_embedding_scale(
self.embeddings,
self.coverages,
self.coverage_in_lora,
hidden_layers[i].module.conv2.A_embedding,
)

if lora_profile_cnn:
assert self.embeddings is not None, "Embeddings must be provided for LoRA"
self.profile_cnn_model.conv_layer = Conv1dLoRA(
self.profile_cnn_model.conv_layer,
A_embedding=self.embeddings,
B_embedding=self.embeddings,
A_embedding_dim=lora_embedding_dim,
B_embedding_dim=lora_embedding_dim,
r=rank,
hidden_dim=hidden_dim,
n_layers=n_lora_layers,
)
rescale_factor = adjust_embedding_scale(
self.embeddings,
self.coverages,
self.coverage_in_lora,
self.profile_cnn_model.conv_layer.A_embedding,
)

# Historical code
# if isinstance(self.profile_cnn_model.linear, nn.Linear):
Expand All @@ -265,16 +324,23 @@ def __init__(
assert self.embeddings is not None, "Embeddings must be provided for LoRA"
self.profile_cnn_model.linear = Conv1dLoRA(
self.profile_cnn_model.linear,
A_embedding=self.embeddings,
B_embedding=self.embeddings,
A_embedding_dim=lora_embedding_dim,
B_embedding_dim=lora_embedding_dim,
r=1,
hidden_dim=hidden_dim,
n_layers=n_lora_layers,
)
rescale_factor = adjust_embedding_scale(
self.embeddings,
self.coverages,
self.coverage_in_lora,
self.profile_cnn_model.linear.A_embedding,
)

# Bias adjusted footprints head should come after the LoRA model
if self.coverages is not None:
self.profile_cnn_model = BiasAdjustedFootprintsHead(
self.profile_cnn_model, self.coverages
self.profile_cnn_model, self.coverages.weight.data.shape[1]
)

def return_origin(self):
Expand Down Expand Up @@ -309,7 +375,7 @@ def return_origin(self):

return model_clone

def collapse(self, cell, turn_on_grads=True):
def collapse(self, cell=None, turn_on_grads=True, A_cells=None, B_cells=None, coverages=None):
"""
This function collapses the LoRA model to a model for one cell or subset of cells.
Expand All @@ -327,25 +393,45 @@ def collapse(self, cell, turn_on_grads=True):

# self = self.to('cpu')
model_clone = deepcopy(self)

if cell is not None:
if type(cell) not in [int, list, np.ndarray, torch.Tensor]:
raise ValueError("cell must be integer(s)")
if type(cell) is int:
cell = [cell]
if self.embeddings is not None:
cell = torch.tensor(cell).long().to(self.embeddings.weight.data.device)
A_cells = self.embeddings(cell)
B_cells = self.embeddings(cell)
if self.coverages is not None:
coverages = self.coverages(cell)
else:
coverages = None
if self.coverage_in_lora:
A_cells = torch.cat([A_cells, coverages], dim=-1)
B_cells = torch.cat([B_cells, coverages], dim=-1)

if not isinstance(model_clone.dna_cnn_model.conv, Conv1dWrapper):
model_clone.dna_cnn_model.conv = model_clone.dna_cnn_model.conv.collapse_layer(cell)
model_clone.dna_cnn_model.conv = model_clone.dna_cnn_model.conv.collapse_layer(
A_cells, B_cells
)
if not isinstance(model_clone.hidden_layer_model.layers[0].module.conv1, Conv1dWrapper):
for layer in model_clone.hidden_layer_model.layers:
layer.module.conv1 = layer.module.conv1.collapse_layer(cell)
layer.module.conv1 = layer.module.conv1.collapse_layer(A_cells, B_cells)
if not isinstance(model_clone.hidden_layer_model.layers[0].module.conv2, Conv1dWrapper):
for layer in model_clone.hidden_layer_model.layers:
layer.module.conv2 = layer.module.conv2.collapse_layer(cell)
layer.module.conv2 = layer.module.conv2.collapse_layer(A_cells, B_cells)

if isinstance(model_clone.profile_cnn_model, BiasAdjustedFootprintsHead):
model = model_clone.profile_cnn_model.footprints_head
model_clone.profile_cnn_model.collapse_layer(cell)
model_clone.profile_cnn_model.collapse_layer(coverages)
else:
model = model_clone.profile_cnn_model

if not isinstance(model.conv_layer, Conv1dWrapper):
model.conv_layer = model.conv_layer.collapse_layer(cell)
model.conv_layer = model.conv_layer.collapse_layer(A_cells, B_cells)
if not isinstance(model.linear, Conv1dWrapper):
model.linear = model.linear.collapse_layer(cell)
model.linear = model.linear.collapse_layer(A_cells, B_cells)
if turn_on_grads:
for p in model_clone.parameters():
p.requires_grad = True
Expand Down Expand Up @@ -375,14 +461,31 @@ def forward(self, X, cells=None, output_len=None, modes=None):
"""
if output_len is None:
output_len = self.output_len
# get the motifs
X = self.dna_cnn_model(X, cells=cells)

# get the hidden layer
X = self.hidden_layer_model(X, cells=cells)
A_cells, B_cells, coverages = None, None, None
if (cells is not None) and (self.embeddings is not None):
A_cells = self.embeddings(cells)
B_cells = self.embeddings(cells)

if self.coverages is not None:
coverages = self.coverages(cells)
if self.coverage_in_lora:
A_cells = torch.cat([A_cells, coverages], dim=-1)
B_cells = torch.cat([B_cells, coverages], dim=-1)
# get the motifs
X = self.dna_cnn_model(X, A_cells=A_cells, B_cells=B_cells)
# get the hidden layer
X = self.hidden_layer_model(X, A_cells=A_cells, B_cells=B_cells)
# get the profile
return self.profile_cnn_model(X, cells=cells, output_len=output_len, modes=modes)
output = self.profile_cnn_model(
X,
A_cells=A_cells,
B_cells=B_cells,
coverages=coverages,
output_len=output_len,
modes=modes,
)
return output

def load_train_state_dict(self, ema, optimizer, scaler, savename):
"""
Expand Down Expand Up @@ -505,7 +608,7 @@ def fit(
if use_amp:
print("Using amp")

scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

if coverage_warming > 0:
# This is a dict to store the requires_grads for each parameter so we can revert back to the original state when coverage warming is done
Expand All @@ -521,6 +624,8 @@ def fit(
p.requires_grad = True
for p in self.profile_cnn_model.adjustment_footprint.parameters():
p.requires_grad = True
total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print("total trainable params", total_params)

for epoch in range(max_epochs):
moving_avg_loss = 0
Expand Down
Loading

0 comments on commit f51f55d

Please sign in to comment.