Skip to content

0x512 Compiler

Check this repo

1. XLA

XLA is a collection of domain-specific compiler and runtime.

IR is typically represented with MLIR.

MLIR (Multi-Level Intermediate Representation) defines multiple dialects and progressively convert (lowering) towards machine code.

Relevant links

1.1. Frontend

1.1.1. Tensorflow

tf2xla implements XlaOpKernel for supported operation in tf.graph. For example, see the softmax XlaOpKernel here

tf2xla

auto-cluster ops that can be compiled are auto-clustered together. (in tf2, only code inside tf.function will be clustered)

tf_jit

See the following example

@tf.function(jit_compile=True)
def f(x):
  return x + 1

f.experimental_get_compiler_ir(tf.random.normal([10, 10]))(stage='hlo')

This will produce the following

HloModule a_inference_f_13__.9

ENTRY %a_inference_f_13__.9 (arg0.1: f32[10,10]) -> f32[10,10] {
  %arg0.1 = f32[10,10]{1,0} parameter(0), parameter_replication={false}
  %reshape.2 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %arg0.1)
  %constant.3 = f32[] constant(1)
  %broadcast.4 = f32[10,10]{1,0} broadcast(f32[] %constant.3)
  %add.5 = f32[10,10]{1,0} add(f32[10,10]{1,0} %reshape.2,
                               f32[10,10]{1,0} %broadcast.4)
  %reshape.6 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %add.5)
  %tuple.7 = (f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %reshape.6)
  ROOT %get-tuple-element.8 = f32[10,10]{1,0}
    get-tuple-element((f32[10,10]{1,0}) %tuple.7), index=0
}

XLA can fuse multiple operations into one operation, therefore improve in both excution speed and memory usage, this can be seen by setting the stage to optimized_hlo

print(f.experimental_get_compiler_ir(tf.random.normal([10, 10]))(stage='optimized_hlo'))

which yields the following, notice that it is shorter by removing some redundent reshape ops

HloModule a_inference_f_87__.9, alias_passthrough_params=true, entry_computation_layout={(f32[10,10]{1,0})->f32[10,10]{1,0}}

%fused_computation (param_0: f32[10,10]) -> f32[10,10] {
  %param_0 = f32[10,10]{1,0} parameter(0)
  %constant.0 = f32[] constant(1), metadata={op_type="AddV2" op_name="add"}
  %broadcast.0 = f32[10,10]{1,0} broadcast(f32[] %constant.0), dimensions={}, metadata={op_type="AddV2" op_name="add"}
  ROOT %add.0 = f32[10,10]{1,0} add(f32[10,10]{1,0} %param_0, f32[10,10]{1,0} %broadcast.0), metadata={op_type="AddV2" op_name="add"}
}

ENTRY %a_inference_f_87__.9 (arg0.1: f32[10,10]) -> f32[10,10] {
  %arg0.1 = f32[10,10]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
  ROOT %fusion = f32[10,10]{1,0} fusion(f32[10,10]{1,0} %arg0.1), kind=kLoop, calls=%fused_computation, metadata={op_type="AddV2" op_name="add"}
}

1.1.2. Jax

1.1.3. Pytorch

1.2. IR

StableHLO has limited orthogonal well-defined ops (< 100), can be lowered from tensorflow graph. Semantic of HLO is defined in the XLA semantics page

1.3. Optimization

1.3.1. Fusion

The major speedup comes from the fusion process:

  • without compiler: a op grabs something from memory, do some operation, and write it back to memory, the next op will repeat the same step
  • with compiler: if the next op is operating on the same operand, then it directly operating (e.g. within register) without the redundant memory write/read

1.3.2. SPMD Partitioner

1.4. Backend

The backend of XLA produces the target-dependent hardware code. Actually, the backend depends on LLVM. They emit LLVM IR necessary to represent the XLA HLO computation, and then invoke LLVM to emit native code from this LLVM IR. For example, CPU backend implementation for XLA can be found here

1.4.1. IFRT

IFRT is a high-level api focusing on distributed runtime

1.4.2. PJRT

PJRT is a low-level hardware/framework-independent device interface for compilers and runtimes. It consumes StableHLO as inputs and produces/execute executable. Relevant code is here

Different hardware should register its compiler/runtime as a plugin to PJRT. Then internally, XLA uses PJRT API to compile code and execute them. See this doc for more details

PJRT can be implemented with either C or C++ (which get converted to C later)

For example,

virtual absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
    const XlaComputation& computation, CompileOptions options) = 0;

There are two existing implementations of PJRT: stream executor and TFRT. Each supports CPU/GPU/TPU to some extent.

1.4.3. XLA:CPU
1.4.4. XLA:TPU

LLO instruction roughly corresponds to single instruction in TPU ISA

1.4.5. XLA:GPU

Seems to use StreamExecutor as a PJRT client right now (because XLA GPU client instantiation will end up here)

1.5. Tools

run_hlo_module can be used to run dumped HLO

HLO dumps can be obtained from compile().as_text() from jax or XLA_FLAGS in tf

2. Triton

Unlike XLA targeting compiling against the entire graph, triton seems to focus on writing efficient code for a specific kernel.

It is built on top of the tiling concept

openai blog post

2.1. Frontend

2.2. IR

2.3. Optimization

2.4. Backend

a relevant effort for jax is Pallas, which can write kernel language for TPU/GPU

3. Reference