Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions AsmParser/tog_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# DEPRECATED (timing path): legacy ONNX Tile-Operation-Graph producer. Builds
# the TOG and serializes it to ONNX for the C++ TileGraphParser. Superseded by
# the C++ trace pipeline (PyTorchSimFrontend/mlir/passes/build_skeleton.py +
# lower_to_emitc.py + cycle_table.py -> a compiled trace .so). Kept live so the
# current pipeline does not break; to be retired once the trace pipeline (P3+)
# stabilizes. See docs/design/togsim_cpp_trace.md.
import os
import sys
import importlib.util
Expand Down
41 changes: 41 additions & 0 deletions PyTorchSimFrontend/extension_codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,19 @@ def load(cls, source_code,
# Run cyclesim
cyclesim = CycleSimulator()
cycle_list = cyclesim.compile_and_simulate(os.path.join(write_path, cycle_binary_name), vectorlane_size, silent_mode=silent_mode)
# Snapshot for the P3-trace hook below: generate_tile_graph consumes
# cycle_list in place (cycle_list.pop(0) per tile), leaving it empty.
cycle_list_for_trace = list(cycle_list)

# Create TOG
# DEPRECATED (timing path): this ONNX-TOG producer -- run_tog ->
# tog_generator.generate_tile_graph -> ONNX -> C++ TileGraphParser --
# is being superseded by the C++ trace pipeline (build_skeleton +
# lower_to_emitc -> compiled .so, + the cycle_table sidecar). The
# per-tile cycle_list / x_offset / w_offset computed here are exactly
# what cycle_table.build_cycle_table will reuse, so both paths stay
# cycle-consistent during the transition. Kept live (pipeline must not
# break); to be retired once the trace pipeline (P3+) stabilizes.
w_offset, x_offset = vectorlane_size, vectorlane_size
if kwargs['loop_size'] is not None and kwargs['loop_size'][-3] < vectorlane_size:
x_offset = kwargs['loop_size'][-3]
Expand All @@ -258,6 +269,36 @@ def load(cls, source_code,
w_offset=w_offset, # FIXME.
vector_lane=vectorlane_size
)

# P3 trace pipeline (opt-in, TORCHSIM_DUMP_TRACE_SO=1): also emit the
# compiled trace producer .so + the cycle-table TSV from the SAME
# post-vcix IR and gem5 cycle_list/offsets, so the trace path can be
# run and compared cycle-consistently against this legacy path.
# Best-effort: never breaks the legacy compile.
if os.environ.get("TORCHSIM_DUMP_TRACE_SO") == "1":
try:
import mlir.ir as ir
from PyTorchSimFrontend.mlir.passes import (
build_skeleton as _bs, cycle_table as _ct, lower_to_emitc as _l2e)
pv = sample_mlir_path + "_postvcix.mlir"
_ctx = ir.Context(); _ctx.allow_unregistered_dialects = True
with _ctx:
_mod = ir.Module.parse(open(pv).read(), _ctx)
_bs.build_skeleton(_mod)
_ntiles = len(_ct._compute_types(_mod))
# align lengths: gem5 gives one numCycles per compute node;
# pad with the last value / truncate if it disagrees.
_cl = list(cycle_list_for_trace)
if _cl and len(_cl) != _ntiles:
_cl = (_cl + [_cl[-1]] * _ntiles)[:_ntiles]
logger.info(f"[P3-trace] cycle_list={cycle_list_for_trace} -> {_cl} "
f"(#tiles={_ntiles}, x_off={x_offset}, w_off={w_offset})")
_tbl = _ct.build_cycle_table(_mod, _cl, x_offset, w_offset)
_ct.dump_cycle_table_tsv(_tbl, os.path.join(write_path, "trace_cycles.tsv"))
_l2e.build_trace_so(pv, os.path.join(write_path, "trace.so"))
logger.info(f"[P3-trace] wrote trace.so + trace_cycles.tsv in {write_path}")
except Exception as e:
logger.warning(f"[P3-trace] trace .so/sidecar dump skipped: {e}")
return key

class CustomAsyncCompile(AsyncCompile):
Expand Down
87 changes: 87 additions & 0 deletions PyTorchSimFrontend/mlir/passes/_mlir_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Small, dependency-light helpers shared across the MLIR passes.

Every pass had its own copy of the same op-walk generator (named variously
`_iter_ops` / `_walk` / `_walk_ops`) and the same one-line attribute builders
(`_i32` / `_i64` / ...). This module is the single source for both.

Import-safety: `walk_ops` is pure block/op attribute access and needs no MLIR
bindings, so this module does NOT import `mlir.ir` at top level -- some passes
(e.g. lower_vlane_idx, decompose_transfer) are deliberately importable without
the bindings present and only touch `mlir.ir` inside their run functions. The
attribute builders therefore import `mlir.ir` lazily; they require an active
MLIR context (the caller's `with ctx:`), exactly as the per-pass copies did.
"""


def walk_ops(block):
"""Yield every op under `block` in program order, recursing into regions.

Snapshots each block's operation list, so a caller may erase ops while
iterating (the strictest of the former copies; a superset of the rest)."""
for op in list(block.operations):
yield op
for region in op.operation.regions:
for b in region.blocks:
yield from walk_ops(b)


def _ir():
import mlir.ir as ir
return ir


def i32(v):
"""`i32` IntegerAttr for `v` (uses the active MLIR context)."""
ir = _ir()
return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), int(v))


def i64(v):
"""`i64` IntegerAttr for `v`."""
ir = _ir()
return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), int(v))


def i64_array(vals):
"""ArrayAttr of `i64` IntegerAttrs for `vals`."""
ir = _ir()
i = ir.IntegerType.get_signless(64)
return ir.ArrayAttr.get([ir.IntegerAttr.get(i, int(v)) for v in vals])


def str_attr(v):
"""StringAttr of `str(v)`."""
ir = _ir()
return ir.StringAttr.get(str(v))


# ---------------------------------------------------------------------------
# attribute readers -- accept an OpView or an Operation; `default` is returned
# when `key` is absent (callers that want the strict "must be present" behaviour
# simply never pass an absent key).
# ---------------------------------------------------------------------------
def _attrs(op):
return getattr(op, "operation", op).attributes


def attr_int(op, key, default=None):
"""Integer value of `op`'s `key` attribute, or `default` if absent."""
ir = _ir()
a = _attrs(op)
return ir.IntegerAttr(a[key]).value if key in a else default


def attr_bool(op, key, default=False):
"""Bool value of `op`'s `key` attribute, or `default` if absent."""
ir = _ir()
a = _attrs(op)
return bool(ir.BoolAttr(a[key]).value) if key in a else default


def attr_i64_array(op, key, default=None):
"""`op`'s `key` ArrayAttr of integers as a Python list, or `default` if
absent (pass `default=[]` for the "missing -> empty" convention)."""
ir = _ir()
a = _attrs(op)
return ([ir.IntegerAttr(x).value for x in ir.ArrayAttr(a[key])]
if key in a else default)
Loading
Loading