8000 Merge pull request #1354 from min-jean-cho/launcher · pytorch/serve@c87bfec · GitHub
[go: up one dir, main page]

Skip to content

Commit c87bfec

Browse files
authored
Merge pull request #1354 from min-jean-cho/launcher
add config option for launcher
2 parents e15b4fe + 40aff06 commit c87bfec

File tree

4 files changed

+290
-1
lines changed

4 files changed

+290
-1
lines changed
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# TorchServe with Intel® Extension for PyTorch*
2+
3+
TorchServe can be used with Intel® Extension for PyTorch* (IPEX) to give performance boost on Intel hardware<sup>1</sup>.
4+
Here we show how to use TorchServe with IPEX.
5+
6+
<sup>1. While IPEX benefits all platforms, plaforms with AVX512 benefit the most. </sup>
7+
8+
## Contents of this Document
9+
* [Install Intel Extension for PyTorch](#install-intel-extension-for-pytorch)
10+
* [Serving model with Intel Extension for PyTorch](#serving-model-with-intel-extension-for-pytorch)
11+
* [TorchServe with Launcher](#torchserve-with-launcher)
12+
* [Creating and Exporting INT8 model for IPEX](#creating-and-exporting-int8-model-for-ipex)
13+
* [Benchmarking with Launcher](#benchmarking-with-launcher)
14+
15+
16+
## Install Intel Extension for PyTorch
17+
Refer to the documentation [here](https://github.com/intel/intel-extension-for-pytorch#installation).
18+
19+
## Serving model with Intel Extension for PyTorch
20+
After installation, all it needs to be done to use TorchServe with IPEX is to enable it in `config.properties`.
21+
```
22+
ipex_enable=true
23+
```
24+
Once IPEX is enabled, deploying PyTorch model follows the same procedure shown [here](https://pytorch.org/serve/use_cases.html). TorchServe with IPEX can deploy any model and do inference.
25+
26+
## TorchServe with Launcher
27+
Launcher is a script to automate the process of tunining configuration setting on intel hardware to boost performance. Tuning configurations such as OMP_NUM_THREADS, thread affininty, memory allocator can have a dramatic effect on performance. Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/tuning_guide.md) and [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for details on performance tuning with launcher.
28+
29+
All it needs to be done to use TorchServe with launcher is to set its configuration in `config.properties`.
30+
31+
Add the following lines in `config.properties` to use launcher with its default configuration.
32+
```
33+
ipex_enable=true
34+
cpu_launcher_enable=true
35+
```
36+
37+
Launcher by default uses `numactl` if its installed to ensure socket is pinned and thus memory is allocated from local numa node. To use launcher without numactl, add the following lines in `config.properties`.
38+
```
39+
ipex_enable=true
40+
cpu_launcher_enable=true
41+
cpu_launcher_args=--disable_numactl
42+
```
43+
44+
Launcher by default uses only non-hyperthreaded cores if hyperthreading is present to avoid core compute resource sharing. To use launcher with all cores, both physical and logical, add the following lines in `config.properties`.
45+
```
46+
ipex_enable=true
47+
cpu_launcher_enable=true
48+
cpu_launcher_args=--use_logical_core
49+
```
50+
51+
Below is an example of passing multiple args to `cpu_launcher_args`.
52+
```
53+
ipex_enable=true
54+
cpu_launcher_enable=true
55+
cpu_launcher_args=--use_logical_core --disable_numactl
56+
```
57+
58+
Some useful `cpu_launcher_args` to note are:
59+
1. Memory Allocator: [ PTMalloc `--use_default_allocator` | *TCMalloc `--enable_tcmalloc`* | JeMalloc `--enable_jemalloc`]
60+
* PyTorch by defualt uses PTMalloc. TCMalloc/JeMalloc generally gives better performance.
61+
2. OpenMP library: [GNU OpenMP `--disable_iomp` | *Intel OpenMP*]
62+
* PyTorch by default uses GNU OpenMP. Launcher by default uses Intel OpenMP. Intel OpenMP library generally gives better performance.
63+
3. Socket id: [`--socket_id`]
64+
* Launcher by default uses all physical cores. Limit memory access to local memories on the Nth socket to avoid Non-Uniform Memory Access (NUMA).
65+
66+
Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for a full list of tunable configuration of launcher.
67+
68+
69+
## Creating and Exporting INT8 model for IPEX
70+
Intel Extension for PyTorch supports both eager and torchscript mode. In this section, we show how to deploy INT8 model for IPEX.
71+
72+
### 1. Creating a serialized file
73+
First create `.pt` serialized file using IPEX INT8 inference. Here we show two examples with BERT and ResNet50.
74+
75+
#### BERT
76+
77+
```
78+
import torch
79+
import intel_extension_for_pytorch as ipex
80+
import transformers
81+
from transformers import AutoModelForSequenceClassification, AutoConfig
82+
83+
# load the model
84+
config = AutoConfig.from_pretrained(
85+
"bert-base-uncased", return_dict=False, torchscript=True, num_labels=2)
86+
model = AutoModelForSequenceClassification.from_pretrained(
87+
"bert-base-uncased", config=config)
88+
model = model.eval()
89+
90+
# define dummy input tensor to use for the model's forward call to record operations in the model for tracing
91+
N, max_length = 1, 384
92+
dummy_tensor = torch.ones((N, max_length), dtype=torch.long)
93+
94+
# calibration
95+
# ipex supports two quantization schemes to be used for activation: torch.per_tensor_affine and torch.per_tensor_symmetric
96+
# default qscheme is torch.per_tensor_affine
97+
conf = ipex.quantization.QuantConf(qscheme=torch.per_tensor_affine)
98+
n_iter = 100
99+
with torch.no_grad():
100+
for i in range(n_iter):
101+
with ipex.quantization.calibrate(conf):
102+
model(dummy_tensor, dummy_tensor, dummy_tensor)
103+
104+
# optionally save the configuraiton for later use
105+
# save:
106+
# conf.save("model_conf.json")
107+
# load:
108+
# conf = ipex.quantization.QuantConf("model_conf.json")
109+
110+
# conversion
111+
jit_inputs = (dummy_tensor, dummy_tensor, dummy_tensor)
112+
model = ipex.quantization.convert(model, conf, jit_inputs)
113+
114+
# enable fusion path work(need to run forward propagation twice)
115+
with torch.no_grad():
116+
y = model(dummy_tensor,dummy_tensor,dummy_tensor)
117+
y = model(dummy_tensor,dummy_tensor,dummy_tensor)
118+
119+
# save to .pt
120+
torch.jit.save(model, 'bert_int8_jit.pt')
121+
```
122+
123+
#### ResNet50
124+
125+
```
126+
import torch
127+
import torch.fx.experimental.optimization as optimization
128+
import intel_extension_for_pytorch as ipex
129+
import torchvision.models as models
130+
131+
# load the model
132+
model = models.resnet50(pretrained=True)
133+
model = model.eval()
134+
model = optimization.fuse(model)
135+
136+
# define dummy input tensor to use for the model's forward call to record operations in the model for tracing
137+
N, C, H, W = 1, 3, 224, 224
138+
dummy_tensor = torch.randn(N, C, H, W).contiguous(memory_format=torch.channels_last)
139+
140+
# calibration
141+
# ipex supports two quantization schemes to be used for activation: torch.per_tensor_affine and torch.per_tensor_symmetric
142+
# default qscheme is torch.per_tensor_affine
143+
conf = ipex.quantization.QuantConf(qscheme=torch.per_tensor_symmetric)
144+
n_iter = 100
145+
with torch.no_grad():
146+
for i in range(n_iter):
147+
with ipex.quantization.calibrate(conf):
148+
model(dummy_tensor)
149+
150+
# optionally save the configuraiton for later use
151+
# save:
152+
# conf.save("model_conf.json")
153+
# load:
154+
# conf = ipex.quantization.QuantConf("model_conf.json")
155+
156+
# conversion
157+
jit_inputs = (dummy_tensor)
158+
model = ipex.quantization.convert(model, conf, jit_inputs)
159+
160+
# enable fusion path work(need to run two iterations)
161+
with torch.no_grad():
162+
y = model(dummy_tensor)
163+
y = model(dummy_tensor)
164+
165+
# save to .pt
166+
torch.jit.save(model, 'rn50_int8_jit.pt')
167+
```
168+
169+
### 2. Creating a Model Archive
170+
Once the serialized file ( `.pt`) is created, it can be used with `torch-model-archiver` as ususal. Use the following command to package the model.
171+
```
172+
torch-model-archiver --model-name rn50_ipex_int8 --version 1.0 --serialized-file rn50_int8_jit.pt --handler image_classifier
173+
```
174+
### 3. Start TorchServe to serve the model
175+
Make sure to set `ipex_enable=true` in `config.properties`. Use the following command to start TorchServe with IPEX.
176+
```
177+
torchserve --start --ncs --model-store model_store --ts-config config.properties
178+
```
179+
180+
### 4. Registering and Deploying model
181+
Registering and deploying the model follows the same steps shown [here](https://pytorch.org/serve/use_cases.html).
182+
183+
## Benchmarking with Launcher
184+
Launcher can be used with TorchServe official [benchmark](https://github.com/pytorch/serve/tree/master/benchmarks) to launch server and benchmark requests with optimal configuration on Intel hardware.
185+
186+
In this section we provide examples of benchmarking with launcher with its default configuration.
187+
188+
Add the following lines to `config.properties` in the benchmark directory to use launcher with its default setting.
189+
```
190+
ipex_enable=true
191+
cpu_launcher_enable=true
192+
```
193+
194+
The rest of the steps for benchmarking follows the same steps shown [here](https://github.com/pytorch/serve/tree/master/benchmarks).
195+
196+
`model_log.log` contains information and command that were used for this execution launch.
197+
198+
199+
CPU usage on a machine with Intel(R) Xeon(R) Platinum 8180 CPU, 2 sockets, 28 cores per socket, 2 threads per core is shown as below:
200+
![launcher_default_2sockets](https://user-images.githubusercontent.com/93151422/144373537-07787510-039d-44c4-8cfd-6afeeb64ac78.gif)
201+
202+
```
203+
$ cat logs/model_log.log
204+
2021-12-01 21:22:40,096 - __main__ - WARNING - Both TCMalloc and JeMalloc are not found in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or /home/<user>/.local/lib/ so the LD_PRELOAD environment variable will not be set. This may drop the performance
205+
2021-12-01 21:22:40,096 - __main__ - INFO - OMP_NUM_THREADS=56
206+
2021-12-01 21:22:40,096 - __main__ - INFO - Using Intel OpenMP
207+
2021-12-01 21:22:40,096 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
208+
2021-12-01 21:22:40,096 - __main__ - INFO - KMP_BLOCKTIME=1
209+
2021-12-01 21:22:40,096 - __main__ - INFO - LD_PRELOAD=<VIRTUAL_ENV>/lib/libiomp5.so
210+
2021-12-01 21:22:40,096 - __main__ - WARNING - Numa Aware: cores:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55] in different NUMA node
211+
```
212+
213+
CPU usage on a machine with Intel(R) Xeon(R) Platinum 8375C CPU, 1 socket, 2 cores per socket, 2 threads per socket is shown as below:
214+
![launcher_default_1socket](https://user-images.githubusercontent.com/93151422/144372993-92b2ca96-f309-41e2-a5c8-bf2143815c93.gif)
215+
216+
```
217+
$ cat logs/model_log.log
218+
2021-12-02 06:15:03,981 - __main__ - WARNING - Both TCMalloc and JeMalloc are not found in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or /home/<user>/.local/lib/ so the LD_PRELOAD environment variable will not be set. This may drop the performance
219+
2021-12-02 06:15:03,981 - __main__ - INFO - OMP_NUM_THREADS=2
220+
2021-12-02 06:15:03,982 - __main__ - INFO - Using Intel OpenMP
221+
2021-12-02 06:15:03,982 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
222+
2021-12-02 06:15:03,982 - __main__ - INFO - KMP_BLOCKTIME=1
223+
2021-12-02 06:15:03,982 - __main__ - INFO - LD_PRELOAD=<VIRTUAL_ENV>/lib/libiomp5.so
224+
225+
```

frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ public final class ConfigManager {
7070

7171
// IPEX config option that can be set at config.properties
7272
private static final String TS_IPEX_ENABLE = "ipex_enable";
73+
private static final String TS_CPU_LAUNCHER_ENABLE = "cpu_launcher_enable";
74+
private static final String TS_CPU_LAUNCHER_ARGS = "cpu_launcher_args";
7375

7476
private static final String TS_ASYNC_LOGGING = "async_logging";
7577
private static final String TS_CORS_ALLOWED_ORIGIN = "cors_allowed_origin";
@@ -339,6 +341,14 @@ public boolean isMetricApiEnable() {
339341
return Boolean.parseBoolean(getProperty(TS_ENABLE_METRICS_API, "true"));
340342
}
341343

344+
public boolean isCPULauncherEnabled() {
345+
return Boolean.parseBoolean(getProperty(TS_CPU_LAUNCHER_ENABLE, "false"));
346+
}
347+
348+
public String getCPULauncherArgs() {
349+
return getProperty(TS_CPU_LAUNCHER_ARGS, null);
350+
}
351+
342352
public int getNettyThreads() {
343353
return getIntProperty(TS_NUMBER_OF_NETTY_THREADS, 0);
344354
}

frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ public class WorkerLifeCycle {
2929
private Connector connector;
3030
private ReaderThread errReader;
3131
private ReaderThread outReader;
32+
private String launcherArgs;
3233

3334
public WorkerLifeCycle(ConfigManager configManager, Model model) {
3435
this.configManager = configManager;
@@ -39,6 +40,46 @@ public Process getProcess() {
3940
return process;
4041
}
4142

43+
public ArrayList<String> launcherArgsToList() {
44+
ArrayList<String> arrlist = new ArrayList<String>();
45+
arrlist.add("-m");
46+
arrlist.add("intel_extension_for_pytorch.cpu.launch");
47+
arrlist.add("--ninstance");
48+
arrlist.add("1");
49+
if (launcherArgs != null && launcherArgs.length() > 1) {
50+
String[] argarray = launcherArgs.split(" ");
51+
for (int i = 0; i < argarray.length; i++) {
52+
arrlist.add(argarray[i]);
53+
}
54+
}
55+
return arrlist;
56+
}
57+
58+
public boolean isLauncherAvailable()
59+
throws WorkerInitializationException, InterruptedException {
60+
boolean launcherAvailable = false;
61+
try {
62+
ArrayList<String> cmd = new ArrayList<String>();
63+
cmd.add("python");
64+
ArrayList<String> args = launcherArgsToList();
65+
cmd.addAll(args);
66+
cmd.add("--no_python");
67+
// try launching dummy command to check launcher availability
68+
String dummyCmd = "hostname";
69+
cmd.add(dummyCmd);
70+
71+
String[] cmd_ = new String[cmd.size()];
72+
cmd_ = cmd.toArray(cmd_);
73+
74+
Process process = Runtime.getRuntime().exec(cmd_);
75+
int ret = process.waitFor();
76+
launcherAvailable = (ret == 0);
77+
} catch (IOException | InterruptedException e) {
78+
throw new WorkerInitializationException("Failed to start launcher", e);
79+
}
80+
return launcherAvailable;
81+
}
82+
4283
public void startWorker(int port) throws WorkerInitializationException, InterruptedException {
4384
File workingDir = new File(configManager.getModelServerHome());
4485
File modelPath;
@@ -51,6 +92,19 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup
5192

5293
ArrayList<String> argl = new ArrayList<String>();
5394
argl.add(EnvironmentUtils.getPythonRunTime(model));
95+
96+
if (configManager.isCPULauncherEnabled()) {
97+
launcherArgs = configManager.getCPULauncherArgs();
98+
boolean launcherAvailable = isLauncherAvailable();
99+
if (launcherAvailable) {
100+
ArrayList<String> args = launcherArgsToList();
101+
argl.addAll(args);
102+
} else {
103+
logger.warn(
104+
"CPU launcher is enabled but launcher is not available. Proceeding without launcher.");
105+
}
106+
}
107+
54108
argl.add(new File(workingDir, "ts/model_service_worker.py").getAbsolutePath());
55109
argl.add("--sock-type");
56110
argl.add(connector.getSocketType());

ts/torch_handler/base_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import intel_extension_for_pytorch as ipex
2121
ipex_enabled = True
2222
except ImportError as error:
23-
logger.warning("IPEX was not installed. Please install IPEX if wanted.")
23+
logger.warning("IPEX is enabled but intel-extension-for-pytorch is not installed. Proceeding without IPEX.")
2424

2525
class BaseHandler(abc.ABC):
2626
"""

0 commit comments

Comments
 (0)
0