making DeepCreamPy 2.6x faster with ONNX and model surgery
DeepCreamPy is a tool that uses a deep neural network to remove censors from explicit anime images. It is able to fill in parts of images that are missing (also known as inpainting) and undo mosaic censoring (where parts of images have had their resolution reduced).
In this example image, areas colored bright green are known as bar censors, and these areas can be plausibly filled in using DeepCreamPy:
DeepCreamPy's original repository has been deleted by its creator but a mirror can be found on GitHub. The links in the project's documentation to download DeepCreamPy's TensorFlow models still appear to be working.
This post discusses how I've improved the performance of DeepCreamPy with negligible impact to model accuracy. The link to the optimized project itself is at the end of the page.
performance comparison
As an overview, we will compare the performance between the original and altered models and code. All results were obtained on a server with 8 virtual AMD cores and 16GB of RAM.
image processing time
The example image was processed11. The original models were ran on the default version of TensorFlow which is not compiled with AVX support. Support for AVX may improve TensorFlow's times, but would require building TensorFlow from source or using a community wheel, both of which are not ideal. 4 times and the overall time was averaged. There are no perceivable differences in the overall quality of the resultant image.22. ONNX and TensorFlow computations are non-deterministic so there will always be minor differences between runs.
peak memory usage
Memory usage is for just one model. DeepCreamPy has separate model weights for inpainting and removing mosaic censors which means you need approximately 12GB of RAM to run both of the original models simultaneously.
On Linux systems, the historical peak memory usage for a Python process can be obtained through the following code:
import resource
usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
print("Peak memory usage in KB:", usage)
Note that this only obtains the memory usage for the process itself and does not include its descendants.
model size on disk
As above, this is only the disk usage for one model.
model load time
As above, this is only the time it takes to load one model.
converting the TensorFlow models to ONNX
TensorFlow is a machine learning framework that is often used to define and train neural networks. DeepCreamPy uses TensorFlow and so its model files are in the TensorFlow checkpoint format.
ONNX (Open Neural Network Exchange) is an open format for representing neural networks and machine learning models. Using the ONNX format is advantageous because many runtimes exist to run ONNX models on different platforms and hardware. There are also many tools to convert between the ONNX format and other model formats.
We will use tf2onnx to convert DeepCreamPy's TensorFlow models to ONNX with no loss in model accuracy. Importantly, tf2onnx also optimizes the resultant ONNX model during the conversion process.
finding the model inputs and outputs
Neural networks are represented as graphs where the nodes are neurons. Each node has inputs and outputs which accept and produce tensors (N-dimensional arrays of numbers). Every input and output in a TensorFlow graph is uniquely named to identify them. tf2onnx needs to know what the overall graph inputs and outputs are so that it can define them in the ONNX format and perform its optimizations correctly.
The easiest way to find the names of the inputs and outputs in a graph is to print them out. Inside DeepCreamPy's model.py predict method, we can see what the input and output tensors of the graph are:
def predict(self, censored, unused, mask):
return self.sess.run(
self.image_result,
feed_dict={
self.X: censored,
self.Y: unused,
self.MASK: mask
}
)
self.image_result
is the output tensor, and self.X
, self.Y
, and self.MASK
are input tensors.
We print them out during a call of predict
to show their names:
Tensor("add:0", shape=(1, 256, 256, 3), dtype=float32)
Tensor("Placeholder:0", shape=(1, 256, 256, 3), dtype=float32)
Tensor("Placeholder_1:0", shape=(1, 256, 256, 3), dtype=float32)
Tensor("Placeholder_2:0", shape=(1, 256, 256, 3), dtype=float32)
Now that we have their names, we can invoke tf2onnx's command line tool to convert the models:
$ python3 -m tf2onnx.convert \
--checkpoint ./models/bar/Train_775000.meta \
--output bar.onnx \
--outputs add:0 \
--inputs Placeholder:0,Placeholder_1:0,Placeholder_2:0
...
... - ERROR - Tensorflow op [CB1/ExtractImagePatches_1: ExtractImagePatches] is not supported
... - ERROR - Tensorflow op [CB1/ExtractImagePatches: ExtractImagePatches] is not supported
... - ERROR - Unsupported ops: Counter({'ExtractImagePatches': 2})
...
However, we get an error.
stubbing the ExtractImagePatches
operation
tf.extract_image_patches
is a
TensorFlow 1 operation33.
In TensorFlow 2, the function has been renamed
to tf.image.extract_patches
.
that produces multiple (potentially overlapping) copies of regions of an image represented as a tensor.
The exact details of how it works are not important.
The error produced by tf2onnx means that this operation is unable to be converted
into equivalent ONNX operations.
To work around this, we mark this TensorFlow operation as a custom operation. Later, we will handle it during the execution of the ONNX model by delegating this operation directly to TensorFlow through Python. As we will learn, this strategy is bad for performance.
First we need a function that rewrites the offending operation node:
from tf2onnx import utils, constants
def convert_extract_image_patches(ctx, node, _name, _args):
node.type = "ExtractImagePatches"
node.domain = constants.CONTRIB_OPS_DOMAIN
ksizes = node.get_attr_value("ksizes")
# ...
ksizes_const = ctx.make_const(
utils.make_name("ksizes_const"),
np.array(ksizes, dtype=np.int32)
)
# ...
# Remove all node attributes.
for key in list(node.attr.keys()):
del node.attr[key]
# Add attribute values as inputs to operation.
ctx.replace_inputs(node, node.input + [
ksizes_const.output[0],
# ...
])
TensorFlow's extract_image_patches
operation accepts keyword arguments as parameters, but these are
not automatically converted to inputs that we can read from during model execution. Hence, we construct
new constants for keyword arguments such as ksizes
and add them as inputs to our custom operation node.
Then, we invoke tf2onnx's conversion API with our ExtractImagePatches
handler:
# ...
(onnx_graph, _) = tf2onnx.convert.from_graph_def(
graph, input_names=inputs, output_names=outputs,
custom_op_handlers={"ExtractImagePatches": (convert_extract_image_patches, [])},
extra_opset=[utils.make_opsetid(constants.CONTRIB_OPS_DOMAIN, 1)],
)
# ...
With this, the TensorFlow models are converted without errors.
The produced ONNX models are already less than 15MB in size. One reason for this is that DeepCreamPy uses a GAN (generative adversarial network) structure for its model which consists of a generator (that produces plausible image fills) and a discriminator (that essentially decides whether the fills are accurate or not). Since the model output does not use the discriminator, this entire part of the graph has been pruned away.
The full code for this section is in this file.
running the ONNX models
To run our new ONNX models, we use the onnxruntime
package:
import onnxruntime
session_options = onnxruntime.SessionOptions()
session = onnxruntime.InferenceSession("bar.onnx", session_options)
def predict(self, censored, unused, mask):
# ...
[output] = session.run(["add:0"], {
"Placeholder:0": censored,
"Placeholder_1:0": unused,
"Placeholder_2:0": mask,
})
# Return our only output.
return output
However, some adjustments need to be made for this code to work.
delegating ExtractImagePatches
To handle our stubbed ExtractImagePatches
operation, we delegate its execution directly to TensorFlow.
We use the onnxruntime_extensions
package to register a handler for this operation:
import tensorflow as tf
from onnxruntime_extensions import onnx_op, PyCustomOpDef
tf = tf.compat.v1
tf.enable_eager_execution()
@onnx_op(op_type="ExtractImagePatches",
inputs=[PyCustomOpDef.dt_float,
PyCustomOpDef.dt_int32,
...],
outputs=[PyCustomOpDef.dt_float])
def extract_image_patches(arr, ksizes, ...):
return tf.extract_image_patches(
arr,
ksizes.tolist(),
# ...
).numpy()
ONNX converts inputs to the ExtractImagePatches
into numpy
arrays and passes them to our handler.
We pass these inputs into TensorFlow's extract_image_patches
operation and then convert the result
back into a numpy
array which we return.
We also need to tell ONNX that our Python code contains custom operation handlers:
from onnxruntime_extensions import get_library_path
# ...
session_options.register_custom_ops_library(get_library_path())
With this, all the operations in the ONNX models can be correctly executed.
batching prediction requests
Unfortunately, after some trial and error, if we run our model we get this error:
INVALID_ARGUMENT : Got invalid dimensions for input: Placeholder_2:0 for the following indices
index: 0 Got: 1 Expected: 8
Please fix either the inputs or the model.
Our model inputs expect not 1, but exactly 8 cropped image areas to be passed in. Why is this?
Unfortunately, DeepCreamPy hard-codes a batch size in the construction of its neural network:
def __init__(self, input_height=256, input_width=256, batch_size = 1, ...):
self.batch_size = batch_size
# ...
def build_model(self):
# ...
self.X = tf.placeholder(tf.float32, [self.batch_size, ...])
# ...
The batch size controls how many input samples should be provided to the network at the same time.
Increasing the batch size during model training can reduce the time it takes for the model to train by
allowing the GPU to do more work at once. It is likely that batch_size
was set to 8
during training.
Normally, batch_size
should be set to None
to allow you to train and run the model with a
variable number of input samples.
The original code using TensorFlow does not suffer from this problem as it reconstructs the network graph
with a batch_size
of 1
from scratch. Hence, TensorFlow only loads the saved checkpoint weights into 1 part of
the network instead of all 8 parts.
We could fix this by duplicating the inputs 8 times:
def predict(self, censored, unused, mask):
censored = [censored] * 8
unused = [censored] * 8
mask = [censored] * 8
# ...
[output] = session.run(...)
# Return the first result from our batched output.
return output[0]
But this would be very inefficient.
DeepCreamPy runs predict
separately for every masked (censored) region in an image and so one image
can require multiple prediction requests.
Instead, we collect all prediction requests into a list and run the
model on them in groups of 8. If we don't have enough requests to make
a group of 8, we add useless inputs until we do:
BATCH_SIZE = 8
def run_predictions(requests, ...):
results = []
for start_index in range(0, len(requests), BATCH_SIZE):
batch = requests[start_index:start_index + BATCH_SIZE]
batch_size = len(batch)
# Fill to required size.
while len(batch) < BATCH_SIZE:
batch.append(batch[0])
# Run model on batch.
# ...
return results
Now the model will always receive 8 inputs to be processed.
trying it out
The code works! The model load time takes just a few seconds and the peak memory usage has halved to less than 3GB. But now it takes over 17 seconds (3x slower) to process the example image! Next we will explore how we improve the ONNX model and code.
The full code for this section is in this file.
unbatching the network
The example image contains 17 areas coloured green which produces 17 inputs for the ONNX model to process. However, to make a multiple of 8, we need to add 7 useless inputs. We can avoid doing useless computations by removing the need for batching from the ONNX model.
In machine learning, the term model surgery refers to the general idea of modifying a neural network after it has already been trained. We will perform model surgery to manually alter the ONNX model graph so that it accepts just 1 input instead of 8 inputs.
exploring the model graph
Netron is an online application that allows us to visualize an ONNX model's nodes and edges. We use Netron to visually identify nodes and parts of the model graph that need to be pruned and modified. If we open our ONNX model, we find a component that has been duplicated 8 times:
If we zoom in to the top of this subgraph, we can examine these components' inputs more closely:
The custom ExtractImagePatches
operation outputs 8 combined tensors, each of size 30 x 30 x 2304.
These tensors pass through Reshape
and Transpose
operations before passing
through a Slice
operation44.
Two Slice
operations are depicted because each component
actually receives inputs from two different ExtractImagePatches
operations, not just one.
that selects just 1 of the tensors to feed into a batch component.
If we zoom into the bottom of the overall subgraph, we see that the outputs of all 8
batch components go into a single Concat
operation node:
Now we have all the information needed to remove batching.
modifying the graph
Most operations in a neural network are adaptive to their inputs. This means we can provide tensors of different input sizes to operations, and the sizes of the output tensors will automatically adjust to compensate. Consequently, we only need to make a few precise modifications to the model graph for us to remove batching.
After loading the ONNX model into Python, we define a helper function to get a node by its name:
import onnx
model = onnx.load_model("bar.onnx")
def find_node_by_name(name):
for node in model.graph.nodes:
if node.name == name:
return node
In Netron, we can view the name of nodes by clicking on them.
pruning the inputs to Concat
Concat
is an operation that combines all the inputs it receives into one output tensor.
We retrieve the Concat
node by its name and delete all inputs going into it except the first one:55.
We also need to ensure that the Slice
operation feeding
into the first batch component selects the first tensor from the output of ExtractImagePatches
.
Fortunately, this is already the case.
subgraph_end = find_node_by_name("CB1/concat_6")
del subgraph_end.input[1:]
Now 7 of the 8 batch components don't have their outputs being used anywhere. Later, the model optimization process will remove these unused batch components entirely.
fixing the tensor size of Reshape
We also need to modify the Reshape
operation. Reshape
is an operation that transforms a tensor of one
size into another, as long as the total amount of elements in each tensor remains the same. For example,
if we have a tensor of size 2 x 2:
[[1, 2],
[3, 4]]
then Reshape
can transform this into a tensor of size 1 x 4:
[[1, 2, 3, 4]]
The total number of elements in each tensor is 4.
If we click on the Reshape
operation in Netron, we see that it expects to transform a
tensor from ExtractImagePatches
of size 8 x 30 x 30 x 2304 into a tensor of size 8 x 1 x 1 x 900 x 2304:
However, after removing batching, ExtractImagePatches
will only produce a tensor of size 1 x 30 x 30 x 2304.
To compensate, we need to change Reshape
to produce a tensor of size 1 x 1 x 1 x 900 x 2304.
The Reshape
operation retrieves the intended tensor size from an initializer node so let's
define a helper function to get an initializer by name:
def find_initializer_by_name(name):
for initializer in model.graph.initializer:
if initializer.name == name:
return initializer
Initializer nodes are similar to operation nodes, but they only produce an output and do not accept inputs. They are also not explicitly shown in the graph in Netron.
We then update the tensor size stored in the initializer:
reshape = find_node_by_name("CB1/Reshape_1")
reshape_const = find_initializer_by_name(reshape.input[1])
reshape_const.CopyFrom(from_array(np.int64([1, 1, 1, 900, 2304]), reshape_const.name))
Now Reshape
will correctly transform its input tensor.
fixing the model input and output sizes
We also need to update the dimensions of the model inputs and outputs to accept and produce just 1 tensor instead of 8 combined tensors:
for node in list(model.graph.input) + list(model.graph.output):
node.type.tensor_type.shape.dim[0].dim_value = 1
Without this change, onnxruntime
will raise an exception when we provide an input of an unexpected size.
performing model optimization
Finally, we will manually perform model optimization so that unused nodes and operations can be pruned from the graph. We use tf2onnx's API for model optimization:
import onnx
from tf2onnx.graph import GraphUtil
# ...
model = GraphUtil.optimize_model_proto(model)
onnx.checker.check_model(model, full_check=True)
We invoke ONNX's check_model
function to verify that our graph
is still valid after performing our model surgery.
If we save the model and load it into Netron, we can see the graph has become a lot simpler:
trying it out
After these changes, processing the example image takes under 14 seconds, a 15% improvement from our initial ONNX model. We still have more optimizations to make though!
The full code for this section is in this file.
implementing ExtractImagePatches
in ONNX natively
If we run cProfile to profile our Python code, and visualize the results using SnakeViz, we can get a better idea of where the bottlenecks are in the code:
The region highlighted in purple is the relative time spent in a function named _on_pyop_invocation
.
This is the function that invokes our custom ExtractImagePatches
handler which delegates to TensorFlow.
However, TensorFlow is normally very fast so why is so much time being spent here?
If we dig deeper, we notice that most of the time is actually used in numpy
's tolist
method.
When we return a numpy
array from our custom operation handler, ONNX converts this array into a Python
list before converting it again into its own internal tensor representation. Constructing Python lists
with millions of elements can easily become very slow.
The problem is much worse than at first glance. When running Python code, the Python interpreter's
GIL (global interpreter lock) needs to be acquired.
Even though onnxruntime
(which runs outside of Python) has builtin support for multithreading,
only one ExtractImagePatches
operation can be running at any time, which
means it has to wait for these operations to complete before continuing.
We will resolve these issues by implementing the ExtractImagePatches
operation in ONNX natively.
rewriting ExtractImagePatches
We need to transform ExtractImagePatches
into simpler operations that ONNX supports. Fortunately, there is
a comment on the tensorflow-onnx repository
that contains a prototype of emulating tf.extract_image_patches
using a TensorFlow convolution operation.
Convolutions are essential for computer vision related tasks and so naturally ONNX supports them.
The exact details of how we adapt this prototype are not important. To use it in our ONNX model, we will
write a function that generates a TensorFlow graph emulating the tf.extract_image_patches
function:
import tensorflow as tf
def extract_image_patches(sizes, ...):
# ...
@tf.function
def function(tensor):
# ...
patches_simulation = tf.nn.conv2d(tensor, ...)
# ...
return ...
return function
Our extract_image_patches
function accepts the same keyword arguments as TensorFlow's
tf.extract_image_patches
function, and returns a TensorFlow graph that accepts a single tensor input.
The @tf.function
annotation is what compiles the nested Python function into a TensorFlow graph.
We also define a helper function that takes a tf2onnx graph and copies it into another tf2onnx graph:
from tf2onnx.graph import Graph
# Insert the graph `copy` into the graph `g` while
# linking the input `input_name` in `g` to the input of `copy`.
def insert_graph(g: Graph, copy: Graph, input_name):
# We want to process the nodes in `copy` in a particular
# order so that we never encounter an undefined input.
copy.topological_sort(copy.get_nodes())
# `new_output_names` maps names from `copy` to names in `g`.
new_output_names = {copy.input_names[0]: input_name}
# Copy each node in the `copy` graph.
for node in copy.get_nodes():
if node.type == "Placeholder":
continue
# For the current node, find the inputs in `g` that should be accepted.
inputs = [new_output_names[name] for name in node.input]
# Make a new node in `g` with the same data as `node`.
new_node = g.make_node(node.type, inputs, attr=node.attr,
shapes=node.output_shapes, dtypes=node.output_dtypes)
# Add the output of the new node in `g` into our mapping.
new_output_names[node.output[0]] = new_node.output[0]
# Return the name of the final output of our inserted graph in `g`.
return new_output_names[copy.outputs[0]]
A tf2onnx graph is a data structure that tf2onnx creates during the conversion process to process and transform TensorFlow graphs. It is similar in structure to an ONNX model graph.
Then, we define a rewriter function that takes a tf2onnx graph, looks for ExtractImagePatches
nodes,
and rewrites them with our emulated version:
from tf2onnx.graph import GraphUtil, Graph, Node
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
def rewrite_extract_image_patches(g: Graph, ...):
# Look for nodes of type `ExtractImagePatches`.
pattern = OpTypePattern("ExtractImagePatches", name="extract")
matches = GraphMatcher(pattern).match_ops(ops)
if not matches:
return
for match in matches:
# Get the `ExtractImagePatches` node.
node: Node = match.get_op("extract")
# Create a TensorFlow graph.
f = extract_image_patches(
sizes=node.get_attr_value("ksizes"),
# ...
)
# The tensor size of the subgraph input.
input_signature = [np.empty([1, 32, 32, 256], dtype=np.float32)]
# Convert our `ExtractImagePatches` TensorFlow graph into an ONNX model.
f_model, _ = tf2onnx.convert.from_function(f, input_signature=input_signature)
# Create a tf2onnx graph from our ONNX model.
f_graph = GraphUtil.create_graph_from_onnx_model(f_model)
# Insert the tf2onnx graph into our overall graph.
output_name = insert_graph(g, f_graph, node.input[0])
g.replace_all_inputs(node.output[0], output_name)
# Remove the original `ExtractImagePatches` node.
g.remove_node(node.name)
input_signature
needs to be manually specified because tf2onnx doesn't know what the sizes of inputs
we are providing to our generated subgraph are.
Finally, we pass our rewriter to tf2onnx's API:
# ...
onnx_graph, _ = tf2onnx.convert.from_graph_def(
graph, input_names=inputs, output_names=outputs,
custom_rewriter=[rewrite_extract_image_patches],
)
tf2onnx will run our rewriter on the tf2onnx graph it generates during the model conversion process.
Note that this procedure is not specific to modifying the DeepCreamPy models. You could write similar functions to replace any node of your choosing with a TensorFlow graph if you had other unsupported operations in your model.
trying it out
After these changes, processing the example image takes ~4.6 seconds, a slight improvement over the original TensorFlow model, and a huge improvement over our initial ONNX model.
You can find the full code for this section in this file.
Now that we know we can emulate ExtractImagePatches
with other supported ONNX operations, we can
write a more comprehensive ExtractImagePatches
rewriter for use in converting any TensorFlow model. I've submitted
a pull request in the tensorflow-onnx repository that does this.
other optimizations
These optimizations relate to DeepCreamPy's Python code that use the models, rather than the models themselves.
finding connected components quickly with scipy.ndimage.measurements.label
If we examine the SnakeViz visualization, we see that another function that the code spends a
disproportionate amount of time in is the find_regions
function (colored in gray). This function is in DeepCreamPy's
utils.py
file.
find_regions
returns all the connected components of an image. More specifically, it is being used to get
the coordinates of all the mask areas (colored in green) in an image. For example if we have a 4 x 4 image where
G
represents the color green:
GG..
G...
..GG
..GG
Then, find_regions
will return two lists of coordinates (top left is (0, 0)
):
[(0, 0), (1, 0), (0, 1)]
[(2, 2), (2, 3), (3, 2), (3, 3)]
DeepCreamPy's implementation is standard, but it's written in pure Python which is slow.
We can make it significantly faster by using the measurements.label
function from the SciPy package
which is partially implemented in C.
measurements.label
can take the above image and replace each green pixel with an index
representing the connected component that it belongs to:
1100
1000
0022
0022
The rest of the pixels are replaced with the number 0
. Further details on this function are available
in SciPy's documentation.
Based on a StackOverflow answer, we can rewrite DeepCreamPy's
find_regions
function using measurements.label
66.
Be careful not to use this function directly as scipy.ndimage.measurements.label(...)
! The nested
modules in scipy
are loaded dynamically which can incur significant overhead.
and numpy
:
import numpy as np
from scipy.ndimage import measurements
def find_regions(image, mask_color):
pixels = np.array(image)
array = np.all(pixels == mask_color, axis=2)
labeled, n_components = measurements.label(array)
indices = np.moveaxis(np.indices(array.shape), 0, -1)[:, :, [1, 0]]
regions = []
for index in range(1, n_components + 1):
regions.append(indices[labeled == index].tolist())
# Sort by largest areas first.
regions.sort(key=len, reverse=True)
return regions
Now this function runs near instantly.
maximizing CPU utilization with Python multithreading
Even though TensorFlow and ONNX are both multithreaded, it appears that neither of them fully use all
the CPU cores during model inference. We can exploit this to further speed up our image processing.
First we need to understand how DeepCreamPy's decensor
function works. It is roughly equivalent to
this function written in pseudocode:
def decensor(image):
cropped_masks = find_masks(image)
for cropped_mask in cropped_masks:
replacement = predict(cropped_mask, ...)
paste_replacement_onto_image(image, replacement)
return image
decensor
is the actual function that takes an image and performs all the steps to remove censors from it.
DeepCreamPy's exact implementation can be found in
this file.
In the above pseudocode, predict
is the function that takes an image and runs an ONNX model on it.
If we had an image with 4 masks, then the time spent in Python and ONNX would look like this:
Time: | 0 5 10 15 20 25
Python: | ### # # # #
ONNX: | 1111 2222 3333 4444
The Python code has to wait for the current predict
call to finish before moving onto the next one.
What if we rewrite this function to collect all the prediction requests first, and then run all the predictions simultaneously?:
def decensor(image):
cropped_masks = find_masks(image)
replacements = predict_simultaneously(cropped_masks, ...)
for replacement in replacements:
paste_replacement_onto_image(image, replacement)
return image
Then, because we aren't fully utilizing the CPU with just one predict
call, the
time spent in Python and ONNX might look like this:
Time: | 0 5 10 15 20 25
Python: | ##### ##
ONNX: | 1111 3333
| 2222 4444
Python is single-threaded, so we haven't saved any time there. However, the running of the ONNX models can now overlap with each other, reducing the overall time it takes to process an image.
In the actual Python code, we implement the predict_simultaneously
function with a ThreadPool
:
from multiprocessing.pool import ThreadPool
# ...
with ThreadPool() as pool:
results = pool.map(predict_region, regions)
# ...
pool.map
runs a function on separate Python threads77.
How can Python have threads if it's single-threaded?
Only one Python thread can hold the GIL at any time.
for each element in a list.
It is important to note that this only works because onnxruntime
releases the GIL before running
model inference, allowing the Python code to continue to execute calls to predict_region
.
If it did not release the GIL, then the above code would be equivalent to:
# ...
results = []
for region in regions:
result.append(predict_region(region))
# ...
This optimization is more important the more CPU cores you have.88. Does this optimization work on the original TensorFlow models? I tried this but it actually increased image processing time even though the CPU was clearly being underutilized. It is unclear what the cause is. TensorFlow does release the GIL. With this change, the time it takes to process the example image is reduced from ~2.8 seconds to our final number of ~1.85 seconds.
The full code for this section is in this file.
conclusion
We set out to improve the performance of DeepCreamPy and along the way we explored a variety of techniques for optimizing and surgically manipulating neural networks and ONNX models. In the end, we significantly reduced DeepCreamPy's memory usage and cut the image processing time from ~4.85 seconds down to ~1.85 seconds which is a result that I'm happy with.
project
If you would like to run the optimized models and code yourself, you can find the link to this project here: GitHub repository. The code and models are encapsulated in a Docker web service to make it easy to install and use.
footnotes
-
The original models were ran on the default version of TensorFlow which is not compiled with AVX support. Support for AVX may improve TensorFlow's times, but would require building TensorFlow from source or using a community wheel, both of which are not ideal. ↩
-
ONNX and TensorFlow computations are non-deterministic so there will always be minor differences between runs. ↩
-
In TensorFlow 2, the function has been renamed to
tf.image.extract_patches
. ↩ -
Two
Slice
operations are depicted because each component actually receives inputs from two differentExtractImagePatches
operations, not just one. ↩ -
We also need to ensure that the
Slice
operation feeding into the first batch component selects the first tensor from the output ofExtractImagePatches
. Fortunately, this is already the case. ↩ -
Be careful not to use this function directly as
scipy.ndimage.measurements.label(...)
! The nested modules inscipy
are loaded dynamically which can incur significant overhead. ↩ -
How can Python have threads if it's single-threaded? Only one Python thread can hold the GIL at any time. ↩
-
Does this optimization work on the original TensorFlow models? I tried this but it actually increased image processing time even though the CPU was clearly being underutilized. It is unclear what the cause is. TensorFlow does release the GIL. ↩