Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/infinicore/analyzer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

// Convenience header — includes all analyzer components.

#include "analyzer/op_type.hpp"
#include "analyzer/intent_generator.hpp"
#include "analyzer/mutual_awareness_analyzer.hpp"
#include "analyzer/op_trace.hpp"
#include "analyzer/op_type.hpp"
#include "analyzer/optimization_intent.hpp"
#include "analyzer/phase_detector.hpp"
#include "analyzer/resource_sensor.hpp"
#include "analyzer/intent_generator.hpp"
#include "analyzer/mutual_awareness_analyzer.hpp"
16 changes: 10 additions & 6 deletions include/infinicore/analyzer/intent_generator.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "optimization_intent.hpp"
#include "op_trace.hpp"
#include "optimization_intent.hpp"

#include <algorithm>
#include <vector>
Expand Down Expand Up @@ -72,7 +72,9 @@ class IntentGenerator {
PhaseType phase,
const std::vector<OpTraceEntry> &window) const {

if (window.empty()) return 0.0f;
if (window.empty()) {
return 0.0f;
}

size_t heavy_compute_ops = 0;
for (auto &e : window) {
Expand Down Expand Up @@ -201,7 +203,7 @@ class IntentGenerator {

// Fusion is beneficial for bandwidth-bound phases (reduce memory traffic)
hint.prefer_fused_ops = (bottleneck == BottleneckType::BANDWIDTH_BOUND)
|| phase == PhaseType::DECODE;
|| phase == PhaseType::DECODE;

// In-place when memory is tight
hint.prefer_in_place = (bottleneck == BottleneckType::MEMORY_BOUND);
Expand All @@ -218,8 +220,8 @@ class IntentGenerator {

// Async comm overlap for multi-device and communication phases
hint.prefer_async_comm = (device_intents.size() > 1)
&& (phase == PhaseType::GEMM_MLP_DENSE
|| phase == PhaseType::COMMUNICATION);
&& (phase == PhaseType::GEMM_MLP_DENSE
|| phase == PhaseType::COMMUNICATION);

return hint;
}
Expand Down Expand Up @@ -254,7 +256,9 @@ class IntentGenerator {
default:
break;
}
if (match) matching++;
if (match) {
matching++;
}
}

return static_cast<float>(matching) / static_cast<float>(window.size());
Expand Down
6 changes: 3 additions & 3 deletions include/infinicore/analyzer/op_trace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ OpTraceRing &getGlobalOpTrace();
/// This is the function called from the INFINICORE_GRAPH_OP_RECORD_OR_RUN
/// macro hook (when ENABLE_MUTUAL_AWARENESS is defined).
inline void traceOp(OpType op_type,
const size_t *shape, size_t ndim,
uint8_t dtype,
uint8_t device_type, int8_t device_id) {
const size_t *shape, size_t ndim,
uint8_t dtype,
uint8_t device_type, int8_t device_id) {
OpTraceEntry entry;
entry.op_type = op_type;
entry.setShape(shape, ndim);
Expand Down
138 changes: 92 additions & 46 deletions include/infinicore/analyzer/op_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,52 +84,98 @@ enum class OpType : uint8_t {
/// Convert OpType to human-readable string.
inline const char *opTypeToString(OpType type) {
switch (type) {
case OpType::ATTENTION: return "attention";
case OpType::FLASH_ATTENTION: return "flash_attention";
case OpType::CAUSAL_SOFTMAX: return "causal_softmax";
case OpType::PAGED_ATTENTION: return "paged_attention";
case OpType::PAGED_ATTENTION_PREFILL: return "paged_attention_prefill";
case OpType::MHA_KVCACHE: return "mha_kvcache";
case OpType::MHA_VARLEN: return "mha_varlen";
case OpType::SOFTMAX: return "softmax";
case OpType::GEMM: return "gemm";
case OpType::LINEAR: return "linear";
case OpType::MATMUL: return "matmul";
case OpType::INT8_GEMM: return "int8_gemm";
case OpType::SCALED_MM_I8: return "scaled_mm_i8";
case OpType::SILU: return "silu";
case OpType::SILU_AND_MUL: return "silu_and_mul";
case OpType::GELU: return "gelu";
case OpType::SWIGLU: return "swiglu";
case OpType::RELU: return "relu";
case OpType::SIGMOID: return "sigmoid";
case OpType::RMS_NORM: return "rms_norm";
case OpType::ADD_RMS_NORM: return "add_rms_norm";
case OpType::LAYER_NORM: return "layer_norm";
case OpType::EMBEDDING: return "embedding";
case OpType::ROPE: return "rope";
case OpType::KV_CACHING: return "kv_caching";
case OpType::PAGED_CACHING: return "paged_caching";
case OpType::ADD: return "add";
case OpType::MUL: return "mul";
case OpType::SUB: return "sub";
case OpType::SUM: return "sum";
case OpType::RECIPROCAL: return "reciprocal";
case OpType::PER_TENSOR_QUANT_I8: return "per_tensor_quant_i8";
case OpType::PER_TENSOR_DEQUANT_I8: return "per_tensor_dequant_i8";
case OpType::PER_CHANNEL_QUANT_I8: return "per_channel_quant_i8";
case OpType::DEQUANTIZE_AWQ: return "dequantize_awq";
case OpType::DEQUANTIZE_GPTQ: return "dequantize_gptq";
case OpType::RANDOM_SAMPLE: return "random_sample";
case OpType::TOPK: return "topk";
case OpType::TOPK_ROUTER: return "topk_router";
case OpType::TOPK_SOFTMAX: return "topk_softmax";
case OpType::ALLREDUCE: return "allreduce";
case OpType::REARRANGE: return "rearrange";
case OpType::ONES: return "ones";
case OpType::ZEROS: return "zeros";
case OpType::TAKE: return "take";
default: return "unknown";
case OpType::ATTENTION:
return "attention";
case OpType::FLASH_ATTENTION:
return "flash_attention";
case OpType::CAUSAL_SOFTMAX:
return "causal_softmax";
case OpType::PAGED_ATTENTION:
return "paged_attention";
case OpType::PAGED_ATTENTION_PREFILL:
return "paged_attention_prefill";
case OpType::MHA_KVCACHE:
return "mha_kvcache";
case OpType::MHA_VARLEN:
return "mha_varlen";
case OpType::SOFTMAX:
return "softmax";
case OpType::GEMM:
return "gemm";
case OpType::LINEAR:
return "linear";
case OpType::MATMUL:
return "matmul";
case OpType::INT8_GEMM:
return "int8_gemm";
case OpType::SCALED_MM_I8:
return "scaled_mm_i8";
case OpType::SILU:
return "silu";
case OpType::SILU_AND_MUL:
return "silu_and_mul";
case OpType::GELU:
return "gelu";
case OpType::SWIGLU:
return "swiglu";
case OpType::RELU:
return "relu";
case OpType::SIGMOID:
return "sigmoid";
case OpType::RMS_NORM:
return "rms_norm";
case OpType::ADD_RMS_NORM:
return "add_rms_norm";
case OpType::LAYER_NORM:
return "layer_norm";
case OpType::EMBEDDING:
return "embedding";
case OpType::ROPE:
return "rope";
case OpType::KV_CACHING:
return "kv_caching";
case OpType::PAGED_CACHING:
return "paged_caching";
case OpType::ADD:
return "add";
case OpType::MUL:
return "mul";
case OpType::SUB:
return "sub";
case OpType::SUM:
return "sum";
case OpType::RECIPROCAL:
return "reciprocal";
case OpType::PER_TENSOR_QUANT_I8:
return "per_tensor_quant_i8";
case OpType::PER_TENSOR_DEQUANT_I8:
return "per_tensor_dequant_i8";
case OpType::PER_CHANNEL_QUANT_I8:
return "per_channel_quant_i8";
case OpType::DEQUANTIZE_AWQ:
return "dequantize_awq";
case OpType::DEQUANTIZE_GPTQ:
return "dequantize_gptq";
case OpType::RANDOM_SAMPLE:
return "random_sample";
case OpType::TOPK:
return "topk";
case OpType::TOPK_ROUTER:
return "topk_router";
case OpType::TOPK_SOFTMAX:
return "topk_softmax";
case OpType::ALLREDUCE:
return "allreduce";
case OpType::REARRANGE:
return "rearrange";
case OpType::ONES:
return "ones";
case OpType::ZEROS:
return "zeros";
case OpType::TAKE:
return "take";
default:
return "unknown";
}
}

Expand Down
42 changes: 21 additions & 21 deletions include/infinicore/analyzer/op_type_registry.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,36 @@ namespace infinicore::analyzer {
inline OpType opTypeFromName(const char *name) {
static const std::unordered_map<std::string, OpType> registry = {
// Attention
{"FlashAttention", OpType::FLASH_ATTENTION},
{"CausalSoftmax", OpType::CAUSAL_SOFTMAX},
{"PagedAttention", OpType::PAGED_ATTENTION},
{"MhaKVCache", OpType::MHA_KVCACHE},
{"FlashAttention", OpType::FLASH_ATTENTION},
{"CausalSoftmax", OpType::CAUSAL_SOFTMAX},
{"PagedAttention", OpType::PAGED_ATTENTION},
{"MhaKVCache", OpType::MHA_KVCACHE},
{"MultiheadAttentionVarlen", OpType::MHA_VARLEN},
// GEMM / MLP
{"Gemm", OpType::GEMM},
{"I8Gemm", OpType::SCALED_MM_I8},
{"Gemm", OpType::GEMM},
{"I8Gemm", OpType::SCALED_MM_I8},
// Activation
{"SiluAndMul", OpType::SILU_AND_MUL},
{"SwiGLU", OpType::SWIGLU},
{"SiluAndMul", OpType::SILU_AND_MUL},
{"SwiGLU", OpType::SWIGLU},
// Norm
{"RMSNorm", OpType::RMS_NORM},
{"AddRMSNorm", OpType::ADD_RMS_NORM},
{"RMSNorm", OpType::RMS_NORM},
{"AddRMSNorm", OpType::ADD_RMS_NORM},
// Embedding / Positional
{"Embedding", OpType::EMBEDDING},
{"RoPE", OpType::ROPE},
{"Embedding", OpType::EMBEDDING},
{"RoPE", OpType::ROPE},
// KV Cache
{"KVCaching", OpType::KV_CACHING},
{"PagedCaching", OpType::PAGED_CACHING},
{"KVCaching", OpType::KV_CACHING},
{"PagedCaching", OpType::PAGED_CACHING},
// Elementwise
{"Add", OpType::ADD},
{"Mul", OpType::MUL},
{"Add", OpType::ADD},
{"Mul", OpType::MUL},
// Quantization
{"PerTensorQuantI8", OpType::PER_TENSOR_QUANT_I8},
{"PerTensorDequantI8", OpType::PER_TENSOR_DEQUANT_I8},
{"PerChannelQuantI8", OpType::PER_CHANNEL_QUANT_I8},
{"DequantizeAWQ", OpType::DEQUANTIZE_AWQ},
{"PerTensorQuantI8", OpType::PER_TENSOR_QUANT_I8},
{"PerTensorDequantI8", OpType::PER_TENSOR_DEQUANT_I8},
{"PerChannelQuantI8", OpType::PER_CHANNEL_QUANT_I8},
{"DequantizeAWQ", OpType::DEQUANTIZE_AWQ},
// Misc
{"Rearrange", OpType::REARRANGE},
{"Rearrange", OpType::REARRANGE},
};
auto it = registry.find(name);
return it != registry.end() ? it->second : OpType::UNKNOWN;
Expand Down
Loading
Loading