Skip to content

Instantly share code, notes, and snippets.

@riga
Last active May 26, 2024 19:44
Show Gist options
  • Save riga/f9a18023d9f7fb647d74daa9744bb978 to your computer and use it in GitHub Desktop.
Save riga/f9a18023d9f7fb647d74daa9744bb978 to your computer and use it in GitHub Desktop.
Accelerated Graph Evaluation with TensorFlow XLA AOT

This test is performed in a docker container, but all steps should adapt to any environment where TensorFlow is properly installed. Please also see the additional files below this readme.

1. Copy files from this gist

mkdir test_files
cd test_files
curl -L https://gist.github.com/riga/f9a18023d9f7fb647d74daa9744bb978/download -o gist.zip
unzip -j gist.zip && rm gist.zip

2. Launch the container

docker run -ti -v $PWD:/test_files tensorflow/tensorflow:2.8.0

3. Install additional software

apt-get -y update
apt-get -y install nano cmake wget

4. Compile the xla_aot_runtime library once

Note: There is one file missing in the bundled source files and we are going to fetch it manually. Otherwise, when compiling custom code that uses our AOT compiled model down the road we would see undefined references to xla::CustomCallStatusGetMessage in libtf_xla_runtime.a during linking. The missing file adds this exact symbol to the xla_aot_runtime library.

# remember the TF install path
export TF_INSTALL_PATH="/usr/local/lib/python3.8/dist-packages/tensorflow"

cd "${TF_INSTALL_PATH}/xla_aot_runtime_src"

# download the missing file
( cd tensorflow/compiler/xla/service && wget https://raw.githubusercontent.com/tensorflow/tensorflow/v2.8.0/tensorflow/compiler/xla/service/custom_call_status.cc )

# compile and create the static library libtf_xla_runtime.a
cmake .
make -j

5. Create the SavedModel

cd /test_files
TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" python create_model.py

6. XLA AOT compile via saved_model_cli

saved_model_cli aot_compile_cpu \
    --dir my_model \
    --tag_set serve \
    --signature_def_key default \
    --output_prefix my_model \
    --cpp_class MyModel

This should have created my_model.h, my_model.o, my_model_makefile.inc and my_model_metadata.o.

Note: To compile for architecturs other than the default (x86_64), add a LLVM-style --target_triplet. Examples can be found here.

7. Compile the test_model.cc program

g++ \
    -D_GLIBCXX_USE_CXX11_ABI=0 \
    -I${TF_INSTALL_PATH}/include \
    -L${TF_INSTALL_PATH}/xla_aot_runtime_src \
    test_model.cc my_model.o \
    -o test_model \
    -lpthread -ltf_xla_runtime

or

make

8. Execute it

./test_model

You should see result: [20, 25].

# coding: utf-8
"""
Script that creates a dummy graph as a SavedModel named "my_model" in the same directory.
Run as:
> TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" python create_model.py
"""
import tensorflow as tf
@tf.function(jit_compile=True)
def my_model(x):
"""
Dummy model that does nothing expect for reducing axis 1 via sum.
"""
return tf.reduce_sum(x, axis=1)
if __name__ == "__main__":
import os
this_dir = os.path.dirname(os.path.abspath(__file__))
model_dir = os.path.join(this_dir, "my_model")
# save the model with a concrete signature
tf.saved_model.save(my_model, model_dir, signatures={
"default": my_model.get_concrete_function(tf.TensorSpec(shape=[2, 5], dtype=tf.float32)),
})
INC = -I${TF_INSTALL_PATH}/include
LIB = -L${TF_INSTALL_PATH}/xla_aot_runtime_src
LIBS = -lpthread -ltf_xla_runtime
CXXFLAGS = -D_GLIBCXX_USE_CXX11_ABI=0
.phony: all clean
all: test_model
test_model: test_model.cc my_model.o
g++ ${CXXFLAGS} ${INC} ${LIB} $^ -o $@ ${LIBS}
clean:
\rm test_model
//
// Test program that uses the AOT compiled model
// with a fixed shape of (2, 5). Adapted from
// https://www.tensorflow.org/xla/tfcompile
//
#define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL
#include <iostream>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "my_model.h"
int main(int argc, char** argv) {
// threadpool settings
Eigen::ThreadPool tp(1);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
// instantiate the model
MyModel myModel;
myModel.set_thread_pool(&device);
// define dummy inputs
const float args[10] = {0, 2, 4, 6, 8, 1, 3, 5, 7, 9};
// copy to the input memory
std::copy(args + 0, args + 10, myModel.arg0_data());
// run it
myModel.Run();
// check result (myModel is doing a simple sum reduction along axis 1)
float* result = myModel.result0_data();
std::cout << "result: [" << *result << ", " << *(result + 1) << "]" << std::endl;
return (*result == 20.0 && *(result + 1) == 25.0) ? 0 : 1;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment