diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 6011d8a..0d61db7 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -188,13 +188,13 @@ def __init__( ) if gradient_checkpointing: - self.model.config["gradient_checkpointing"] = True + setattr(self.model.config, "gradient_checkpointing", True) if schema_tokens: - self.model.config["schema_tokens"] = schema_tokens + setattr(self.model.config, "schema_tokens", schema_tokens) if schema_tokens: - self.model.config["schema_return"] = schema_return + setattr(self.model.config, "schema_return", schema_return) if self.tokenizer is None: # Update tokenizer settings (if not set already)