0x520 Tensorflow
1. API
use https://cs.opensource.google/tensorflow for code search
1.1. Tensor
tf.Tensor
is immutable tensor. A number of specialized tensors are available (e.g. tf.Variable
, tf.constant
, tf.placeholder
and tf.RaggedTensor
)
tf.Variable
is mutable tensor and requires an initial value for the variable and maintains states during execution.
v = tf.Variable(1.)
v.assign(2.)
v.assign_add(0.5)
tf.Module
is a named container for tf.Variables
, it has variables, trainable_variables and submodules.
1.2. Operation
tf.raw_ops
is the namespace of low-level raw operations in Tensorflow, its list is here
1.2.1. Control Flow
A few level of TF APIs for control flow:
- high-level API: tf.map_fn, tf.case, ...
- Base API: tf.cond, tf.while_loop, ...
Grad of tf.cond(pred, fn1, fn2)
is computed as tf.cond(pred, grad(fn1), grad(fn2))
Grad of tf.while_loop(cond_fn, body_fn, loop_vars)
is computed as tf.while_loop(lambda i, g_vars: i < N, lambda i, g_vars: (i+1, grad(body_fn)(g_vars), g_grad_ys))
while \(N\) is the repetition number kept in while.
In tf1, Base APIs were converted to low-level API: Enter, Switch, Exit, Merge, NextIteration. Switch
and Merge
manage start and ends of conditional execution, dead tensors flow thorugh untaken path as well. Enter
, Exit
and NextIteration
manage frames for while loops, one frame per loop iteration. It allows ops in loop body to run multiple times.
See this paper and this whitepaper for control-flow ops implementation
In tf2, it is replaced with some functional ops of If
and While
. but getting lowered (i.e. inline)to tf1 before execution at least at the time of this video
Inside tf.function
, python control flow is converted to base API using Autograph
1.3. Resource
Resource represents Stateful
object in tensorflow (with DT_RESOURCE
dtype). For example, tf.Variable
is one of the resource.
Marking Statefulness
will enforce:
- no constant folding (because variable might be updated)
- no common subexpression elimination (otherwise all layers will be intiailized with the same random values)
A tf.Variable
is a handler (i.e. pointer) to some underlying Tensor. Reading/Writing handler will insert the ReadVariableOp
and AssignVariableOp
operations.
2. Trace
When a function decorated with @tf.function
get executed, it will be traced (not when they get defined) into tf.Graph
. During tracing, shape inference is also happening. Traced graph will be cached.
Importantly, tf.Graph
-level tracing allows unknown shape by feeding something like tf.TensorSpec([1, None])
to input_signature. This prevents tf from retracing the graph when specialising shape are feeded.
2.1. PolymorphicFunction
tf.Graph
is not enabled by default in eager mode, tf2 builds a graph by decorating tf.function
and tracing, see tf.function doc and PolymorphicFunction doc
Basically, tf.function
creates a PolymorphicFunction (or GenericFunction), which can encapsulates several tf.Graph
, further decorating with jit_compile=True
will trigger the compilation
PolymorphicFunction tf.function wraps a Python function, returning a PolymorphicFunction
object. It manages a set of ConcreteFunction
s and automatically picks the right one for your inputs. The interface can be forced with input_signature
with tf.TensorSpec
# f is a PolymorphicFunction
@tf.function
def f(x):
return x + 1
# input_signature can be constrainted with tf.TensorSpec
@tf.function(input_signature=[tf.TensorSpec([1, None])])
def constrained_foo(t):
print("tracing...")
return t
2.2. ConcreteFunction
When fed with TensorSpec
arguments into GenericFunction
, it traces to create a tf.Graph and wraps it in a ConcreteFunction
, also known as a trace.
# f.get_concrete_function(TensorSpec()) returns a ConcreteFunction
concrete_function = f.get_concrete_function(tf.TensorSpec((1,), dtype=tf.int32))
# if a specific tensor is passed, it is abstracted into a TensorSpec by ignoring its values
concrete_function = f.get_concrete_function(tf.constant([1.]))
Signature of concrete function can be accessed by concrete_function.inputs
and concrete_function.outputs
. Both of them are SymbolicTensor
.
# inputs/outputs of concrete_function are SymbolicTensor (i.e. not eager tensor)
assert tf.is_symbolic_tensor(concrete_function.inputs[0])
assert tf.is_symbolic_tensor(concrete_function.outputs[0])
A ConcreteFunction
manages two elements:
AtomicFunction
: further contains the actualFuncGraph
- captured inputs
print(concrete_function.function_def)
# signature {
# name: "__inference_f_51110"
# input_arg {
# name: "x"
# type: DT_INT32
# }
# output_arg {
# name: "identity"
# type: DT_INT32
# }
# }
# node_def {
# name: "add/y"
# op: "Const"
# attr {
# key: "dtype"
# value {
# type: DT_INT32
# }
# }
# attr {
# key: "value"
# value {
# tensor {
# dtype: DT_INT32
# tensor_shape {
# }
# int_val: 1
# }
# }
# }
# }...
2.2.1. AtomicFunction (FuncGraph)
AtomicFunction
wraps a FuncGraph
in its cached graph, which is the actual graph in ConcreteFunction
. The input to atomic function is typically user-inputs and captured inputs (e.g. variables)
AtomicFunction is callable and can be extracted by inference_fn
# AtomicFunction
concrete_function.inference_fn
# FuncGraph
graph = concrete_function.inference_fn.graph
# or simply
graph = concrete_function.graph
print(graph.as_graph_def())
# node {
# name: "x"
# op: "Placeholder"
# attr {
# key: "_user_specified_name"
# value {
# s: "x"
# }
# }
# attr {
# key: "dtype"
# value {
# type: DT_INT32
# }
# }
# attr {
# key: "shape"
# value {
# shape {
# }
# }
# }
# }
print(graph.operations)
# [<tf.Operation 'x' type=Placeholder>,
# <tf.Operation 'add/y' type=Const>,
# <tf.Operation 'add' type=AddV2>,
# <tf.Operation 'Identity' type=Identity>]
2.2.2. Captured Objects
Many objects are captured during tracing. concrete_function.captured_inputs
are implicitly passed to the function as args next time after tracing.
See the following example:
a = tf.Variable([3.14, 3.14]) # captured as ResourceHandle (i.e just pointer)
b = tf.constant(1.0) # captured as actual tensor
c = 2. # built into graph's Add node attribute directly
@tf.function
def f(x):
print('tracing...')
d = tf.constant(3.) # built into graph's Const node
return x + a + b + c + d
#<tf.Tensor: shape=(2,), dtype=float32, numpy=array([9.14, 9.14], dtype=float32)>
f(0.) # trigger tracing, 9.14 = 3.14 + 1 + 2 + 3
# [<tf.Tensor: shape=(), dtype=resource, value=<ResourceHandle(name="Variable/14", device="/job:localhost/replica:0/task:0/device:CPU:0", container="Anonymous", type="tensorflow::Var", dtype and shapes : "[ DType enum: 1, Shape: [2] ]")>>,
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>]
# a, b are captured as inputs here
f.get_concrete_function(0.).captured_inputs
# making any changes to a,b,c will not cause retracing or impact on the results, for example
b = tf.constant(2.0)
f(0.) # no re-tracing, result are still 9.14
f.get_concrete_function(0.).captured_inputs # 2nd captured input is still 1.0
# make assign can impact result but without retracing, this is because the captured_input for variable is ResourceHandle (pointer) not the actual values
a.assign([4.14, 4.14])
f(0.) # no-retracing, 10.14 = 4.14 + 1 + 2 + 3
3. Transform
3.1. Pre-placement
If/While ops are lowered into low-level primitives (e.g. Switch, Merge)
3.2. Cluster
auto-cluster ops that can be compiled are auto-clustered together. (in tf2, only code inside tf.function
will be clustered)
3.3. Placer
3.4. Grappler
3.5. Partitioner
Partitioner split the graph on different device and communication primitive (e.g. send/recv) is inserted to transfer data.
Actual communication should implement rendezvous interface. See this blog
4. Compile
The execution path is roughly
Watch this youtube video for introduction
4.1. Lower
Each traced ConcreteFunction (tf.graph) can be lowered to HLO when the shape changes, this is because XLA needs static shape but ConcreteFunction allows dynamic shape (i.e None)
Note that lowering/compiling can happen even without retracing. see the following example:
# dump XLA artifacts for debugging
# note that this line has to be added at the beginning of code
os.environ['XLA_FLAGS'] = '--xla_dump_to=/tmp/dump'
@tf.function(input_signature=[tf.TensorSpec([1, None])], jit_compile=True)
def f(t):
print("tracing...")
return tf.sin(t)
# tracing...
# HloModule a_inference_f_2574__.7, entry_computation_layout={(f32[1,1]{1,0})->f32[1,1]{1,0}}...
f.experimental_get_compiler_ir(tf.constant([[1.,]]))(stage='hlo')
# not triggering tracing, but trigging compile. notice the shape of args has changed
# HloModule a_inference_f_2574__.7, entry_computation_layout={(f32[1,2]{1,0})->f32[1,2]{1,0}}...
f.experimental_get_compiler_ir(tf.constant([[1.,2]]))(stage='hlo')
Compiled results will also get cached
4.2. Optimize
optimized results after HLO passes can be observed by diffing before_optimizations.txt
and _after_optimizations.txt
or checking
# before optimization
f.experimental_get_compiler_ir(tf.constant([[1.,2]]))(stage='hlo')
# after optimization
f.experimental_get_compiler_ir(tf.constant([[1.,2]]))(stage='optimized_hlo')
4.3. Compile
When xla_dump_to is enabled, it will generate a few files
module_0002.a_inference_f_6__.7.before_optimizations.txt
module_0002.a_inference_f_6__.7.cpu_after_optimizations-buffer-assignment.txt
module_0002.a_inference_f_6__.7.cpu_after_optimizations.txt
module_0002.a_inference_f_6__.7.ir-no-opt.ll
module_0002.a_inference_f_6__.7.ir-with-opt.ll
module_0002.a_inference_f_6__.7.o
ll
file contains LLVM IR, for example my module_0002.a_inference_f_6__.7.ir-with-opt.ll
has the following text content:
; ModuleID = '__compute_module'
source_filename = "__compute_module"
target datalayout = ""
target triple = ""
; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(readwrite, inaccessiblemem: none) uwtable
define void @a_inference_f_6__.7(ptr nocapture readnone %retval, ptr noalias nocapture readnone %run_options, ptr noalias nocapture readnone %params, ptr noalias nocapture readonly %buffer_table, ptr noalias nocapture readnone %status, ptr noalias nocapture readnone %prof_counters) local_unnamed_addr #0 {
entry:
%0 = getelementptr inbounds i8, ptr %buffer_table, i64 8
%arg0.1 = load ptr, ptr %0, align 8, !invariant.load !0, !dereferenceable !1, !align !1
%sine.3 = load ptr, ptr %buffer_table, align 8, !invariant.load !0, !dereferenceable !1, !align !1
%1 = load float, ptr %arg0.1, align 4, !invariant.load !0, !noalias !2
%2 = tail call float @llvm.sin.f32(float %1)
store float %2, ptr %sine.3, align 4, !alias.scope !2
ret void
}
; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare float @llvm.sin.f32(float) #1
attributes #0 = { mustprogress nofree norecurse nosync nounwind willreturn memory(readwrite, inaccessiblemem: none) uwtable "denormal-fp-math"="preserve-sign" "no-frame-pointer-elim"="false" }
attributes #1 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) }
!0 = !{}
!1 = !{i64 4}
!2 = !{!3}
!3 = !{!"buffer: {index:0, offset:0, size:4}", !4}
!4 = !{!"XLA global AA domain"}
o
file is the object file of-course
5. Distribution
All tensorflow's visible devices can be retrieved with
# this typically includes CPU as well
logging.info("All TF devices: %s", tf.config.list_logical_devices())
5.1. GShard (xla sharding)
Implemented as xla_sharding in tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py
This is probably the lowest-level sharding API to annotate tensor
5.2. tf.distribute
TPUStrategy
is useful to create a TPU spmd runtime
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology,
computation_shape=[2, 2, 1, 1], # == topology.mesh_shape
num_replicas=1)
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment)
strategy.run(concrete_function, (arg_0, arg_1))
5.3. DTensor
Latest SPMD feature is exposed using DTensor, which is a low-level SPMD API for tensorflow. See this DTensor tutorial and distributed training with DTensors
from tensorflow.experimental import dtensor
tf_tpu_devices = [f'TPU:{i}' for i in range(4)]
mesh_2d = dtensor.create_mesh([('x', 4)], devices=tf_tpu_devices)
layout = dtensor.Layout(['x', dtensor.UNSHARDED], mesh_2d)
def dtensor_from_array(arr, layout, shape=None, dtype=None):
"""Convert a DTensor from something that looks like an array or Tensor.
This function is convenient for quick doodling DTensors from a known,
unsharded data object in a single-client environment. This is not the
most efficient way of creating a DTensor, but it will do for this
tutorial.
"""
if shape is not None or dtype is not None:
arr = tf.constant(arr, shape=shape, dtype=dtype)
# replicate the input to the mesh
a = dtensor.copy_to_mesh(arr,
layout=dtensor.Layout.replicated(layout.mesh, rank=layout.rank))
# shard the copy to the desirable layout
return dtensor.relayout(a, layout=layout)
sharded_tf_x = dtensor_from_array(tf_x, layout)
6. Serialization
There are two related concepts:
checkpoint
: contains only values of parameters, no graph is included. It is typically only useful when source code that will use the saved parameter values is available.SavedModel
: includes a serialized description of the computation at thetf.Graph
level in addition to the parameter values (checkpoint)
6.1. SavedModel
This section deals with SavedModel
In the most simple case, we can export a single concrete function
# save a tf.function with signature
@tf.function(input_signature=[tf.TensorSpec([2, ], tf.float32)], jit_compile=True)
def f(x):
return x+1
concrete_f = f.get_concrete_function()
tf.saved_model.save(concrete_f, "/tmp/saved_model_f")
loaded_f = tf.saved_model.load("tmp/saved_model_f")
loaded_f(tf.constant([1., 2.]))
# save a tf.function with signature
@tf.function(jit_compile=True)
def f(x):
return x+1
concrete_f = f.get_concrete_function(tf.constant([1., 2.]))
tf.saved_model.save(concrete_f, "/tmp/saved_model_f")
SavedModel usually saved a Trackable
obj (typically a tf.Module
). SavedModel will save
signature
(alias ofConcreteFunction
): tf.function decorated with explicit input_signature will be saved, or target signatures should be passed as an argument.checkpoint
(recursive attributes of trackable objects): saved tensor variables
# save a tf.Module
class Adder(tf.Module):
def __init__(self):
self.weight = tf.Variable(3.14, dtype=tf.float32)
@tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.float32)])
def __call__(self, x):
return x + self.weight
model = Adder()
tf.saved_model.save(adder, "/tmp/adder")
This will create the following files:
-
saved_model.pb
: it is aMetaGraphDef
protobuf file, which containsGraphDef
as a child. Typically a small file. -
variables
: checkpoint directory, it contain files such as variables.index, variables.data-00000-of-00001. See the next section
To load a SavedModel and use its concrete functions:
f = tf.saved_model.load("/tmp/adder")
# use its signature (concrete function)
saved_concrete_function = f.signatures['serving_default']
# <ConcreteFunction (*, x: TensorSpec(shape=(), dtype=tf.float32, name='x')) -> Dict[['output_0', TensorSpec(shape=(), dtype=tf.float32, name='output_0')]] at ...>
print(saved_concrete_function)
# {'output_0': <tf.Tensor: shape=(), dtype=float32, numpy=4.1400003>}
saved_concrete_function(tf.constant(1.0))
# error: concrete function cannot run over incompatible input_signature
saved_concrete_function(tf.constant(1))
To inspect nodes of a graph from a signature
graph = f.signatures['serving_default'].graph
for node in graph.as_graph_def().node:
print(node)
Note the graph before saving and the graph after loading is not identical. Compare the following two graph_def. Saved graph typically encapsulates most computing into a single raw_ops.PartitionedCall
refering to some __inference___call___
function in the function library. See its doc
class Adder(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.float32)])
def __call__(self, x):
return x + 3.14
# before saving
adder = Adder()
adder.__call__.get_concrete_function(tf.constant(1.0)).graph.as_graph_def()
# after saving
tf.saved_model.save(adder, "/tmp/adder")
saved_adder = tf.saved_model.load("/tmp/adder")
saved_adder_concrete_function = saved_adder.signatures['serving_default']
saved_adder_concrete_function.graph.as_graph_def()
Marking jit_compile=True
will produce the same graph before saving, but will attach _XlaMustCompile=True
attribute to PartitionedCall
during saving. This probably will trigger jit after loading
6.2. MetaGraph
In Tf2 SavedModel, graph can be stored as a SavedModel format, in saved_model.proto, graph is serialized using MetaGraphDef
protobuf (it contains GraphDef as a child).
message SavedModel {
// The schema version of the SavedModel instance. Used for versioning when
// making future changes to the specification/implementation. Initial value
// at release will be 1.
int64 saved_model_schema_version = 1;
// One or more MetaGraphs.
repeated MetaGraphDef meta_graphs = 2;
}
MetaGraphDef proto is defined as follows:
message MetaGraphDef {
// Meta information regarding the graph to be exported. To be used by users
MetaInfoDef meta_info_def = 1;
// GraphDef.
GraphDef graph_def = 2;
// signature_def: Map from user supplied key for a signature to a single
// SignatureDef.
map<string, SignatureDef> signature_def = 5;
// Extra information about the structure of functions and stateful objects.
SavedObjectGraph object_graph_def = 7;
}
GraphDef
encapsulate tf.Graph
level information as described above
MetaInfoDef
contains the following information:
- stripped_op_list: a copy of the OpDefs used by the producer of this graph_def
- tags: e.g. gpu,serve
- function_aliases: FunctionDef mapping (note this is saved using tf.saved_model.SaveOptions)
SavedObjectGraph
used to reconstruct the object structure
6.3. GraphDef
A tf.Graph
is the raw, language-agnostic, portable representation of a TensorFlow computation.
Its protobuf is as follows:
message GraphDef {
repeated NodeDef node = 1;
FunctionDefLibrary library = 2;
VersionDef versions = 4;
}
FunctionDefLibrary
contains repeated FunctionDef
message FunctionDefLibrary {
repeated FunctionDef function = 1;
repeated GradientDef gradient = 2;
repeated RegisteredGradient registered_gradients = 3;
}
message FunctionDef {
OpDef signature = 1;
map<string, AttrValue> attr = 5;
map<uint32, ArgAttrs> arg_attr = 7;
map<uint32, uint32> resource_arg_unique_id = 8;
repeated NodeDef node_def = 3;
map<string, string> ret = 4;
map<string, string> control_ret = 6;
}
6.3.1. Node
both GraphDef
and FunctionDef
are defined using NodeDef
message NodeDef {
string name = 1;
string op = 2;
repeated string input = 3;
string device = 4;
map<string, AttrValue> attr = 5;
message ExperimentalDebugInfo {
repeated string original_node_names = 1;
repeated string original_func_names = 2;
}
// This stores debug information associated with the node.
ExperimentalDebugInfo experimental_debug_info = 6;
FullTypeDef experimental_type = 7;
}
Node represent a tf.Operation
which takes zero or more Tensor objects as input, and produces zero or more Tensor objects as output
6.4. SavedObjectGraph
SavedObjectGraph
used to reconstruct the object structure.
message SavedObjectGraph {
// Flattened list of objects in the object graph.
//
// The position of the object in this list indicates its id.
// Nodes[0] is considered the root node.
repeated SavedObject nodes = 1;
// Information about captures and output structures in concrete functions.
// Referenced from SavedBareConcreteFunction and SavedFunction.
map<string, SavedConcreteFunction> concrete_functions = 2;
}
6.5. Checkpoint
Checkpoint is a serialization of the weight container (e.g. tf.Module
), it typically has two files (ref: stackoverflow)
variables.index
: it is a string-string immutable table(tensorflow::table::Table). Each key is a name of a tensor and its value is a serializedBundleEntryProto
. Each BundleEntryProto describes the metadata of a tensor: which "data" file contains the content of a tensor, the offset into that file, checksum, some auxiliary data, etc.variables.data-00000-of-00001
: it isTensorBundle
collection, save the values of all variables.
To inspect variables (ckpt) from a SavedModel
f = tf.saved_model.load("/tmp/adder")
# trackable attributes are accessible after loading
# this info is stored in MetaGraphDef.object_graph_def
# <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.14>
f.weight
# it is also accessible from signature.variables
signature = f.signatures['serving_default']
# (<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.14>,)
print(signature.variables)
# variables can also be loaded with load_checkpoint
reader = tf.train.load_checkpoint("/tmp/adder/variables/variables")
# {'_CHECKPOINTABLE_OBJECT_GRAPH': tf.string, 'weight/.ATTRIBUTES/VARIABLE_VALUE': tf.float32}
dtype_from_key = reader.get_variable_to_dtype_map()
shape_from_key = reader.get_variable_to_shape_map()
key = 'weight/.ATTRIBUTES/VARIABLE_VALUE'
print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
reader.get_tensor(key)
Note that only Trackable
objects and attributes will be saved, none-trackable tf objects (e.g. tf.constant) are saved in concrete_function.captured_input
, native python is directly built into the model
It is also possible to save/restore a checkpoint without using SavedModel
model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)
# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
# checkpoint.save is called, the save counter is increased.
save_path = checkpoint.save('/tmp/training_checkpoints')
# Restore the checkpointed values to the `model` object.
checkpoint.restore(save_path)
6.5.1. Inference Converter
Inference Converter takes an exported SavedModel and performs the following steps:
For TPU models, it adds TPUPartitionedCall, which wraps a few TPU-specific OPS (e.g. TPUReplicateMetadata, TPUCompilationResult) and the original inference_function. It makes the function to be servable on the TPU.
6.6. TFlite
Unlike SavedModel using Protobuffer, TFLite is another serialization format based on FlatBuffer.
It has very similiar structure as GraphDef, see its structure starting from table.Model
7. Runtime
There are a few runtimes
- Executor: the traditional runtime
- TF-XLA: XLA runtime used together with Executor
- TFRT: new runtime based on mlir (but seems not fully adopted yet)
- TFLite: interpreter for mobile devices
- Tensorflow.js: interpreter for Javascript
7.1. Memory Management
Check this blog for allocator implementation