-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
42 lines (33 loc) · 1.52 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0
"""Top-level entrypoint script for model training (also supports loading for inference)
"""
# Python Built-Ins:
import logging
import os
import sys
def run_training():
"""Configure logging, import local modules and run the training job"""
consolehandler = logging.StreamHandler(sys.stdout)
consolehandler.setFormatter(
logging.Formatter("%(asctime)s [%(name)s] %(levelname)s %(message)s")
)
logging.basicConfig(handlers=[consolehandler], level=os.environ.get("LOG_LEVEL", logging.INFO))
from code.train import main
return main()
if __name__ == "__main__":
# If the file is running as a script, we're in training mode and should run the actual training
# routine (with a little logging setup before any imports, to make sure output shows up ok):
run_training()
else:
# If the file is imported as a module, we're in inference mode and should pass through the
# override functions defined in the inference module. This is to support directly deploying the
# model via SageMaker SDK's Estimator.deploy(), which will carry over the environment variable
# SAGEMAKER_PROGRAM=train.py from training - causing the server to try and load handlers from
# here rather than inference.py.
from code.inference import *
def _mp_fn(index):
"""For torch_xla / SageMaker Training Compiler
(See smtc_launcher.py in this folder for configuration tips)
"""
return run_training()