-
Notifications
You must be signed in to change notification settings - Fork 542
Matrix compression for federated learning broadcasts #2524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3211565
36d7263
23d7022
ec8f5df
58e592d
7ef87da
ec3f6be
b58b601
136ff91
4087cbe
65e35b3
f01d81f
927fc19
c6e5ec5
43a937b
b173446
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| name: Compression Tests | ||
|
|
||
| on: | ||
| push: | ||
| branches: [ feature/compression ] | ||
| pull_request: | ||
| branches: [ main ] | ||
|
|
||
| jobs: | ||
| test: | ||
| runs-on: ubuntu-latest | ||
|
|
||
| steps: | ||
| - name: Checkout code | ||
| uses: actions/checkout@v3 | ||
|
|
||
| - name: Set up Java 17 | ||
| uses: actions/setup-java@v3 | ||
| with: | ||
| java-version: '17' | ||
| distribution: 'temurin' | ||
|
|
||
| - name: Cache Maven dependencies | ||
| uses: actions/cache@v3 | ||
| with: | ||
| path: ~/.m2/repository | ||
| key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} | ||
| restore-keys: | | ||
| ${{ runner.os }}-maven- | ||
|
|
||
| - name: Build project | ||
| run: mvn clean package -Dmaven.test.skip=true -Dmaven.javadoc.skip=true | ||
|
|
||
| - name: Run compression tests | ||
| run: | | ||
| mvn test \ | ||
| -Dtest=TopKCompressorTest,ProbabilisticQuantizationCompressorTest \ | ||
| -Dmaven.test.failure.ignore=false \ | ||
| -Dmaven.javadoc.skip=true | ||
|
|
||
| - name: Upload test results | ||
| if: always() | ||
| uses: actions/upload-artifact@v4 | ||
| with: | ||
| name: test-results | ||
| path: target/surefire-reports/ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| package org.apache.sysds.runtime.compress; | ||
|
|
||
| import java.io.Serializable; | ||
|
|
||
| /** | ||
| * Generic container for compressed matrix data. | ||
| * Stores the compressed representation along with metadata | ||
| * needed for decompression and size estimation. | ||
| * | ||
| * @author Nirvan C. Udaysingh Jhurree | ||
| */ | ||
| public class CompressedMatrix implements Serializable { | ||
|
|
||
| private static final long serialVersionUID = 1L; | ||
|
|
||
| private final CompressionType type; | ||
| private final int numRows; | ||
| private final int numCols; | ||
| private final Object compressedData; // Technique-specific data | ||
| private final double compressionRatio; | ||
| private final byte[] metadata; // Optional: scaling factors, etc. | ||
|
|
||
| public CompressedMatrix(CompressionType type, int numRows, int numCols, | ||
| Object compressedData, double compressionRatio) { | ||
| this(type, numRows, numCols, compressedData, compressionRatio, null); | ||
| } | ||
|
|
||
| public CompressedMatrix(CompressionType type, int numRows, int numCols, | ||
| Object compressedData, double compressionRatio, | ||
| byte[] metadata) { | ||
| this.type = type; | ||
| this.numRows = numRows; | ||
| this.numCols = numCols; | ||
| this.compressedData = compressedData; | ||
| this.compressionRatio = compressionRatio; | ||
| this.metadata = metadata; | ||
| } | ||
|
|
||
| public CompressionType getType() { return type; } | ||
| public int getNumRows() { return numRows; } | ||
| public int getNumCols() { return numCols; } | ||
| public Object getCompressedData() { return compressedData; } | ||
| public double getCompressionRatio() { return compressionRatio; } | ||
| public byte[] getMetadata() { return metadata; } | ||
|
|
||
| /** Estimate original size in bytes (8 bytes per double) */ | ||
| public long estimateOriginalSizeBytes() { | ||
| return (long) numRows * numCols * 8; | ||
| } | ||
|
|
||
| /** Estimate compressed size in bytes */ | ||
| public long getCompressedSizeBytes() { | ||
| if (compressedData instanceof byte[]) { | ||
| return ((byte[]) compressedData).length; | ||
| } | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public String toString() { | ||
| return String.format("CompressedMatrix[%s, %dx%d, ratio=%.2fx]", | ||
| type.getId(), numRows, numCols, compressionRatio); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| package org.apache.sysds.runtime.compress; | ||
|
|
||
| import java.util.HashMap; | ||
| import java.util.Map; | ||
|
|
||
| /** | ||
| * Immutable configuration for compression in federated operations. | ||
| * Uses the Builder pattern for flexible, readable configuration. | ||
| * | ||
| * Usage example: | ||
| * CompressionConfig config = CompressionConfig.builder() | ||
| * .enable(true) | ||
| * .withType(CompressionType.TOPK) | ||
| * .withSparsity(0.01) | ||
| * .build(); | ||
| * | ||
| * | ||
|
Comment on lines
+16
to
+17
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove empty lines. |
||
| */ | ||
| public class CompressionConfig { | ||
|
|
||
| private final boolean enabled; | ||
| private final CompressionType type; | ||
| private final Map<String, Object> parameters; | ||
|
|
||
| private CompressionConfig(Builder builder) { | ||
| this.enabled = builder.enabled; | ||
| this.type = builder.enabled ? builder.type : CompressionType.NONE; | ||
| this.parameters = new HashMap<>(builder.parameters); | ||
| } | ||
|
|
||
| public boolean isEnabled() { return enabled; } | ||
| public CompressionType getType() { return type; } | ||
| public Map<String, Object> getParameters() { return new HashMap<>(parameters); } | ||
|
|
||
| /** Convenience getter for sparsity parameter (TopK) */ | ||
| public double getSparsity() { | ||
| return (double) parameters.getOrDefault("sparsity", 0.01); | ||
| } | ||
|
|
||
| /** Convenience getter for bits parameter (Quantization) */ | ||
| public int getBits() { | ||
| return (int) parameters.getOrDefault("bits", 4); | ||
| } | ||
|
|
||
| @Override | ||
| public String toString() { | ||
| return String.format("CompressionConfig[enabled=%s, type=%s, params=%s]", | ||
| enabled, type.getId(), parameters); | ||
| } | ||
|
|
||
| // ----------------------------------------------------------------------- | ||
| // Builder | ||
| // ----------------------------------------------------------------------- | ||
|
|
||
| public static Builder builder() { | ||
| return new Builder(); | ||
| } | ||
|
|
||
| public static class Builder { | ||
| private boolean enabled = false; | ||
| private CompressionType type = CompressionType.NONE; | ||
| private final Map<String, Object> parameters = new HashMap<>(); | ||
|
|
||
| public Builder enable(boolean enabled) { | ||
| this.enabled = enabled; | ||
| return this; | ||
| } | ||
|
|
||
| public Builder withType(CompressionType type) { | ||
| this.type = type; | ||
| return this; | ||
| } | ||
|
|
||
| public Builder withParameter(String key, Object value) { | ||
| this.parameters.put(key, value); | ||
| return this; | ||
| } | ||
|
|
||
| /** Shorthand for TopK sparsity ratio */ | ||
| public Builder withSparsity(double sparsity) { | ||
| return withParameter("sparsity", sparsity); | ||
| } | ||
|
|
||
| /** Shorthand for quantization bit width */ | ||
| public Builder withBits(int bits) { | ||
| return withParameter("bits", bits); | ||
| } | ||
|
|
||
| public CompressionConfig build() { | ||
| return new CompressionConfig(this); | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| package org.apache.sysds.runtime.compress; | ||
|
|
||
| import org.apache.sysds.runtime.compress.TopK.TopKCompressor; | ||
| import org.apache.sysds.runtime.compress.Quantization.ProbabilisticQuantizationCompressor; | ||
|
|
||
| /** | ||
| * Factory for creating compressor instances from configuration. | ||
| * Centralizes compressor instantiation and parameter validation. | ||
| * | ||
| * Usage: | ||
| * CompressionConfig config = CompressionConfig.builder() | ||
| * .enable(true) | ||
| * .withType(CompressionType.TOPK) | ||
| * .withSparsity(0.01) | ||
| * .build(); | ||
| * MatrixCompressor compressor = CompressionFactory.create(config); | ||
| * | ||
| * | ||
| */ | ||
| public class CompressionFactory { | ||
|
|
||
| private CompressionFactory() { | ||
| // Utility class — no instantiation | ||
| } | ||
|
|
||
| /** | ||
| * Create a compressor from a CompressionConfig. | ||
| * @param config The compression configuration | ||
| * @return A ready-to-use MatrixCompressor | ||
| * @throws IllegalArgumentException if the config is invalid | ||
| */ | ||
| public static MatrixCompressor create(CompressionConfig config) { | ||
| if (config == null || !config.isEnabled()) { | ||
| return new PassthroughCompressor(); | ||
| } | ||
| return create(config.getType(), config); | ||
| } | ||
|
|
||
| /** | ||
| * Create a compressor for a specific type with given config. | ||
| */ | ||
| public static MatrixCompressor create(CompressionType type, CompressionConfig config) { | ||
|
Comment on lines
+32
to
+42
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can merge these two methods into a single one, as only the above method is called from outside this class. |
||
| switch (type) { | ||
| case TOPK: | ||
| double sparsity = config.getSparsity(); | ||
| return new TopKCompressor(sparsity, true); | ||
|
|
||
| case PROBABILISTIC_QUANTIZATION: | ||
| int bits = config.getBits(); | ||
| return new ProbabilisticQuantizationCompressor(bits); | ||
|
|
||
| case ONE_BIT_CS: | ||
| throw new UnsupportedOperationException( | ||
| "1-Bit Compressed Sensing not yet implemented"); | ||
|
Comment on lines
+52
to
+54
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you still planning on implementing compressed sensing? |
||
|
|
||
| case NONE: | ||
| default: | ||
| return new PassthroughCompressor(); | ||
| } | ||
| } | ||
|
|
||
| // ----------------------------------------------------------------------- | ||
| // Passthrough compressor (no-op) for when compression is disabled | ||
| // ----------------------------------------------------------------------- | ||
|
|
||
| /** | ||
| * No-op compressor: returns the matrix as-is. | ||
| * Used when compression is disabled or type is NONE. | ||
| */ | ||
| private static class PassthroughCompressor implements MatrixCompressor { | ||
|
|
||
| @Override | ||
| public CompressedMatrix compress(org.apache.sysds.runtime.matrix.data.MatrixBlock input) | ||
| throws org.apache.sysds.runtime.compress.exceptions.CompressionException { | ||
| return new CompressedMatrix( | ||
| CompressionType.NONE, | ||
| input.getNumRows(), | ||
| input.getNumColumns(), | ||
| input, | ||
| 1.0 | ||
| ); | ||
| } | ||
|
|
||
| @Override | ||
| public org.apache.sysds.runtime.matrix.data.MatrixBlock decompress(CompressedMatrix compressed) | ||
| throws org.apache.sysds.runtime.compress.exceptions.DecompressionException { | ||
| return (org.apache.sysds.runtime.matrix.data.MatrixBlock) compressed.getCompressedData(); | ||
|
Comment on lines
+85
to
+87
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Import the MatrixBlock and use it as |
||
| } | ||
|
|
||
| @Override | ||
| public CompressionType getCompressionType() { | ||
| return CompressionType.NONE; | ||
|
Comment on lines
+90
to
+92
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove if not needed. |
||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| package org.apache.sysds.runtime.compress; | ||
|
|
||
| /** | ||
| * Enumeration of supported compression techniques for federated learning. | ||
| * Used for configuration, serialization, and technique selection. | ||
| * | ||
| * @author Nirvan C. Udaysingh Jhurree | ||
| */ | ||
| public enum CompressionType { | ||
|
|
||
| /** TopK sparsification: keep largest-magnitude elements only */ | ||
| TOPK("topk", "Top-K Sparsification"), | ||
|
|
||
| /** Probabilistic quantization: reduce precision with stochastic rounding */ | ||
| PROBABILISTIC_QUANTIZATION("prob_quant", "Probabilistic Quantization"), | ||
|
|
||
| /** 1-bit compressed sensing: sign-only transmission + iterative reconstruction */ | ||
| ONE_BIT_CS("1bit_cs", "1-Bit Compressed Sensing"), | ||
|
|
||
| /** No compression (passthrough) */ | ||
| NONE("none", "No Compression"); | ||
|
|
||
| private final String id; | ||
| private final String description; | ||
|
|
||
| CompressionType(String id, String description) { | ||
| this.id = id; | ||
| this.description = description; | ||
| } | ||
|
|
||
| public String getId() { return id; } | ||
| public String getDescription() { return description; } | ||
|
|
||
| /** Parse from string identifier (case-insensitive) */ | ||
| public static CompressionType fromString(String text) { | ||
| for (CompressionType type : CompressionType.values()) { | ||
| if (type.id.equalsIgnoreCase(text)) { | ||
| return type; | ||
| } | ||
| } | ||
| throw new IllegalArgumentException("Unknown compression type: " + text); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| package org.apache.sysds.runtime.compress; | ||
|
|
||
| import org.apache.sysds.runtime.compress.exceptions.CompressionException; | ||
| import org.apache.sysds.runtime.compress.exceptions.DecompressionException; | ||
| import org.apache.sysds.runtime.matrix.data.MatrixBlock; | ||
|
|
||
| /** | ||
| * Interface for matrix compression techniques in federated learning. | ||
| * All compressors must implement compress/decompress operations. | ||
| * | ||
| * @author Nirvan C. Udaysingh Jhurree | ||
| */ | ||
| public interface MatrixCompressor { | ||
|
|
||
| /** | ||
| * Compress a matrix block for transmission. | ||
| * @param input The source matrix to compress | ||
| * @return CompressedMatrix containing compressed data and metadata | ||
| * @throws CompressionException if compression fails | ||
| */ | ||
| CompressedMatrix compress(MatrixBlock input) throws CompressionException; | ||
|
|
||
| /** | ||
| * Decompress a compressed matrix back to MatrixBlock. | ||
| * @param compressed The compressed data to decompress | ||
| * @return Reconstructed MatrixBlock (may be approximate) | ||
| * @throws DecompressionException if decompression fails | ||
| */ | ||
| MatrixBlock decompress(CompressedMatrix compressed) throws DecompressionException; | ||
|
|
||
| /** | ||
| * Get the compression technique identifier. | ||
| * @return CompressionType enum value | ||
| */ | ||
| CompressionType getCompressionType(); | ||
|
Comment on lines
+31
to
+35
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove if not needed. |
||
|
|
||
| /** | ||
| * Estimate the compression ratio achieved. | ||
| * Higher is better (e.g. 10.0 means 10x smaller). | ||
| */ | ||
| default double estimateCompressionRatio(long originalSize, long compressedSize) { | ||
| return compressedSize == 0 ? Double.MAX_VALUE : (double) originalSize / compressedSize; | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to create a separate workflow for this test. You can move your test to class to one of the test folders so that it gets executed automatically, or you can add the path to your test to the
javaTests.ymlworkflow.