Skip to content

Commit c9f00ca

Browse files
authored
Merge pull request #62 from beehive-lab/refactor/generarilize
[refactor] Generalize the design of `tornadovm` package to support multiple new models and types for GPU exec
2 parents 6367a00 + af0eb5c commit c9f00ca

File tree

80 files changed

+3663
-3626
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+3663
-3626
lines changed

Makefile

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ package:
2323
package-with-clean:
2424
$(MVN) clean package -DskipTests
2525

26+
lint:
27+
$(MVN) -T12C -Pspotless spotless:check
28+
29+
# Automatically format the code to conform to a style guide.
30+
# Modifies the code to ensure consistent formatting.
31+
format:
32+
$(MVN) -T12C -Pspotless spotless:apply
33+
2634
# Display help
2735
help:
2836
@echo "Available targets:"

external/tornadovm

Submodule tornadovm updated 50 files

pom.xml

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,85 @@
157157
<autoPublish>true</autoPublish>
158158
</configuration>
159159
</plugin>
160+
160161
</plugins>
161162
</build>
163+
164+
<!-- Profiles for optional/conditional builds -->
165+
<profiles>
166+
<!-- Spotless: Code formatting and style checking (Optional)
167+
Usage: mvn -Pspotless spotless:check (to check violations)
168+
mvn -Pspotless spotless:apply (to apply fixes)
169+
-->
170+
<profile>
171+
<id>spotless</id>
172+
<build>
173+
<plugins>
174+
<plugin>
175+
<groupId>com.diffplug.spotless</groupId>
176+
<artifactId>spotless-maven-plugin</artifactId>
177+
<version>2.44.4</version>
178+
<configuration>
179+
<!-- Only check files changed since main branch (smart incremental checking) -->
180+
<ratchetFrom>origin/main</ratchetFrom>
181+
182+
<!-- Java files formatting -->
183+
<java>
184+
<includes>
185+
<include>src/main/java/**/*.java</include>
186+
<include>src/test/java/**/*.java</include>
187+
</includes>
188+
<excludes>
189+
<exclude>**/target/**</exclude>
190+
</excludes>
191+
<!-- Google Java Format (AOSP style) -->
192+
<googleJavaFormat>
193+
<version>1.19.2</version>
194+
<style>AOSP</style>
195+
</googleJavaFormat>
196+
<!-- Ensure newline at end of file -->
197+
<endWithNewline/>
198+
<!-- Remove trailing whitespace -->
199+
<trimTrailingWhitespace/>
200+
</java>
201+
202+
<!-- XML files formatting -->
203+
<pom>
204+
<includes>
205+
<include>pom.xml</include>
206+
</includes>
207+
<sortPom>
208+
<nrOfIndentSpace>4</nrOfIndentSpace>
209+
<expandEmptyElements>false</expandEmptyElements>
210+
</sortPom>
211+
</pom>
212+
213+
<!-- Markdown files -->
214+
<markdown>
215+
<includes>
216+
<include>**/*.md</include>
217+
</includes>
218+
<excludes>
219+
<exclude>**/target/**</exclude>
220+
</excludes>
221+
</markdown>
222+
223+
<!-- Properties files -->
224+
<format>
225+
<name>props</name>
226+
<includes>
227+
<include>src/**/*.properties</include>
228+
</includes>
229+
<excludes>
230+
<exclude>**/target/**</exclude>
231+
</excludes>
232+
<trimTrailingWhitespace/>
233+
<endWithNewline/>
234+
</format>
235+
</configuration>
236+
</plugin>
237+
</plugins>
238+
</build>
239+
</profile>
240+
</profiles>
162241
</project>

src/main/java/org/beehive/gpullama3/auxiliary/Tuple2.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,23 @@ public U getSecond() {
1919

2020
@Override
2121
public String toString() {
22-
return "Tuple2{" +
23-
"first=" + first +
24-
", second=" + second +
25-
'}';
22+
return "Tuple2{" + "first=" + first + ", second=" + second + '}';
2623
}
2724

2825
@Override
2926
public boolean equals(Object o) {
30-
if (this == o) return true;
31-
if (o == null || getClass() != o.getClass()) return false;
27+
if (this == o) {
28+
return true;
29+
}
30+
if (o == null || getClass() != o.getClass()) {
31+
return false;
32+
}
3233

3334
Tuple2<?, ?> tuple2 = (Tuple2<?, ?>) o;
3435

35-
if (!first.equals(tuple2.first)) return false;
36+
if (!first.equals(tuple2.first)) {
37+
return false;
38+
}
3639
return second.equals(tuple2.second);
3740
}
3841

src/main/java/org/beehive/gpullama3/inference/InferenceEngine.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import org.beehive.gpullama3.inference.state.State;
66
import org.beehive.gpullama3.model.Configuration;
77
import org.beehive.gpullama3.model.Model;
8-
import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
8+
import org.beehive.gpullama3.tokenizer.Tokenizer;
99
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
1010
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
1111

src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,37 @@
33
import org.beehive.gpullama3.Options;
44
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
55
import org.beehive.gpullama3.model.Model;
6-
import org.beehive.gpullama3.tornadovm.FloatArrayUtils;
6+
import org.beehive.gpullama3.tornadovm.utils.FloatArrayUtils;
77
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
88

99
import java.util.random.RandomGenerator;
1010
import java.util.random.RandomGeneratorFactory;
1111

1212
/**
13-
* Generic interface for sampling tokens from probability distributions.
14-
* Supports both FloatTensor and FloatArray tensor implementations.
13+
* Generic interface for sampling tokens from probability distributions. Supports both FloatTensor and FloatArray tensor implementations.
1514
*/
1615
@FunctionalInterface
1716
public interface Sampler {
1817

18+
/**
19+
* Argmax implementation for FloatTensor.
20+
*/
21+
Sampler TENSOR_ARGMAX = tensor -> {
22+
if (tensor instanceof FloatTensor) {
23+
return ((FloatTensor) tensor).argmax();
24+
} else if (tensor instanceof FloatArray) {
25+
return argmaxFloatArray((FloatArray) tensor);
26+
}
27+
throw new IllegalArgumentException("Unsupported tensor type: " + (tensor != null ? tensor.getClass().getName() : "null"));
28+
};
29+
/**
30+
* Legacy ARGMAX for backward compatibility.
31+
*
32+
* @deprecated Use TENSOR_ARGMAX instead
33+
*/
34+
@Deprecated
35+
Sampler ARGMAX = TENSOR_ARGMAX;
36+
1937
/**
2038
* Creates and configures a sampler for token generation based on specified parameters.
2139
*
@@ -103,42 +121,15 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp,
103121
return sampler;
104122
}
105123

106-
public static Sampler createSampler(Model model, Options options) {
124+
static Sampler createSampler(Model model, Options options) {
107125
return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed());
108126
}
109127

110-
/**
111-
* Sample a token from the provided tensor.
112-
*
113-
* @param tensor The tensor containing probabilities/logits
114-
* @return The selected token index
115-
*/
116-
int sampleToken(Object tensor);
117-
118-
/**
119-
* Argmax implementation for FloatTensor.
120-
*/
121-
Sampler TENSOR_ARGMAX = tensor -> {
122-
if (tensor instanceof FloatTensor) {
123-
return ((FloatTensor) tensor).argmax();
124-
} else if (tensor instanceof FloatArray) {
125-
return argmaxFloatArray((FloatArray) tensor);
126-
}
127-
throw new IllegalArgumentException("Unsupported tensor type: " +
128-
(tensor != null ? tensor.getClass().getName() : "null"));
129-
};
130-
131-
/**
132-
* Legacy ARGMAX for backward compatibility.
133-
* @deprecated Use TENSOR_ARGMAX instead
134-
*/
135-
@Deprecated
136-
Sampler ARGMAX = TENSOR_ARGMAX;
137-
138128
/**
139129
* Find the index of the maximum value in a FloatArray.
140130
*
141-
* @param array The FloatArray to find the maximum value in
131+
* @param array
132+
* The FloatArray to find the maximum value in
142133
* @return The index of the maximum value
143134
*/
144135
static int argmaxFloatArray(FloatArray array) {
@@ -155,4 +146,13 @@ static int argmaxFloatArray(FloatArray array) {
155146

156147
return maxIndex;
157148
}
149+
150+
/**
151+
* Sample a token from the provided tensor.
152+
*
153+
* @param tensor
154+
* The tensor containing probabilities/logits
155+
* @return The selected token index
156+
*/
157+
int sampleToken(Object tensor);
158158
}

src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Q8_0Weights.java renamed to src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/LlamaTornadoWeightsQ8_0.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
66
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
77

8-
public class Q8_0Weights implements TornadoWeights {
8+
public class LlamaTornadoWeightsQ8_0 implements TornadoWeights {
99
public FloatArray[] rms_att_weightLayered; // (layer, dim) rmsnorm weights
1010
public Q8_0QuantizedTensor[] wqLayered; // (layer, n_heads * head_size)
1111
public Q8_0QuantizedTensor[] wkLayered; // (layer, n_kv_heads, head_size)
@@ -24,7 +24,7 @@ public class Q8_0Weights implements TornadoWeights {
2424
// (optional) classifier weights for the logits, on the last layer
2525
protected final GGMLType weightType;
2626

27-
public Q8_0Weights(
27+
public LlamaTornadoWeightsQ8_0(
2828
FloatArray tokenEmbeddingTable,
2929
FloatArray[] rms_att_weightLayered,
3030
Q8_0QuantizedTensor[] wqLayered,

src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Phi3TornadoWeightsQ8_0.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
66

77

8-
public class Phi3TornadoWeightsQ8_0 extends Q8_0Weights {
8+
public class Phi3TornadoWeightsQ8_0 extends LlamaTornadoWeightsQ8_0 {
99

1010
// Phi3-specific weight arrays
1111
public Q8_0QuantizedTensor[] wqkvLayered; // Combined QKV weights: (layer, op_size, dim) where op_size = dim + 2 * (n_kv_heads * head_dim)

src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Qwen2TornadoWeightsQ8_0.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
66

77

8-
public class Qwen2TornadoWeightsQ8_0 extends Q8_0Weights {
8+
public class Qwen2TornadoWeightsQ8_0 extends LlamaTornadoWeightsQ8_0 {
99

1010
// Qwen2-specific tornado weights
1111
public FloatArray[] q_biasLayered;
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
66

77

8-
public class Qwen3Q8_0TornadoWeights extends Q8_0Weights{
8+
public class Qwen3TornadoWeightsQ8_0 extends LlamaTornadoWeightsQ8_0 {
99

1010
//attnKNorm
1111
public FloatArray[] rms_att_KNormLayered;
1212
//attnQNorm
1313
public FloatArray[] rms_att_QNormLayered;
1414

1515
// @formatter:off
16-
public Qwen3Q8_0TornadoWeights(
16+
public Qwen3TornadoWeightsQ8_0(
1717
FloatArray tokenEmbeddingTable,
1818
FloatArray[] rms_att_weightLayered,
1919
Q8_0QuantizedTensor[] wqLayered,

0 commit comments

Comments
 (0)