Skip to content

0x520 Tensorflow

1. API

use https://cs.opensource.google/tensorflow for code search

1.1. Tensor

tf.Tensor is immutable tensor. A number of specialized tensors are available (e.g. tf.Variable, tf.constant, tf.placeholder and tf.RaggedTensor)

tf.Variable is mutable tensor and requires an initial value for the variable and maintains states during execution.

v = tf.Variable(1.)
v.assign(2.)
v.assign_add(0.5)

tf.Module is a named container for tf.Variables, it has variables, trainable_variables and submodules.

1.2. Operation

tf.raw_ops is the namespace of low-level raw operations in Tensorflow, its list is here

1.2.1. Control Flow

A few level of TF APIs for control flow:

  • high-level API: tf.map_fn, tf.case, ...
  • Base API: tf.cond, tf.while_loop, ...

Grad of tf.cond(pred, fn1, fn2) is computed as tf.cond(pred, grad(fn1), grad(fn2))

Grad of tf.while_loop(cond_fn, body_fn, loop_vars) is computed as tf.while_loop(lambda i, g_vars: i < N, lambda i, g_vars: (i+1, grad(body_fn)(g_vars), g_grad_ys)) while \(N\) is the repetition number kept in while.

In tf1, Base APIs were converted to low-level API: Enter, Switch, Exit, Merge, NextIteration. Switch and Merge manage start and ends of conditional execution, dead tensors flow thorugh untaken path as well. Enter, Exit and NextIteration manage frames for while loops, one frame per loop iteration. It allows ops in loop body to run multiple times.

See this paper and this whitepaper for control-flow ops implementation

In tf2, it is replaced with some functional ops of If and While. but getting lowered (i.e. inline)to tf1 before execution at least at the time of this video

Inside tf.function, python control flow is converted to base API using Autograph

1.3. Resource

Resource represents Stateful object in tensorflow (with DT_RESOURCE dtype). For example, tf.Variable is one of the resource.

Marking Statefulness will enforce:

  • no constant folding (because variable might be updated)
  • no common subexpression elimination (otherwise all layers will be intiailized with the same random values)

A tf.Variable is a handler (i.e. pointer) to some underlying Tensor. Reading/Writing handler will insert the ReadVariableOp and AssignVariableOp operations.

2. Trace

When a function decorated with @tf.function get executed, it will be traced (not when they get defined) into tf.Graph. During tracing, shape inference is also happening. Traced graph will be cached.

Importantly, tf.Graph-level tracing allows unknown shape by feeding something like tf.TensorSpec([1, None]) to input_signature. This prevents tf from retracing the graph when specialising shape are feeded.

2.1. PolymorphicFunction

tf.Graph is not enabled by default in eager mode, tf2 builds a graph by decorating tf.function and tracing, see tf.function doc and PolymorphicFunction doc

Basically, tf.function creates a PolymorphicFunction (or GenericFunction), which can encapsulates several tf.Graph, further decorating with jit_compile=True will trigger the compilation

PolymorphicFunction tf.function wraps a Python function, returning a PolymorphicFunction object. It manages a set of ConcreteFunctions and automatically picks the right one for your inputs. The interface can be forced with input_signature with tf.TensorSpec

# f is a PolymorphicFunction
@tf.function
def f(x):
  return x + 1

# input_signature can be constrainted with tf.TensorSpec
@tf.function(input_signature=[tf.TensorSpec([1, None])])
def constrained_foo(t):
  print("tracing...")
  return t

2.2. ConcreteFunction

When fed with TensorSpec arguments into GenericFunction, it traces to create a tf.Graph and wraps it in a ConcreteFunction, also known as a trace.

# f.get_concrete_function(TensorSpec()) returns a ConcreteFunction
concrete_function = f.get_concrete_function(tf.TensorSpec((1,), dtype=tf.int32))

# if a specific tensor is passed, it is abstracted into a TensorSpec by ignoring its values
concrete_function = f.get_concrete_function(tf.constant([1.]))

Signature of concrete function can be accessed by concrete_function.inputs and concrete_function.outputs. Both of them are SymbolicTensor.

# inputs/outputs of concrete_function are SymbolicTensor (i.e. not eager tensor)
assert tf.is_symbolic_tensor(concrete_function.inputs[0])
assert tf.is_symbolic_tensor(concrete_function.outputs[0])

A ConcreteFunction manages two elements:

  • AtomicFunction: further contains the actual FuncGraph
  • captured inputs
print(concrete_function.function_def)
# signature {
#   name: "__inference_f_51110"
#   input_arg {
#     name: "x"
#     type: DT_INT32
#   }
#   output_arg {
#     name: "identity"
#     type: DT_INT32
#   }
# }
# node_def {
#   name: "add/y"
#   op: "Const"
#   attr {
#     key: "dtype"
#     value {
#       type: DT_INT32
#     }
#   }
#   attr {
#     key: "value"
#     value {
#       tensor {
#         dtype: DT_INT32
#         tensor_shape {
#         }
#         int_val: 1
#       }
#     }
#   }
# }...

2.2.1. AtomicFunction (FuncGraph)

AtomicFunction wraps a FuncGraph in its cached graph, which is the actual graph in ConcreteFunction. The input to atomic function is typically user-inputs and captured inputs (e.g. variables)

AtomicFunction is callable and can be extracted by inference_fn

# AtomicFunction
concrete_function.inference_fn

# FuncGraph
graph = concrete_function.inference_fn.graph

# or simply
graph = concrete_function.graph

print(graph.as_graph_def())
# node {
#   name: "x"
#   op: "Placeholder"
#   attr {
#     key: "_user_specified_name"
#     value {
#       s: "x"
#     }
#   }
#   attr {
#     key: "dtype"
#     value {
#       type: DT_INT32
#     }
#   }
#   attr {
#     key: "shape"
#     value {
#       shape {
#       }
#     }
#   }
# }

print(graph.operations)
# [<tf.Operation 'x' type=Placeholder>,
#  <tf.Operation 'add/y' type=Const>,
#  <tf.Operation 'add' type=AddV2>,
#  <tf.Operation 'Identity' type=Identity>]

2.2.2. Captured Objects

Many objects are captured during tracing. concrete_function.captured_inputs are implicitly passed to the function as args next time after tracing.

See the following example:

a = tf.Variable([3.14, 3.14]) # captured as ResourceHandle (i.e just pointer)
b = tf.constant(1.0) # captured as actual tensor
c = 2. # built into graph's Add node attribute directly

@tf.function
def f(x):
  print('tracing...')
  d = tf.constant(3.) # built into graph's Const node
  return x + a + b + c + d

#<tf.Tensor: shape=(2,), dtype=float32, numpy=array([9.14, 9.14], dtype=float32)>
f(0.) # trigger tracing, 9.14 = 3.14 + 1 + 2 + 3

# [<tf.Tensor: shape=(), dtype=resource, value=<ResourceHandle(name="Variable/14", device="/job:localhost/replica:0/task:0/device:CPU:0", container="Anonymous", type="tensorflow::Var", dtype and shapes : "[ DType enum: 1, Shape: [2] ]")>>,
 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>]
# a, b are captured as inputs here 
f.get_concrete_function(0.).captured_inputs

# making any changes to a,b,c will not cause retracing or impact on the results, for example
b = tf.constant(2.0)
f(0.) # no re-tracing, result are still 9.14
f.get_concrete_function(0.).captured_inputs # 2nd captured input is still 1.0

# make assign can impact result but without retracing, this is because the captured_input for variable is ResourceHandle (pointer) not the actual values
a.assign([4.14, 4.14])
f(0.) # no-retracing, 10.14 = 4.14 + 1 + 2 + 3

3. Transform

3.1. Pre-placement

If/While ops are lowered into low-level primitives (e.g. Switch, Merge)

3.2. Cluster

auto-cluster ops that can be compiled are auto-clustered together. (in tf2, only code inside tf.function will be clustered)

tf_jit

3.3. Placer

3.4. Grappler

3.5. Partitioner

Partitioner split the graph on different device and communication primitive (e.g. send/recv) is inserted to transfer data.

Actual communication should implement rendezvous interface. See this blog

4. Compile

The execution path is roughly

\[\text{python} \to[\text{trace}]\to \text{tf.Graph} \to[\text{lower}]\to \text{HLO} \to[\text{compile}]\to \text{native}\]

Watch this youtube video for introduction

4.1. Lower

Each traced ConcreteFunction (tf.graph) can be lowered to HLO when the shape changes, this is because XLA needs static shape but ConcreteFunction allows dynamic shape (i.e None)

Note that lowering/compiling can happen even without retracing. see the following example:

# dump XLA artifacts for debugging
# note that this line has to be added at the beginning of code
os.environ['XLA_FLAGS'] = '--xla_dump_to=/tmp/dump'

@tf.function(input_signature=[tf.TensorSpec([1, None])], jit_compile=True)
def f(t):
  print("tracing...")
  return tf.sin(t)

# tracing...
# HloModule a_inference_f_2574__.7, entry_computation_layout={(f32[1,1]{1,0})->f32[1,1]{1,0}}...
f.experimental_get_compiler_ir(tf.constant([[1.,]]))(stage='hlo')

# not triggering tracing, but trigging compile. notice the shape of args has changed
# HloModule a_inference_f_2574__.7, entry_computation_layout={(f32[1,2]{1,0})->f32[1,2]{1,0}}...
f.experimental_get_compiler_ir(tf.constant([[1.,2]]))(stage='hlo')

Compiled results will also get cached

4.2. Optimize

optimized results after HLO passes can be observed by diffing before_optimizations.txt and _after_optimizations.txt or checking

# before optimization
f.experimental_get_compiler_ir(tf.constant([[1.,2]]))(stage='hlo')

# after optimization
f.experimental_get_compiler_ir(tf.constant([[1.,2]]))(stage='optimized_hlo')

4.3. Compile

When xla_dump_to is enabled, it will generate a few files

module_0002.a_inference_f_6__.7.before_optimizations.txt
module_0002.a_inference_f_6__.7.cpu_after_optimizations-buffer-assignment.txt
module_0002.a_inference_f_6__.7.cpu_after_optimizations.txt
module_0002.a_inference_f_6__.7.ir-no-opt.ll
module_0002.a_inference_f_6__.7.ir-with-opt.ll
module_0002.a_inference_f_6__.7.o

ll file contains LLVM IR, for example my module_0002.a_inference_f_6__.7.ir-with-opt.ll has the following text content:

; ModuleID = '__compute_module'
source_filename = "__compute_module"
target datalayout = ""
target triple = ""

; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(readwrite, inaccessiblemem: none) uwtable
define void @a_inference_f_6__.7(ptr nocapture readnone %retval, ptr noalias nocapture readnone %run_options, ptr noalias nocapture readnone %params, ptr noalias nocapture readonly %buffer_table, ptr noalias nocapture readnone %status, ptr noalias nocapture readnone %prof_counters) local_unnamed_addr #0 {
entry:
  %0 = getelementptr inbounds i8, ptr %buffer_table, i64 8
  %arg0.1 = load ptr, ptr %0, align 8, !invariant.load !0, !dereferenceable !1, !align !1
  %sine.3 = load ptr, ptr %buffer_table, align 8, !invariant.load !0, !dereferenceable !1, !align !1
  %1 = load float, ptr %arg0.1, align 4, !invariant.load !0, !noalias !2
  %2 = tail call float @llvm.sin.f32(float %1)
  store float %2, ptr %sine.3, align 4, !alias.scope !2
  ret void
}

; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare float @llvm.sin.f32(float) #1

attributes #0 = { mustprogress nofree norecurse nosync nounwind willreturn memory(readwrite, inaccessiblemem: none) uwtable "denormal-fp-math"="preserve-sign" "no-frame-pointer-elim"="false" }
attributes #1 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) }

!0 = !{}
!1 = !{i64 4}
!2 = !{!3}
!3 = !{!"buffer: {index:0, offset:0, size:4}", !4}
!4 = !{!"XLA global AA domain"}

o file is the object file of-course

5. Distribution

All tensorflow's visible devices can be retrieved with

# this typically includes CPU as well
logging.info("All TF devices: %s", tf.config.list_logical_devices())

5.1. GShard (xla sharding)

Implemented as xla_sharding in tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py

This is probably the lowest-level sharding API to annotate tensor

5.2. tf.distribute

TPUStrategy is useful to create a TPU spmd runtime

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
    topology,
    computation_shape=[2, 2, 1, 1], # == topology.mesh_shape
    num_replicas=1)
strategy = tf.distribute.TPUStrategy(
    resolver, experimental_device_assignment=device_assignment)

strategy.run(concrete_function, (arg_0, arg_1))

5.3. DTensor

Latest SPMD feature is exposed using DTensor, which is a low-level SPMD API for tensorflow. See this DTensor tutorial and distributed training with DTensors

from tensorflow.experimental import dtensor
tf_tpu_devices = [f'TPU:{i}' for i in range(4)]
mesh_2d = dtensor.create_mesh([('x', 4)], devices=tf_tpu_devices)
layout = dtensor.Layout(['x', dtensor.UNSHARDED], mesh_2d)

def dtensor_from_array(arr, layout, shape=None, dtype=None):
  """Convert a DTensor from something that looks like an array or Tensor.

  This function is convenient for quick doodling DTensors from a known,
  unsharded data object in a single-client environment. This is not the
  most efficient way of creating a DTensor, but it will do for this
  tutorial.
  """
  if shape is not None or dtype is not None:
    arr = tf.constant(arr, shape=shape, dtype=dtype)

  # replicate the input to the mesh
  a = dtensor.copy_to_mesh(arr,
          layout=dtensor.Layout.replicated(layout.mesh, rank=layout.rank))
  # shard the copy to the desirable layout
  return dtensor.relayout(a, layout=layout)

sharded_tf_x = dtensor_from_array(tf_x, layout)

6. Serialization

There are two related concepts:

  • checkpoint: contains only values of parameters, no graph is included. It is typically only useful when source code that will use the saved parameter values is available.
  • SavedModel: includes a serialized description of the computation at the tf.Graph level in addition to the parameter values (checkpoint)

6.1. SavedModel

This section deals with SavedModel

In the most simple case, we can export a single concrete function

# save a tf.function with signature
@tf.function(input_signature=[tf.TensorSpec([2, ], tf.float32)], jit_compile=True)
def f(x):
  return x+1

concrete_f = f.get_concrete_function()
tf.saved_model.save(concrete_f, "/tmp/saved_model_f")

loaded_f = tf.saved_model.load("tmp/saved_model_f")
loaded_f(tf.constant([1., 2.]))

# save a tf.function with signature
@tf.function(jit_compile=True)
def f(x):
  return x+1

concrete_f = f.get_concrete_function(tf.constant([1., 2.]))
tf.saved_model.save(concrete_f, "/tmp/saved_model_f")

SavedModel usually saved a Trackable obj (typically a tf.Module). SavedModel will save

  • signature (alias of ConcreteFunction): tf.function decorated with explicit input_signature will be saved, or target signatures should be passed as an argument.
  • checkpoint (recursive attributes of trackable objects): saved tensor variables
# save a tf.Module
class Adder(tf.Module):

  def __init__(self):
    self.weight = tf.Variable(3.14, dtype=tf.float32)

  @tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.float32)])
  def __call__(self, x):
    return x + self.weight

model = Adder()
tf.saved_model.save(adder, "/tmp/adder")

This will create the following files:

  • saved_model.pb: it is a MetaGraphDef protobuf file, which contains GraphDef as a child. Typically a small file.

  • variables: checkpoint directory, it contain files such as variables.index, variables.data-00000-of-00001. See the next section

To load a SavedModel and use its concrete functions:

f = tf.saved_model.load("/tmp/adder")

# use its signature (concrete function)
saved_concrete_function = f.signatures['serving_default']

# <ConcreteFunction (*, x: TensorSpec(shape=(), dtype=tf.float32, name='x')) -> Dict[['output_0', TensorSpec(shape=(), dtype=tf.float32, name='output_0')]] at ...>
print(saved_concrete_function)


# {'output_0': <tf.Tensor: shape=(), dtype=float32, numpy=4.1400003>}
saved_concrete_function(tf.constant(1.0))

# error: concrete function cannot run over incompatible input_signature
saved_concrete_function(tf.constant(1))

To inspect nodes of a graph from a signature

graph = f.signatures['serving_default'].graph
for node in graph.as_graph_def().node:
  print(node)

Note the graph before saving and the graph after loading is not identical. Compare the following two graph_def. Saved graph typically encapsulates most computing into a single raw_ops.PartitionedCall refering to some __inference___call___ function in the function library. See its doc

class Adder(tf.Module):

  @tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.float32)])
  def __call__(self, x):
    return x + 3.14

# before saving
adder = Adder()
adder.__call__.get_concrete_function(tf.constant(1.0)).graph.as_graph_def()

# after saving
tf.saved_model.save(adder, "/tmp/adder")
saved_adder = tf.saved_model.load("/tmp/adder")
saved_adder_concrete_function = saved_adder.signatures['serving_default']
saved_adder_concrete_function.graph.as_graph_def()

Marking jit_compile=True will produce the same graph before saving, but will attach _XlaMustCompile=True attribute to PartitionedCall during saving. This probably will trigger jit after loading

6.2. MetaGraph

In Tf2 SavedModel, graph can be stored as a SavedModel format, in saved_model.proto, graph is serialized using MetaGraphDef protobuf (it contains GraphDef as a child).

message SavedModel {
  // The schema version of the SavedModel instance. Used for versioning when
  // making future changes to the specification/implementation. Initial value
  // at release will be 1.
  int64 saved_model_schema_version = 1;

  // One or more MetaGraphs.
  repeated MetaGraphDef meta_graphs = 2;
}

MetaGraphDef proto is defined as follows:

message MetaGraphDef {
  // Meta information regarding the graph to be exported.  To be used by users
  MetaInfoDef meta_info_def = 1;

  // GraphDef.
  GraphDef graph_def = 2;

  // signature_def: Map from user supplied key for a signature to a single
  // SignatureDef.
  map<string, SignatureDef> signature_def = 5;

  // Extra information about the structure of functions and stateful objects.
  SavedObjectGraph object_graph_def = 7;
}

GraphDef encapsulate tf.Graph level information as described above

MetaInfoDef contains the following information:

  • stripped_op_list: a copy of the OpDefs used by the producer of this graph_def
  • tags: e.g. gpu,serve
  • function_aliases: FunctionDef mapping (note this is saved using tf.saved_model.SaveOptions)

SavedObjectGraph used to reconstruct the object structure

6.3. GraphDef

A tf.Graph is the raw, language-agnostic, portable representation of a TensorFlow computation.

Its protobuf is as follows:

message GraphDef {
  repeated NodeDef node = 1;
  FunctionDefLibrary library = 2;
  VersionDef versions = 4;
}

FunctionDefLibrary contains repeated FunctionDef

message FunctionDefLibrary {
  repeated FunctionDef function = 1;
  repeated GradientDef gradient = 2;
  repeated RegisteredGradient registered_gradients = 3;
}

message FunctionDef {
  OpDef signature = 1;
  map<string, AttrValue> attr = 5;
  map<uint32, ArgAttrs> arg_attr = 7;
  map<uint32, uint32> resource_arg_unique_id = 8;
  repeated NodeDef node_def = 3;
  map<string, string> ret = 4;
  map<string, string> control_ret = 6;
}

6.3.1. Node

both GraphDef and FunctionDef are defined using NodeDef

message NodeDef {
  string name = 1;
  string op = 2;
  repeated string input = 3;
  string device = 4;
  map<string, AttrValue> attr = 5;

  message ExperimentalDebugInfo {
    repeated string original_node_names = 1;
    repeated string original_func_names = 2;
  }

  // This stores debug information associated with the node.
  ExperimentalDebugInfo experimental_debug_info = 6;
  FullTypeDef experimental_type = 7;
}

Node represent a tf.Operation which takes zero or more Tensor objects as input, and produces zero or more Tensor objects as output

6.4. SavedObjectGraph

SavedObjectGraph used to reconstruct the object structure.

message SavedObjectGraph {
  // Flattened list of objects in the object graph.
  //
  // The position of the object in this list indicates its id.
  // Nodes[0] is considered the root node.
  repeated SavedObject nodes = 1;

  // Information about captures and output structures in concrete functions.
  // Referenced from SavedBareConcreteFunction and SavedFunction.
  map<string, SavedConcreteFunction> concrete_functions = 2;
}

6.5. Checkpoint

Checkpoint is a serialization of the weight container (e.g. tf.Module), it typically has two files (ref: stackoverflow)

  • variables.index: it is a string-string immutable table(tensorflow::table::Table). Each key is a name of a tensor and its value is a serialized BundleEntryProto. Each BundleEntryProto describes the metadata of a tensor: which "data" file contains the content of a tensor, the offset into that file, checksum, some auxiliary data, etc.
  • variables.data-00000-of-00001: it is TensorBundle collection, save the values of all variables.

To inspect variables (ckpt) from a SavedModel

f = tf.saved_model.load("/tmp/adder")

# trackable attributes are accessible after loading
# this info is stored in MetaGraphDef.object_graph_def

# <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.14>
f.weight

# it is also accessible from signature.variables
signature = f.signatures['serving_default']

# (<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.14>,)
print(signature.variables)

# variables can also be loaded with load_checkpoint
reader = tf.train.load_checkpoint("/tmp/adder/variables/variables")

# {'_CHECKPOINTABLE_OBJECT_GRAPH': tf.string, 'weight/.ATTRIBUTES/VARIABLE_VALUE': tf.float32}
dtype_from_key = reader.get_variable_to_dtype_map()
shape_from_key = reader.get_variable_to_shape_map()

key = 'weight/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
reader.get_tensor(key)

Note that only Trackable objects and attributes will be saved, none-trackable tf objects (e.g. tf.constant) are saved in concrete_function.captured_input, native python is directly built into the model

It is also possible to save/restore a checkpoint without using SavedModel

model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)

# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
# checkpoint.save is called, the save counter is increased.
save_path = checkpoint.save('/tmp/training_checkpoints')

# Restore the checkpointed values to the `model` object.
checkpoint.restore(save_path)

6.5.1. Inference Converter

Inference Converter takes an exported SavedModel and performs the following steps:

For TPU models, it adds TPUPartitionedCall, which wraps a few TPU-specific OPS (e.g. TPUReplicateMetadata, TPUCompilationResult) and the original inference_function. It makes the function to be servable on the TPU.

6.6. TFlite

Unlike SavedModel using Protobuffer, TFLite is another serialization format based on FlatBuffer.

It has very similiar structure as GraphDef, see its structure starting from table.Model

7. Runtime

There are a few runtimes

  • Executor: the traditional runtime
  • TF-XLA: XLA runtime used together with Executor
  • TFRT: new runtime based on mlir (but seems not fully adopted yet)
  • TFLite: interpreter for mobile devices
  • Tensorflow.js: interpreter for Javascript

7.1. Memory Management

Check this blog for allocator implementation

8. Reference

Tensorflow code repository