Skip to content

Commit

Permalink
Add init_weights method to FlaxMixin (huggingface#513)
Browse files Browse the repository at this point in the history
* Add `init_weights` method to `FlaxMixin`

* Rn `random_state` -> `shape_state`

* `PRNGKey(0)` for `jax.eval_shape`

* No allow mismatched sizes

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Suraj Patil <[email protected]>

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Suraj Patil <[email protected]>

* docstring diffusers

Co-authored-by: Suraj Patil <[email protected]>
  • Loading branch information
Mishig Davaadorj and patil-suraj authored Sep 15, 2022
1 parent d144c46 commit fb5468a
Showing 1 changed file with 70 additions and 5 deletions.
75 changes: 70 additions & 5 deletions src/diffusers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax
import jax.numpy as jnp
import msgpack.exceptions
from flax.core.frozen_dict import FrozenDict
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -183,6 +183,9 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
```"""
return self._cast_floating_to(params, jnp.float16, mask)

def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
raise NotImplementedError(f"init_weights method has to be implemented for {self}")

@classmethod
def from_pretrained(
cls,
Expand Down Expand Up @@ -227,10 +230,6 @@ def from_pretrained(
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
checkpoint with 3 labels).
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
Expand Down Expand Up @@ -394,6 +393,72 @@ def from_pretrained(
# flatten dicts
state = flatten_dict(state)

params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())

shape_state = flatten_dict(unfreeze(params_shape_tree))

missing_keys = required_params - set(state.keys())
unexpected_keys = set(state.keys()) - required_params

if missing_keys:
logger.warning(
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
"Make sure to call model.init_weights to initialize the missing weights."
)
cls._missing_keys = missing_keys

# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys = []
for key in state.keys():
if key in shape_state and state[key].shape != shape_state[key].shape:
raise ValueError(
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
)

# remove unexpected keys to not be saved again
for unexpected_key in unexpected_keys:
del state[unexpected_key]

if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
" with another architecture."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")

if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
elif len(mismatched_keys) == 0:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
" training."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)

# dictionary of key: dtypes for the model params
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
# extract keys of parameters not in jnp.float32
Expand Down

0 comments on commit fb5468a

Please sign in to comment.