Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error when trying to populate PoseEstimation #1207

Closed
MichaelCoulter opened this issue Jan 6, 2025 · 6 comments · Fixed by #1208
Closed

error when trying to populate PoseEstimation #1207

MichaelCoulter opened this issue Jan 6, 2025 · 6 comments · Fixed by #1208

Comments

@MichaelCoulter
Copy link
Collaborator

this is the code i am running

#RS21 2nd half (after model training)
video_lists = [[
    {"nwb_file_name": "RS2120241016_.nwb", "epoch": 2},
    {"nwb_file_name": "RS2120241016_.nwb", "epoch": 4},
    {"nwb_file_name": "RS2120241016_.nwb", "epoch": 6},
    {"nwb_file_name": "RS2120241016_.nwb", "epoch": 8},
    {"nwb_file_name": "RS2120241016_.nwb", "epoch": 10}],
                   [
    {"nwb_file_name": "RS2120241017_.nwb", "epoch": 2},
    {"nwb_file_name": "RS2120241017_.nwb", "epoch": 4},
    {"nwb_file_name": "RS2120241017_.nwb", "epoch": 6},
    {"nwb_file_name": "RS2120241017_.nwb", "epoch": 8},],
                   [
    {"nwb_file_name": "RS2120241018_.nwb", "epoch": 2},
    {"nwb_file_name": "RS2120241018_.nwb", "epoch": 4},
    {"nwb_file_name": "RS2120241018_.nwb", "epoch": 6},
    {"nwb_file_name": "RS2120241018_.nwb", "epoch": 8},
    {"nwb_file_name": "RS2120241018_.nwb", "epoch": 10}],]

project_names = ["RS2120241016_run","RS2120241017_run","RS2120241018_run",]
nwbs = ["RS2120241016_.nwb","RS2120241017_.nwb","RS2120241018_.nwb",]
vids = ['*20241016_RS21*','*20241017_RS21*','*20241018_RS21*',]
team_name = (sgc.LabTeam & {"team_name": "mcoulter section"}).fetch("team_name")[0]  # If on lab DB, "LorenLab"
frames_per_video = 20
bodyparts = ["redLED_C", "greenLED", "tailBase"]
for n in range(0,3):
    video_list = video_lists[n]
    project_name = project_names[n]
    
    project_key = sgp.DLCProject.insert_new_project(
    project_name=project_name,
    bodyparts=bodyparts,
    lab_team=team_name,
    frames_per_video=frames_per_video,
    video_list=video_list,
    skip_duplicates=True,
    )
    
    gputouse = 1
    
    training_params_name = "tutorial"
    sgp.DLCModelTrainingParams.insert_new_params(
        paramset_name=training_params_name,
        params={
            "trainingsetindex": 0,
            "shuffle": 1,
            "gputouse": gputouse,
            "net_type": "resnet_50",
            "augmenter_type": "imgaug",
        },
        skip_duplicates=True,
    )

    if "config_path" in project_key:
        del project_key["config_path"]
    
    sgp.DLCModelTrainingSelection().insert1(
       {
           **project_key,
           "dlc_training_params_name": training_params_name,
           "training_id": 0,
           "model_prefix": "",
       },
        skip_duplicates=True,
    )
    model_training_key = (
        sgp.DLCModelTrainingSelection
        & {
            **project_key,
            "dlc_training_params_name": training_params_name,
        }
    ).fetch1("KEY")
    
    temp_model_key = (sgp.DLCModelSource & model_training_key).fetch1("KEY")
    
    # comment these lines out after successfully inserting, for each project
    sgp.DLCModelSelection().insert1(
        {**temp_model_key, "dlc_model_params_name": "default"}, skip_duplicates=True
    )

    model_key = (sgp.DLCModelSelection & temp_model_key).fetch1("KEY")
    sgp.DLCModel.populate(model_key)

    camera_name = "MEC_run_camera"
    nwb_file_name = nwbs[n]
    matching_rows = sgc.VideoFile() & {"camera_name": camera_name} & {"nwb_file_name": nwb_file_name}
    print(matching_rows)
    
    ! find /nimbus/deeplabcut/video -type f -name vids[n] -delete
    for row in matching_rows:
        col1val = row["nwb_file_name"]
        if nwb_file_name in col1val:  # *** change depending on rat/day!!!
            col2val = row["epoch"]
            col3val = row["video_file_num"]
            ##insert pose estimation task
            pose_estimation_key = (
                sgp.DLCPoseEstimationSelection().insert_estimation_task(
                    {
                        "nwb_file_name": col1val,
                        "epoch": col2val,
                        "video_file_num": col3val,
                        **model_key,
                    },
                    task_mode="trigger",  # load or trigger
                    params={"gputouse": gputouse, "videotype": "mp4"},
                )
            )
            
            ##populate DLC Pose Estimation
            sgp.DLCPoseEstimation().populate(pose_estimation_key)
    

and this is the error.

[15:16:25][WARNING] Spyglass: project name: RS2120241016_run is already in use.
WARNING:spyglass:project name: RS2120241016_run is already in use.
[15:16:25][INFO] Spyglass: New param set not added
A param set with name: tutorial already exists
INFO:spyglass:New param set not added
A param set with name: tutorial already exists
*nwb_file_name *epoch    *video_file_nu camera_name    video_file_obj
+------------+ +-------+ +------------+ +------------+ +------------+
RS2120241016_. 2         1              MEC_run_camera 3b76b894-d00c-
RS2120241016_. 4         1              MEC_run_camera 24616ada-16e2-
RS2120241016_. 6         1              MEC_run_camera 165946a8-f87d-
RS2120241016_. 8         1              MEC_run_camera 1f1a105e-ef0b-
RS2120241016_. 10        1              MEC_run_camera 15e3a344-c670-
 (Total: 5)

/home/mcoulter/mambaforge/envs/spyglass-position/lib/python3.9/site-packages/hdmf/container.py:420: UserWarning: The table for this DynamicTableRegion has not been added to the parent.
  warn(msg)
INFO:spyglass:Pose Estimation Selection
INFO:spyglass:video_dir: /stelmo/nwb/video/
INFO:spyglass:/stelmo/nwb/deeplabcut/video/20241016_RS21_02_r1.2.mp4 already exists, skipping conversion
[15:16:27][INFO] Spyglass: inserted entry into Pose Estimation Selection
INFO:spyglass:inserted entry into Pose Estimation Selection
[15:16:27][INFO] Spyglass: ----------------------
INFO:spyglass:----------------------
[15:16:27][INFO] Spyglass: Pose Estimation
INFO:spyglass:Pose Estimation
/home/mcoulter/mambaforge/envs/spyglass-position/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:1694: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.
  warnings.warn('`layer.apply` is deprecated and '
Using snapshot-1030000 for model /nimbus/deeplabcut/projects/RS2120241016_run-mcoulter_section-2024-12-02/dlc-models/iteration-0/RS2120241016_runDec2-trainset95shuffle1
2025-01-06 15:16:30.224305: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78948 MB memory:  -> device: 0, name: NVIDIA A100 80GB PCIe, pci bus id: 0000:52:00.0, compute capability: 8.0
[15:16:31][WARNING] Spyglass: DEPRECATION scheduled for version 0.6: dlc_reader: PoseEstimation
WARNING:spyglass:DEPRECATION scheduled for version 0.6: dlc_reader: PoseEstimation
[15:16:31][INFO] Spyglass: getting raw position
INFO:spyglass:getting raw position
Starting to analyze %  /stelmo/nwb/deeplabcut/video/20241016_RS21_02_r1.2.mp4
The videos are analyzed. Now your research can truly start! 
 You can create labeled videos with 'create_labeled_video'
If the tracking is not satisfactory for some videos, consider expanding the training set. You can use the function 'extract_outlier_frames' to extract a few representative outlier frames.
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[7], line 110
     96 pose_estimation_key = (
     97     sgp.DLCPoseEstimationSelection().insert_estimation_task(
     98         {
   (...)
    106     )
    107 )
    109 ##populate DLC Pose Estimation
--> 110 sgp.DLCPoseEstimation().populate(pose_estimation_key)
    112 ##start smooth interpolation
    113 si_params_name = "just_nan"

File ~/spyglass/src/spyglass/utils/dj_mixin.py:608, in SpyglassMixin.populate(self, *restrictions, **kwargs)
    606 if use_transact:  # Pass single-process populate to super
    607     kwargs["processes"] = processes
--> 608     return super().populate(*restrictions, **kwargs)
    609 else:  # No transaction protection, use bare make
    610     for key in keys:

File ~/mambaforge/envs/spyglass-position/lib/python3.9/site-packages/datajoint/autopopulate.py:241, in AutoPopulate.populate(self, suppress_errors, return_exception_objects, reserve_jobs, order, limit, max_calls, display_progress, processes, make_kwargs, *restrictions)
    237 if processes == 1:
    238     for key in (
    239         tqdm(keys, desc=self.__class__.__name__) if display_progress else keys
    240     ):
--> 241         error = self._populate1(key, jobs, **populate_kwargs)
    242         if error is not None:
    243             error_list.append(error)

File ~/mambaforge/envs/spyglass-position/lib/python3.9/site-packages/datajoint/autopopulate.py:292, in AutoPopulate._populate1(self, key, jobs, suppress_errors, return_exception_objects, make_kwargs)
    290 self.__class__._allow_insert = True
    291 try:
--> 292     make(dict(key), **(make_kwargs or {}))
    293 except (KeyboardInterrupt, SystemExit, Exception) as error:
    294     try:

File ~/spyglass/src/spyglass/position/v1/position_dlc_pose_estimation.py:213, in DLCPoseEstimation.make(self, key)
    209 """.populate() method will launch training for each PoseEstimationTask"""
    210 self.log_path = (
    211     Path(infer_output_dir(key=key, makedir=False)) / "log.log"
    212 )
--> 213 self._logged_make(key)

File ~/spyglass/src/spyglass/position/v1/dlc_utils.py:198, in file_log.<locals>.decorator.<locals>.wrapper(self, *args, **kwargs)
    196     logger.removeHandler(logger.handlers[0])
    197 try:
--> 198     return func(self, *args, **kwargs)
    199 finally:
    200     if not console:

File ~/spyglass/src/spyglass/position/v1/position_dlc_pose_estimation.py:252, in DLCPoseEstimation._logged_make(self, key)
    246 creation_time = datetime.fromtimestamp(
    247     dlc_result.creation_time
    248 ).strftime("%Y-%m-%d %H:%M:%S")
    250 logger.info("getting raw position")
    251 interval_list_name = (
--> 252     convert_epoch_interval_name_to_position_interval_name(
    253         {
    254             "nwb_file_name": key["nwb_file_name"],
    255             "epoch": key["epoch"],
    256         },
    257         populate_missing=False,
    258     )
    259 )
    260 if interval_list_name:
    261     spatial_series = (
    262         RawPosition()
    263         & {**key, "interval_list_name": interval_list_name}
    264     ).fetch_nwb()[0]["raw_position"]

File ~/spyglass/src/spyglass/common/common_behav.py:649, in convert_epoch_interval_name_to_position_interval_name(key, populate_missing)
    646     PositionIntervalMap()._no_transaction_make(key)
    647     pos_query = PositionIntervalMap & key
--> 649 if pos_query.fetch(pos_str)[0] == "":
    650     logger.info(f"No position intervals found for {key}")
    651     return []

IndexError: index 0 is out of bounds for axis 0 with size 0

it looks like for a different nwbfile ("RS2120241020_.nwb") there are entries in the 'PositionIntervalMap' table, but i can't seem to populate this table for the nwbfile i am working with now ("RS2120241016_.nwb"). any ideas would be appreciated, thank you.

@samuelbray32
Copy link
Collaborator

Immediate solution
@MichaelCoulter Until solution is merged in, can avoid by running this before the populate. Will ensure that every epoch has an appropriate entry (note this is now effectively done at session insertion for new files, should only pop up for older pre-existing files in database)

from spyglass.common import convert_epoch_interval_name_to_position_interval_name, TaskEpoch

key = {"nwb_file_name": "RS2120241016_.nwb",}
for k in (TaskEpoch & key).fetch("KEY"):
    convert_epoch_interval_name_to_position_interval_name(k)

Spyglass solution
I think this line can allow for populate_missing=True since it doesn't use a transact. This would prevent the error

@samuelbray32
Copy link
Collaborator

@MichaelCoulter if you can test the branch for the #1208 PR on another one that hit this error that would be great

@MichaelCoulter
Copy link
Collaborator Author

thanks! i will test this branch.

@MichaelCoulter
Copy link
Collaborator Author

that change didnt work for me. i recreated deleting that line in my own repo and got this error

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[3], line 101
     87 pose_estimation_key = (
     88     sgp.DLCPoseEstimationSelection().insert_estimation_task(
     89         {
   (...)
     97     )
     98 )
    100 ##populate DLC Pose Estimation
--> 101 sgp.DLCPoseEstimation().populate(pose_estimation_key)
    103 ##start smooth interpolation
    104 si_params_name = "just_nan"

File ~/spyglass/src/spyglass/utils/dj_mixin.py:608, in SpyglassMixin.populate(self, *restrictions, **kwargs)
    606 if use_transact:  # Pass single-process populate to super
    607     kwargs["processes"] = processes
--> 608     return super().populate(*restrictions, **kwargs)
    609 else:  # No transaction protection, use bare make
    610     for key in keys:

File ~/mambaforge/envs/spyglass-position/lib/python3.9/site-packages/datajoint/autopopulate.py:241, in AutoPopulate.populate(self, suppress_errors, return_exception_objects, reserve_jobs, order, limit, max_calls, display_progress, processes, make_kwargs, *restrictions)
    237 if processes == 1:
    238     for key in (
    239         tqdm(keys, desc=self.__class__.__name__) if display_progress else keys
    240     ):
--> 241         error = self._populate1(key, jobs, **populate_kwargs)
    242         if error is not None:
    243             error_list.append(error)

File ~/mambaforge/envs/spyglass-position/lib/python3.9/site-packages/datajoint/autopopulate.py:292, in AutoPopulate._populate1(self, key, jobs, suppress_errors, return_exception_objects, make_kwargs)
    290 self.__class__._allow_insert = True
    291 try:
--> 292     make(dict(key), **(make_kwargs or {}))
    293 except (KeyboardInterrupt, SystemExit, Exception) as error:
    294     try:

File ~/spyglass/src/spyglass/position/v1/position_dlc_pose_estimation.py:213, in DLCPoseEstimation.make(self, key)
    209 """.populate() method will launch training for each PoseEstimationTask"""
    210 self.log_path = (
    211     Path(infer_output_dir(key=key, makedir=False)) / "log.log"
    212 )
--> 213 self._logged_make(key)

File ~/spyglass/src/spyglass/position/v1/dlc_utils.py:198, in file_log.<locals>.decorator.<locals>.wrapper(self, *args, **kwargs)
    196     logger.removeHandler(logger.handlers[0])
    197 try:
--> 198     return func(self, *args, **kwargs)
    199 finally:
    200     if not console:

File ~/spyglass/src/spyglass/position/v1/position_dlc_pose_estimation.py:252, in DLCPoseEstimation._logged_make(self, key)
    246 creation_time = datetime.fromtimestamp(
    247     dlc_result.creation_time
    248 ).strftime("%Y-%m-%d %H:%M:%S")
    250 logger.info("getting raw position")
    251 interval_list_name = (
--> 252     convert_epoch_interval_name_to_position_interval_name(
    253         {
    254             "nwb_file_name": key["nwb_file_name"],
    255             "epoch": key["epoch"],
    256         },
    257         #populate_missing=False,
    258     )
    259 )
    260 if interval_list_name:
    261     spatial_series = (
    262         RawPosition()
    263         & {**key, "interval_list_name": interval_list_name}
    264     ).fetch_nwb()[0]["raw_position"]

File ~/spyglass/src/spyglass/common/common_behav.py:646, in convert_epoch_interval_name_to_position_interval_name(key, populate_missing)
    644     if null_entry:
    645         pos_query.delete(safemode=False)  # no prompt
--> 646     PositionIntervalMap()._no_transaction_make(key)
    647     pos_query = PositionIntervalMap & key
    649 if pos_query.fetch(pos_str)[0] == "":

File ~/spyglass/src/spyglass/common/common_behav.py:547, in PositionIntervalMap._no_transaction_make(self, key)
    545 if len(pos_intervals) == 0:
    546     logger.error(f"NO POS INTERVALS FOR {key};\n{no_pop_msg}")
--> 547     self.insert1(null_key, **insert_opts)
    548     return
    550 valid_times = (IntervalList & key).fetch1("valid_times")

File ~/mambaforge/envs/spyglass-position/lib/python3.9/site-packages/datajoint/table.py:337, in Table.insert1(self, row, **kwargs)
    330 def insert1(self, row, **kwargs):
    331     """
    332     Insert one data record into the table. For ``kwargs``, see ``insert()``.
    333 
    334     :param row: a numpy record, a dict-like object, or an ordered sequence to be inserted
    335         as one row.
    336     """
--> 337     self.insert((row,), **kwargs)

File ~/mambaforge/envs/spyglass-position/lib/python3.9/site-packages/datajoint/table.py:419, in Table.insert(self, rows, replace, skip_duplicates, ignore_extra_fields, allow_direct_insert)
    416     return
    418 field_list = []  # collects the field list from first row (passed by reference)
--> 419 rows = list(
    420     self.__make_row_to_insert(row, field_list, ignore_extra_fields)
    421     for row in rows
    422 )
    423 if rows:
    424     try:

File ~/mambaforge/envs/spyglass-position/lib/python3.9/site-packages/datajoint/table.py:420, in <genexpr>(.0)
    416     return
    418 field_list = []  # collects the field list from first row (passed by reference)
    419 rows = list(
--> 420     self.__make_row_to_insert(row, field_list, ignore_extra_fields)
    421     for row in rows
    422 )
    423 if rows:
    424     try:

File ~/mambaforge/envs/spyglass-position/lib/python3.9/site-packages/datajoint/table.py:871, in Table.__make_row_to_insert(self, row, field_list, ignore_extra_fields)
    865     attributes = [
    866         self.__make_placeholder(name, row[name], ignore_extra_fields)
    867         for name in self.heading
    868         if name in row.dtype.fields
    869     ]
    870 elif isinstance(row, collections.abc.Mapping):  # dict-based
--> 871     check_fields(row)
    872     attributes = [
    873         self.__make_placeholder(name, row[name], ignore_extra_fields)
    874         for name in self.heading
    875         if name in row
    876     ]
    877 else:  # positional

File ~/mambaforge/envs/spyglass-position/lib/python3.9/site-packages/datajoint/table.py:857, in Table.__make_row_to_insert.<locals>.check_fields(fields)
    855         for field in fields:
    856             if field not in self.heading:
--> 857                 raise KeyError(
    858                     "`{0:s}` is not in the table heading".format(field)
    859                 )
    860 elif set(field_list) != set(fields).intersection(self.heading.names):
    861     raise DataJointError("Attempt to insert rows with different fields.")

KeyError: '`epoch` is not in the table heading'

@samuelbray32
Copy link
Collaborator

@MichaelCoulter what is the value of pose_estimation_key

@samuelbray32
Copy link
Collaborator

Ah, I see. The _no_transaction_make isn't strictly called py the populate function, which means that extra values can be in the key. I'll put a fix in the same PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants