https://github.com/pytorch/serve#install-torchserve-and-torch-model-archiver
torch-model-archiver --model-name pulse --version 1.0 --model-file PULSE.py --serialized-file model.pt --export-path ./model_store --handler pulse_serve_handler:entry_point_function_name --extra-files PULSE.py,align_face.py,bicubic.py,drive.py,gaussian_fit.pt,loss.py,mapping.pt,SphericalOptimizer.py,shape_predictor.py,stylegan.py,synthesis.pt -f
Note the --extra-files
parameters. All dependent files of the model should be specified and they will be packaged into the .mar
file.
To load the mapping and synthesis network at run time, we need to save the entire model first. In our case it's hard to load the dependent parameters (eg, mapping and synthesis) when serving the request.
import PULSE from PULSE
import torch
model = PULSE(...)
model.save('model.pt')
There are only 4 types of handler for some basic classification tasks. In our case the api should return a super-resolution image to the client. So a custom handler is needed.
import os
import torch
# Create model object
model = None
def entry_point_function_name(data, context):
"""
Works on data and context to create model object or process inference request.
Following sample demonstrates how model object can be initialized for jit mode.
Similarly you can do it for eager mode models.
:param data: Input data for prediction
:param context: context contains model server system properties
:return: prediction output
"""
global model
if not data:
manifest = context.manifest
properties = context.system_properties
model_dir = properties.get("model_dir")
device = torch.device(
"cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
# Read model serialize/pt file
serialized_file = manifest['model']['serializedFile']
model_pt_path = os.path.join(model_dir, serialized_file)
if not os.path.isfile(model_pt_path):
raise RuntimeError("Missing the model.pt file")
model = torch.load(model_pt_path)
else:
ref_im = torch.randn(
(1, 3, 1024, 1024), dtype=torch.float, requires_grad=True, device='cpu')
kwargs = {'input_dir': './input',
'output_dir': './output',
'cache_dir': 'cache',
'duplicates': 1,
'batch_size': 1,
'seed': 23,
'loss_str': '100*L2+0.05*GEOCROSS',
'eps': 0.02,
'noise_type': 'zero',
'num_trainable_noise_layers': 5,
'tile_latent': False,
'bad_noise_layers': '17',
'opt_name': 'adam',
'learning_rate': 0.4,
'steps': 10,
'lr_schedule': 'linear1cycledrop',
'save_intermediate': True}
(HR, LR) = next(model(ref_im, **kwargs))
return [HR]
When we start the server, it will invoke this entry point one time with no data, and a model instance will be initialized. Later when a request comes, it will call the forward
method of this instance.
The server will start four workers for each model by default. This configuration can be changed using config.properties
file.
default_workers_per_model=1
torchserve --start --ncs --model-store model_store --models pulse
In our case the server needs to return the generated SR image to the client. However, the size of the response is limited in the server. An error occurs:
io.netty.handler.codec.CorruptedFrameException: Message size exceed limit: 12583737
at org.pytorch.serve.util.codec.CodecUtils.readLength(CodecUtils.java:24)
at org.pytorch.serve.util.codec.ModelResponseDecoder.decode(ModelResponseDecoder.java:72)
at io.netty.handler.codec.ByteToMessageDecoder.decodeRemovalReentryProtection(ByteToMessageDecoder.java:501)
at io.netty.handler.codec.ByteToMessageDecoder.callDecode(ByteToMessageDecoder.java:440)
at io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:276)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:379)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:365)
at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:357)
at io.netty.channel.DefaultChannelPipeline$HeadContext.channelRead(DefaultChannelPipeline.java:1410)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:379)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:365)
at io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:919)
at io.netty.channel.kqueue.AbstractKQueueStreamChannel$KQueueStreamUnsafe.readReady(AbstractKQueueStreamChannel.java:544)
at io.netty.channel.kqueue.KQueueDomainSocketChannel$KQueueDomainUnsafe.readReady(KQueueDomainSocketChannel.java:131)
at io.netty.channel.kqueue.AbstractKQueueChannel$AbstractKQueueUnsafe.readReady(AbstractKQueueChannel.java:382)
at io.netty.channel.kqueue.KQueueEventLoop.processReady(KQueueEventLoop.java:211)
at io.netty.channel.kqueue.KQueueEventLoop.run(KQueueEventLoop.java:289)
at io.netty.util.concurrent.SingleThreadEventExecutor$4.run(SingleThreadEventExecutor.java:989)
at io.netty.util.internal.ThreadExecutorMap$2.run(ThreadExecutorMap.java:74)
at io.netty.util.concurrent.FastThreadLocalRunnable.run(FastThreadLocalRunnable.java:30)
at java.base/java.lang.Thread.run(Thread.java:829)
TorchServe uses a netty to serve the requests. In the CodecUtils.java
it specifies the maximum size of a message. So in our case a traditional server may be more suitable