Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase head #103

Open
wants to merge 4 commits into
base: core
Choose a base branch
from
Open

Increase head #103

wants to merge 4 commits into from

Conversation

kirilklein
Copy link
Owner

@kirilklein kirilklein commented Jan 5, 2025

Summary by CodeRabbit

  • New Features

    • Enhanced model classification architecture with multi-layer neural network classifiers.
    • Added dropout regularization to improve model performance.
    • New classifier options introduced: BigHead and StandardHead.
  • Improvements

    • Updated cross-validation function to support additional run tracking parameter.
    • Enhanced flexibility in counterfactual data creation by preserving original data structure.
  • Technical Updates

    • Modified neural network model heads to include more complex classification layers.
    • Updated configuration to specify classifier type.

Copy link

coderabbitai bot commented Jan 5, 2025

Walkthrough

The pull request introduces modifications across several files in the ehr2vec project. In 04_finetune_cv.py, a run parameter is added to the finetune_fold function call for tracking purposes. The model/heads.py file sees the introduction of new classifier classes, BigHead and StandardHead, replacing simple linear layers in the BaseRNN and FineTuneHead classes, enhancing the classification mechanism. Additionally, the configuration file is updated to specify the classifier type, and changes are made to the create_counterfactual_data function to improve its flexibility while preserving the original data structure.

Changes

File Change Summary
ehr2vec/main/04_finetune_cv.py Added run parameter to finetune_fold function call
ehr2vec/model/heads.py Introduced BigHead and StandardHead classes; replaced linear classifiers in BaseRNN and FineTuneHead with conditional assignments to new classes
ehr2vec/configs/example_configs/05_02_finetune_simulated.yaml Added new parameter classifier: "big" in the model section
ehr2vec/double_robust/counterfactual.py Updated create_counterfactual_data function to allow optional exposure_regex and control_regex parameters

Possibly related PRs

  • Compilation #88: Modifications to the load_state_dict_into_model method, not directly related to the main PR changes.
  • Improve freq #100: Discusses improvements related to exposure handling, aligning with the enhancements in the main PR regarding model training.
  • Fix sampler #101: Focuses on the initialization of training components, relevant to the changes made in the finetune_fold function.

Poem

🐰 In code we hop, with joy and cheer,
New classifiers bloom, our path is clear,
With run in tow, we track our quest,
A rabbit's heart beats in every test.
Layers and heads, a dance of delight,
Together we code, into the night! 🌙


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR. (Beta)
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
ehr2vec/model/heads.py (2)

58-64: Consider making dropout and hidden size configurable.
Using hardcoded values (128 for hidden size and 0.1 for dropout) restricts the flexibility of your model. Providing these as configuration options allows for easier experimentation and tuning.

Here's a sample diff to fetch them from the config (with a fallback to the current values):

-        classifier_hidden = 128
+        classifier_hidden = config.to_dict().get("classifier_hidden_dim", 128)
...
-            nn.Dropout(0.1),
+            nn.Dropout(config.to_dict().get("classifier_dropout", 0.1)),

145-151: Unify the configuration retrieval for the classifier parameters.
Similar to BaseRNN, using configurable parameters for hidden size and dropout in FineTuneHead fosters consistency and makes it easier to share hyperparameters across multiple classes.

Below is a sample diff for consistency:

-            classifier_hidden = 128
+            classifier_hidden = config.to_dict().get("classifier_hidden_dim", 128)
...
-            nn.Dropout(0.1),
+            nn.Dropout(config.to_dict().get("classifier_dropout", 0.1)),
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a3f9fcf and f2e21a7.

📒 Files selected for processing (2)
  • ehr2vec/main/04_finetune_cv.py (1 hunks)
  • ehr2vec/model/heads.py (2 hunks)
🔇 Additional comments (1)
ehr2vec/main/04_finetune_cv.py (1)

198-198: Good practice adding run parameter to finetune_fold.
This enhancement allows for more dynamic logging or tracking (e.g., with W&B), aligning with common MLOps patterns.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (5)
ehr2vec/model/heads.py (2)

198-210: Consider parameterizing the BigHead configuration.

The implementation looks good, but consider making the following improvements:

  1. Add configuration parameters for classifier_hidden and dropout instead of hard-coding
  2. Add a docstring explaining the purpose and architecture of this classifier
 class BigHead(nn.Module):
     def __init__(self, input_size):
         super().__init__()
-        self.classifier_hidden = 128
+        self.classifier_hidden = config.classifier_hidden_size
         self.classifier = nn.Sequential(
             nn.Linear(input_size, self.classifier_hidden),
             nn.ReLU(),
-            nn.Dropout(0.1),
+            nn.Dropout(config.classifier_dropout),
             nn.Linear(self.classifier_hidden, 1),
         )

58-64: Consider using enum for classifier types.

Instead of string literals, consider using an enum or constants for classifier types to prevent typos and enable better IDE support.

from enum import Enum

class ClassifierType(Enum):
    BIG = "big"
    STANDARD = "standard"
ehr2vec/configs/example_configs/05_02_finetune_simulated.yaml (1)

18-18: Document available classifier options.

Please add a comment documenting the available classifier options ("big" or "standard") and their implications.

+  # Classifier type: "big" (MLP with dropout) or "standard" (linear)
   classifier: "big"
ehr2vec/double_robust/counterfactual.py (2)

14-14: Improve type hints for optional parameters.

Consider using Optional from typing module for optional parameters and document the behavior in the docstring.

-def create_counterfactual_data(data: Data, exposure_regex: List[str] = None, control_regex: str = None) -> Data:
+from typing import Optional
+def create_counterfactual_data(
+    data: Data,
+    exposure_regex: Optional[List[str]] = None,
+    control_regex: Optional[str] = None
+) -> Data:

Also applies to: 19-27


61-70: Simplify the conditional logic in the return statement.

The current implementation can be simplified using a conditional expression.

-        exposed_patients=(
-            [
-                pid
-                for pid in data.pids
-                if counterfactual_exposures[data.pids.index(pid)] == 1
-            ]
-            if counterfactual_exposures is not None
-            else []
-        ),
+        exposed_patients=[
+            pid for pid in data.pids
+            if counterfactual_exposures and counterfactual_exposures[data.pids.index(pid)] == 1
+        ],
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f2e21a7 and 3c3c6fb.

📒 Files selected for processing (3)
  • ehr2vec/configs/example_configs/05_02_finetune_simulated.yaml (1 hunks)
  • ehr2vec/double_robust/counterfactual.py (2 hunks)
  • ehr2vec/model/heads.py (4 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (2)
  • GitHub Check: tests
  • GitHub Check: tests
🔇 Additional comments (2)
ehr2vec/model/heads.py (1)

213-219: LGTM!

Clean implementation of a standard linear classifier.

ehr2vec/double_robust/counterfactual.py (1)

51-54: LGTM! Good practice to avoid modifying input data.

Creating a new list instead of modifying the original preserves data immutability.

Comment on lines +146 to +153
classifier_input_size = config.hidden_size + self.exposure_dim
if config.to_dict().get("classifier", None) is not None:
if config.classifier == "big":
self.classifier = BigHead(classifier_input_size)
else:
self.classifier = StandardHead(classifier_input_size)
else:
self.classifier = StandardHead(classifier_input_size)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Extract classifier creation logic to avoid duplication.

The classifier creation logic is duplicated between BaseRNN and FineTuneHead. Consider extracting it to a factory method.

def create_classifier(config, input_size):
    """Create classifier based on config."""
    if config.to_dict().get("classifier", None) is not None:
        if config.classifier == "big":
            return BigHead(input_size)
    return StandardHead(input_size)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant