Skip to content

Commit 4e719f5

Browse files
authored
Merge pull request #8 from mikepapadim/feat/fp16_input_weights
Switch to FP16 -> HalfFloatArray for model weights to avoid on the fly decoding
2 parents bb1676b + c6d7a8c commit 4e719f5

File tree

8 files changed

+202
-508
lines changed

8 files changed

+202
-508
lines changed

.gitmodules

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
[submodule "external/tornadovm"]
22
path = external/tornadovm
33
url = https://github.com/beehive-lab/TornadoVM.git
4+
branch = master
5+

README.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -529,12 +529,6 @@ The secret sauce that transforms regular Java code into GPU-accelerated compute
529529
-----------
530530

531531

532-
## Early performance of v1.0
533-
534-
![GPULlama3.java Performance Comparison](./docs/performance.png)
535-
536-
-----------
537-
538532
## License
539533

540534

src/main/java/com/example/loader/weights/ModelLoader.java

Lines changed: 63 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,18 @@
99
import com.example.core.model.tensor.GGMLTensorEntry;
1010
import com.example.core.model.tensor.Q4_0FloatTensor;
1111
import com.example.core.model.tensor.Q8_0FloatTensor;
12-
import com.example.core.types.Float16;
1312
import com.example.core.types.Pair;
1413
import com.example.inference.engine.impl.Configuration;
1514
import com.example.inference.engine.impl.Llama;
1615
import com.example.inference.operation.RoPE;
1716
import com.example.tokenizer.impl.Tokenizer;
1817
import com.example.tokenizer.vocabulary.Vocabulary;
18+
import uk.ac.manchester.tornado.api.types.HalfFloat;
1919
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
2020
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
21+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
2122

2223
import java.io.IOException;
23-
import java.lang.foreign.MemorySegment;
2424
import java.nio.ByteOrder;
2525
import java.nio.FloatBuffer;
2626
import java.nio.channels.FileChannel;
@@ -33,9 +33,6 @@
3333
import java.util.stream.Collectors;
3434
import java.util.stream.IntStream;
3535

36-
import static com.example.core.model.tensor.FloatTensor.readByte;
37-
import static com.example.core.model.tensor.FloatTensor.readShort;
38-
3936
public final class ModelLoader {
4037
private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2";
4138

@@ -104,15 +101,15 @@ private static Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tenso
104101
return new Weights(
105102
// Load directly to TornadoVM format
106103
loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
107-
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
108-
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
109-
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
110-
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
104+
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
105+
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
106+
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
107+
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
111108
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
112-
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
113-
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
114-
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
115-
FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), createByteArrayFromTensor(outputWeight), outputWeight.ggmlType());
109+
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
110+
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
111+
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
112+
FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), loadTensorAsHalfFloatArray(outputWeight), outputWeight.ggmlType());
116113
}
117114

118115
/**
@@ -132,15 +129,51 @@ private static Weights createStandardWeights(Map<String, GGMLTensorEntry> tensor
132129
FloatBuffer.wrap(ropeFreqs.first()), FloatBuffer.wrap(ropeFreqs.second()), loadQuantized(outputWeight), outputWeight.ggmlType());
133130
}
134131

135-
private static FloatArray[] loadArrayAsFloatArray(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
132+
private static Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
133+
String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges");
134+
List<Pair<Integer, Integer>> merges = Arrays.stream(mergeLines).map(line -> line.split(" "))
135+
.map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList();
136+
137+
int allTokens = vocabulary.size();
138+
int baseTokens = 128000; // assume all tokens after the base ones are special.
139+
int reservedSpecialTokens = allTokens - baseTokens;
140+
List<String> specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList();
141+
142+
assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent());
143+
144+
Map<String, Integer> specialTokens = IntStream.range(0, specialTokensList.size()).boxed().collect(Collectors.toMap(i -> specialTokensList.get(i), i -> baseTokens + i));
145+
146+
return new Tokenizer(vocabulary, merges, LLAMA_3_PATTERN, specialTokens);
147+
}
148+
149+
public static FloatTensor loadQuantized(GGMLTensorEntry entry) {
150+
GGMLType ggmlType = entry.ggmlType();
151+
return switch (ggmlType) {
152+
// case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
153+
case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
154+
case Q4_0 -> new Q4_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
155+
case F16 -> new F16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
156+
default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
157+
};
158+
}
159+
160+
public static FloatArray[] loadArrayAsFloatArray(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
136161
FloatArray[] array = new FloatArray[size];
137162
for (int i = 0; i < size; i++) {
138163
array[i] = loadTensorAsFloatArray(getTensorEntry.apply(i));
139164
}
140165
return array;
141166
}
142167

143-
private static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) {
168+
public static HalfFloatArray[] loadArrayAsHalfFloatArray(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
169+
HalfFloatArray[] array = new HalfFloatArray[size];
170+
for (int i = 0; i < size; i++) {
171+
array[i] = loadTensorAsHalfFloatArray(getTensorEntry.apply(i));
172+
}
173+
return array;
174+
}
175+
176+
public static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) {
144177
if (tensorEntry.ggmlType() == GGMLType.F32) {
145178
FloatBuffer buffer = tensorEntry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
146179
return FloatArray.fromFloatBuffer(buffer);
@@ -149,20 +182,20 @@ private static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) {
149182
}
150183
}
151184

152-
private static FloatArray[] loadArrayAsFloatArrayFromBuffer(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
185+
public static FloatArray[] loadArrayAsFloatArrayFromBuffer(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
153186
FloatArray[] array = new FloatArray[size];
154187
for (int i = 0; i < size; i++) {
155188
array[i] = floatBufferToFloatArray(getTensorEntry.apply(i));
156189
}
157190
return array;
158191
}
159192

160-
private static ByteArray createByteArrayFromTensor(GGMLTensorEntry entry) {
193+
public static ByteArray createByteArrayFromTensor(GGMLTensorEntry entry) {
161194
FloatTensor tensor = loadQuantized(entry);
162195
return ByteArray.fromSegment(tensor.asMemorySegment());
163196
}
164197

165-
private static FloatArray loadTensorAsFloatArray(GGMLTensorEntry entry) {
198+
public static FloatArray loadTensorAsFloatArray(GGMLTensorEntry entry) {
166199
if (entry.ggmlType() == GGMLType.F32) {
167200
// For F32, we can directly create FloatArray from memory
168201
FloatBuffer buffer = entry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
@@ -182,50 +215,20 @@ private static FloatArray loadTensorAsFloatArray(GGMLTensorEntry entry) {
182215
}
183216
}
184217

185-
public static float getFloat(int index, int size, MemorySegment memorySegment) {
186-
assert 0 <= index && index < size;
187-
int blockIndex = index / GGMLType.Q4_0.getBlockSize();
188-
int blockOffset = blockIndex * GGMLType.Q4_0.getTypeSize();
189-
float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset));
190-
byte quant;
191-
int modIndex = index % GGMLType.Q4_0.getBlockSize();
192-
if (modIndex < GGMLType.Q4_0.getBlockSize() / 2) {
193-
quant = (byte) (readByte(memorySegment, blockOffset + Float16.BYTES + modIndex) & 0x0F);
218+
public static HalfFloatArray loadTensorAsHalfFloatArray(GGMLTensorEntry entry) {
219+
if (entry.ggmlType() == GGMLType.F32) {
220+
System.out.println("Loading F32 tensor as HalfFloatArray");
221+
return null;
194222
} else {
195-
quant = (byte) ((readByte(memorySegment, blockOffset + Float16.BYTES + modIndex - GGMLType.Q4_0.getBlockSize() / 2) >>> 4) & 0x0F);
223+
// For quantized formats, we need to load through FloatTensor
224+
FloatTensor tensor = loadQuantized(entry);
225+
HalfFloatArray array = new HalfFloatArray(tensor.size());
226+
for (int i = 0; i < tensor.size(); i++) {
227+
HalfFloat x = new HalfFloat(tensor.getFloat(i));
228+
array.set(i, x);
229+
}
230+
return array;
196231
}
197-
quant -= 8;
198-
return quant * scale;
199-
}
200-
201-
private static Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
202-
String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges");
203-
List<Pair<Integer, Integer>> merges = Arrays.stream(mergeLines).map(line -> line.split(" "))
204-
.map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList();
205-
206-
int allTokens = vocabulary.size();
207-
int baseTokens = 128000; // assume all tokens after the base ones are special.
208-
int reservedSpecialTokens = allTokens - baseTokens;
209-
List<String> specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList();
210-
211-
assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent());
212-
213-
Map<String, Integer> specialTokens = IntStream.range(0, specialTokensList.size()).boxed().collect(Collectors.toMap(i -> specialTokensList.get(i), i -> baseTokens + i));
214-
215-
return new Tokenizer(vocabulary, merges, LLAMA_3_PATTERN, specialTokens);
216-
}
217-
218-
public static FloatTensor loadQuantized(GGMLTensorEntry entry) {
219-
GGMLType ggmlType = entry.ggmlType();
220-
// System.out.println("Loading quantized tensor of type " + entry.name());
221-
return switch (ggmlType) {
222-
// case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
223-
case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
224-
case Q4_0 -> new Q4_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
225-
// case BF16 -> new BF16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
226-
case F16 -> new F16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
227-
default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
228-
};
229232
}
230233

231234
public static FloatTensor[] loadArrayOfQuantized(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {

0 commit comments

Comments
 (0)