-
Notifications
You must be signed in to change notification settings - Fork 885
add config option for launcher #1354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bae1519
b39976c
48ad12f
f0a9bf3
feb65cc
4313dc9
a90217f
3c413d5
a744304
e9e8184
f12a736
477fa3c
1d8ab51
93e5dd6
4109c6d
b1b9023
4e0d29a
f182802
9fee5b4
6d1e315
370f62e
769cf5e
bc49f07
c12061e
773434f
d1e7104
506fc08
4cc8a76
b99100d
9283e32
40aff06
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
# TorchServe with Intel® Extension for PyTorch* | ||
|
||
TorchServe can be used with Intel® Extension for PyTorch* (IPEX) to give performance boost on Intel hardware<sup>1</sup>. | ||
Here we show how to use TorchServe with IPEX. | ||
|
||
<sup>1. While IPEX benefits all platforms, plaforms with AVX512 benefit the most. </sup> | ||
|
||
## Contents of this Document | ||
* [Install Intel Extension for PyTorch](#install-intel-extension-for-pytorch) | ||
* [Serving model with Intel Extension for PyTorch](#serving-model-with-intel-extension-for-pytorch) | ||
* [TorchServe with Launcher](#torchserve-with-launcher) | ||
* [Creating and Exporting INT8 model for IPEX](#creating-and-exporting-int8-model-for-ipex) | ||
* [Benchmarking with Launcher](#benchmarking-with-launcher) | ||
|
||
|
||
## Install Intel Extension for PyTorch | ||
Refer to the documentation [here](https://github.com/intel/intel-extension-for-pytorch#installation). | ||
|
||
## Serving model with Intel Extension for PyTorch | ||
After installation, all it needs to be done to use TorchServe with IPEX is to enable it in `config.properties`. | ||
``` | ||
ipex_enable=true | ||
``` | ||
Once IPEX is enabled, deploying PyTorch model follows the same procedure shown [here](https://pytorch.org/serve/use_cases.html). TorchServe with IPEX can deploy any model and do inference. | ||
|
||
## TorchServe with Launcher | ||
Launcher is a script to automate the process of tunining configuration setting on intel hardware to boost performance. Tuning configurations such as OMP_NUM_THREADS, thread affininty, memory allocator can have a dramatic effect on performance. Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/tuning_guide.md) and [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for details on performance tuning with launcher. | ||
|
||
All it needs to be done to use TorchServe with launcher is to set its configuration in `config.properties`. | ||
|
||
Add the following lines in `config.properties` to use launcher with its default configuration. | ||
``` | ||
ipex_enable=true | ||
cpu_launcher_enable=true | ||
``` | ||
|
||
Launcher by default uses `numactl` if its installed to ensure socket is pinned and thus memory is allocated from local numa node. To use launcher without numactl, add the following lines in `config.properties`. | ||
``` | ||
ipex_enable=true | ||
cpu_launcher_enable=true | ||
cpu_launcher_args=--disable_numactl | ||
``` | ||
|
||
Launcher by default uses only non-hyperthreaded cores if hyperthreading is present to avoid core compute resource sharing. To use launcher with all cores, both physical and logical, add the following lines in `config.properties`. | ||
``` | ||
ipex_enable=true | ||
cpu_launcher_enable=true | ||
cpu_launcher_args=--use_logical_core | ||
``` | ||
|
||
Below is an example of passing multiple args to `cpu_launcher_args`. | ||
``` | ||
ipex_enable=true | ||
cpu_launcher_enable=true | ||
cpu_launcher_args=--use_logical_core --disable_numactl | ||
``` | ||
|
||
Some useful `cpu_launcher_args` to note are: | ||
1. Memory Allocator: [ PTMalloc `--use_default_allocator` | *TCMalloc `--enable_tcmalloc`* | JeMalloc `--enable_jemalloc`] | ||
* PyTorch by defualt uses PTMalloc. TCMalloc/JeMalloc generally gives better performance. | ||
2. OpenMP library: [GNU OpenMP `--disable_iomp` | *Intel OpenMP*] | ||
* PyTorch by default uses GNU OpenMP. Launcher by default uses Intel OpenMP. Intel OpenMP library generally gives better performance. | ||
3. Socket id: [`--socket_id`] | ||
* Launcher by default uses all physical cores. Limit memory access to local memories on the Nth socket to avoid Non-Uniform Memory Access (NUMA). | ||
|
||
Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for a full list of tunable configuration of launcher. | ||
|
||
|
||
## Creating and Exporting INT8 model for IPEX | ||
Intel Extension for PyTorch supports both eager and torchscript mode. In this section, we show how to deploy INT8 model for IPEX. | ||
|
||
### 1. Creating a serialized file | ||
First create `.pt` serialized file using IPEX INT8 inference. Here we show two examples with BERT and ResNet50. | ||
|
||
#### BERT | ||
|
||
``` | ||
import torch | ||
import intel_extension_for_pytorch as ipex | ||
import transformers | ||
from transformers import AutoModelForSequenceClassification, AutoConfig | ||
|
||
# load the model | ||
config = AutoConfig.from_pretrained( | ||
"bert-base-uncased", return_dict=False, torchscript=True, num_labels=2) | ||
model = AutoModelForSequenceClassification.from_pretrained( | ||
"bert-base-uncased", config=config) | ||
model = model.eval() | ||
|
||
# define dummy input tensor to use for the model's forward call to record operations in the model for tracing | ||
N, max_length = 1, 384 | ||
dummy_tensor = torch.ones((N, max_length), dtype=torch.long) | ||
|
||
# calibration | ||
# ipex supports two quantization schemes to be used for activation: torch.per_tensor_affine and torch.per_tensor_symmetric | ||
# default qscheme is torch.per_tensor_affine | ||
conf = ipex.quantization.QuantConf(qscheme=torch.per_tensor_affine) | ||
n_iter = 100 | ||
with torch.no_grad(): | ||
for i in range(n_iter): | ||
with ipex.quantization.calibrate(conf): | ||
model(dummy_tensor, dummy_tensor, dummy_tensor) | ||
|
||
# optionally save the configuraiton for later use | ||
# save: | ||
# conf.save("model_conf.json") | ||
# load: | ||
# conf = ipex.quantization.QuantConf("model_conf.json") | ||
|
||
# conversion | ||
jit_inputs = (dummy_tensor, dummy_tensor, dummy_tensor) | ||
model = ipex.quantization.convert(model, conf, jit_inputs) | ||
|
||
# enable fusion path work(need to run forward propagation twice) | ||
with torch.no_grad(): | ||
y = model(dummy_tensor,dummy_tensor,dummy_tensor) | ||
y = model(dummy_tensor,dummy_tensor,dummy_tensor) | ||
|
||
# save to .pt | ||
torch.jit.save(model, 'bert_int8_jit.pt') | ||
``` | ||
|
||
#### ResNet50 | ||
|
||
``` | ||
import torch | ||
import torch.fx.experimental.optimization as optimization | ||
import intel_extension_for_pytorch as ipex | ||
import torchvision.models as models | ||
|
||
# load the model | ||
model = models.resnet50(pretrained=True) | ||
model = model.eval() | ||
model = optimization.fuse(model) | ||
|
||
# define dummy input tensor to use for the model's forward call to record operations in the model for tracing | ||
N, C, H, W = 1, 3, 224, 224 | ||
dummy_tensor = torch.randn(N, C, H, W).contiguous(memory_format=torch.channels_last) | ||
|
||
# calibration | ||
# ipex supports two quantization schemes to be used for activation: torch.per_tensor_affine and torch.per_tensor_symmetric | ||
# default qscheme is torch.per_tensor_affine | ||
conf = ipex.quantization.QuantConf(qscheme=torch.per_tensor_symmetric) | ||
n_iter = 100 | ||
with torch.no_grad(): | ||
for i in range(n_iter): | ||
with ipex.quantization.calibrate(conf): | ||
model(dummy_tensor) | ||
|
||
# optionally save the configuraiton for later use | ||
# save: | ||
# conf.save("model_conf.json") | ||
# load: | ||
# conf = ipex.quantization.QuantConf("model_conf.json") | ||
|
||
# conversion | ||
jit_inputs = (dummy_tensor) | ||
model = ipex.quantization.convert(model, conf, jit_inputs) | ||
|
||
# enable fusion path work(need to run two iterations) | ||
with torch.no_grad(): | ||
y = model(dummy_tensor) | ||
y = model(dummy_tensor) | ||
|
||
# save to .pt | ||
torch.jit.save(model, 'rn50_int8_jit.pt') | ||
``` | ||
|
||
### 2. Creating a Model Archive | ||
Once the serialized file ( `.pt`) is created, it can be used with `torch-model-archiver` as ususal. Use the following command to package the model. | ||
``` | ||
torch-model-archiver --model-name rn50_ipex_int8 --version 1.0 --serialized-file rn50_int8_jit.pt --handler image_classifier | ||
``` | ||
### 3. Start TorchServe to serve the model | ||
Make sure to set `ipex_enable=true` in `config.properties`. Use the following command to start TorchServe with IPEX. | ||
``` | ||
torchserve --start --ncs --model-store model_store --ts-config config.properties | ||
``` | ||
|
||
### 4. Registering and Deploying model | ||
Registering and deploying the model follows the same steps shown [here](https://pytorch.org/serve/use_cases.html). | ||
|
||
## Benchmarking with Launcher | ||
Launcher can be used with TorchServe official [benchmark](https://github.com/pytorch/serve/tree/master/benchmarks) to launch server and benchmark requests with optimal configuration on Intel hardware. | ||
|
||
In this section we provide examples of benchmarking with launcher with its default configuration. | ||
|
||
Add the following lines to `config.properties` in the benchmark directory to use launcher with its default setting. | ||
``` | ||
ipex_enable=true | ||
cpu_launcher_enable=true | ||
``` | ||
|
||
The rest of the steps for benchmarking follows the same steps shown [here](https://github.com/pytorch/serve/tree/master/benchmarks). | ||
|
||
`model_log.log` contains information and command that were used for this execution launch. | ||
|
||
|
||
CPU usage on a machine with Intel(R) Xeon(R) Platinum 8180 CPU, 2 sockets, 28 cores per socket, 2 threads per core is shown as below: | ||
 | ||
|
||
``` | ||
$ cat logs/model_log.log | ||
2021-12-01 21:22:40,096 - __main__ - WARNING - Both TCMalloc and JeMalloc are not found in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or /home/<user>/.local/lib/ so the LD_PRELOAD environment variable will not be set. This may drop the performance | ||
2021-12-01 21:22:40,096 - __main__ - INFO - OMP_NUM_THREADS=56 | ||
2021-12-01 21:22:40,096 - __main__ - INFO - Using Intel OpenMP | ||
2021-12-01 21:22:40,096 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0 | ||
2021-12-01 21:22:40,096 - __main__ - INFO - KMP_BLOCKTIME=1 | ||
2021-12-01 21:22:40,096 - __main__ - INFO - LD_PRELOAD=<VIRTUAL_ENV>/lib/libiomp5.so | ||
2021-12-01 21:22:40,096 - __main__ - WARNING - Numa Aware: cores:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55] in different NUMA node | ||
``` | ||
|
||
CPU usage on a machine with Intel(R) Xeon(R) Platinum 8375C CPU, 1 socket, 2 cores per socket, 2 threads per socket is shown as below: | ||
 | ||
|
||
``` | ||
$ cat logs/model_log.log | ||
2021-12-02 06:15:03,981 - __main__ - WARNING - Both TCMalloc and JeMalloc are not found in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or /home/<user>/.local/lib/ so the LD_PRELOAD environment variable will not be set. This may drop the performance | ||
2021-12-02 06:15:03,981 - __main__ - INFO - OMP_NUM_THREADS=2 | ||
2021-12-02 06:15:03,982 - __main__ - INFO - Using Intel OpenMP | ||
2021-12-02 06:15:03,982 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0 | ||
2021-12-02 06:15:03,982 - __main__ - INFO - KMP_BLOCKTIME=1 | ||
2021-12-02 06:15:03,982 - __main__ - INFO - LD_PRELOAD=<VIRTUAL_ENV>/lib/libiomp5.so | ||
|
||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ public class WorkerLifeCycle { | |
private Connector connector; | ||
private ReaderThread errReader; | ||
private ReaderThread outReader; | ||
private String launcherArgs; | ||
|
||
public WorkerLifeCycle(ConfigManager configManager, Model model) { | ||
this.configManager = configManager; | ||
|
@@ -39,6 +40,46 @@ public Process getProcess() { | |
return process; | ||
} | ||
|
||
public ArrayList<String> launcherArgsToList() { | ||
ArrayList<String> arrlist = new ArrayList<String>(); | ||
arrlist.add("-m"); | ||
arrlist.add("intel_extension_for_pytorch.cpu.launch"); | ||
arrlist.add("--ninstance"); | ||
arrlist.add("1"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is "- F438 -ninstance" always 1? If not, please make it configurable and update doc to specify the relationship b/w torchserve worker and ipex ninstance. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if (launcherArgs != null && launcherArgs.length() > 1) { | ||
String[] argarray = launcherArgs.split(" "); | ||
for (int i = 0; i < argarray.length; i++) { | ||
arrlist.add(argarray[i]); | ||
} | ||
} | ||
return arrlist; | ||
} | ||
|
||
public boolean isLauncherAvailable() | ||
throws WorkerInitializationException, InterruptedException { | ||
boolean launcherAvailable = false; | ||
try { | ||
ArrayList<String> cmd = new ArrayList<String>(); | ||
cmd.add("python"); | ||
ArrayList<String> args = launcherArgsToList(); | ||
cmd.addAll(args); | ||
cmd.add("--no_python"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does this work exactly? We have two commands There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Launcher uses the following command line: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok got it, could we add a quick comment then to explain that this is just a dummy command to see if launcher exists? Also do users need to install something for the launcher script? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, will add a comment. Only |
||
// try launching dummy command to check launcher availability | ||
String dummyCmd = "hostname"; | ||
cmd.add(dummyCmd); | ||
|
||
String[] cmd_ = new String[cmd.size()]; | ||
cmd_ = cmd.toArray(cmd_); | ||
|
||
Process process = Runtime.getRuntime().exec(cmd_); | ||
int ret = process.waitFor(); | ||
launcherAvailable = (ret == 0); | ||
} catch (IOException | InterruptedException e) { | ||
throw new WorkerInitializationException("Failed to start launcher", e); | ||
} | ||
return launcherAvailable; | ||
} | ||
|
||
public void startWorker(int port) throws WorkerInitializationException, InterruptedException { | ||
File workingDir = new File(configManager.getModelServerHome()); | ||
File modelPath; | ||
|
@@ -51,6 +92,19 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup | |
|
||
ArrayList<String> argl = new ArrayList<String>(); | ||
argl.add(EnvironmentUtils.getPythonRunTime(model)); | ||
|
||
if (configManager.isCPULauncherEnabled()) { | ||
launcherArgs = configManager.getCPULauncherArgs(); | ||
boolean launcherAvailable = isLauncherAvailable(); | ||
if (launcherAvailable) { | ||
ArrayList<String> args = launcherArgsToList(); | ||
argl.addAll(args); | ||
} else { | ||
logger.warn( | ||
"CPU launcher is enabled but launcher is not available. Proceeding without launcher."); | ||
} | ||
} | ||
|
||
argl.add(new File(workingDir, "ts/model_service_worker.py").getAbsolutePath()); | ||
argl.add("--sock-type"); | ||
argl.add(connector.getSocketType()); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious could you point me to the launcher source code that determines how to set
OMP_NUM_THREADS
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please have a look at here and here; launcher by default uses non-hyperthreaded cores only (if hyperthreading is enabled)