99import com .example .core .model .tensor .GGMLTensorEntry ;
1010import com .example .core .model .tensor .Q4_0FloatTensor ;
1111import com .example .core .model .tensor .Q8_0FloatTensor ;
12- import com .example .core .types .Float16 ;
1312import com .example .core .types .Pair ;
1413import com .example .inference .engine .impl .Configuration ;
1514import com .example .inference .engine .impl .Llama ;
1615import com .example .inference .operation .RoPE ;
1716import com .example .tokenizer .impl .Tokenizer ;
1817import com .example .tokenizer .vocabulary .Vocabulary ;
18+ import uk .ac .manchester .tornado .api .types .HalfFloat ;
1919import uk .ac .manchester .tornado .api .types .arrays .ByteArray ;
2020import uk .ac .manchester .tornado .api .types .arrays .FloatArray ;
21+ import uk .ac .manchester .tornado .api .types .arrays .HalfFloatArray ;
2122
2223import java .io .IOException ;
23- import java .lang .foreign .MemorySegment ;
2424import java .nio .ByteOrder ;
2525import java .nio .FloatBuffer ;
2626import java .nio .channels .FileChannel ;
3333import java .util .stream .Collectors ;
3434import java .util .stream .IntStream ;
3535
36- import static com .example .core .model .tensor .FloatTensor .readByte ;
37- import static com .example .core .model .tensor .FloatTensor .readShort ;
38-
3936public final class ModelLoader {
4037 private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2" ;
4138
@@ -104,15 +101,15 @@ private static Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tenso
104101 return new Weights (
105102 // Load directly to TornadoVM format
106103 loadTensorAsFloatArray (tokenEmbeddings ), loadArrayAsFloatArrayFromBuffer (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".attn_norm.weight" )),
107- loadArrayAsFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".attn_q.weight" )),
108- loadArrayAsFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".attn_k.weight" )),
109- loadArrayAsFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".attn_v.weight" )),
110- loadArrayAsFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".attn_output.weight" )),
104+ loadArrayAsHalfFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".attn_q.weight" )),
105+ loadArrayAsHalfFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".attn_k.weight" )),
106+ loadArrayAsHalfFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".attn_v.weight" )),
107+ loadArrayAsHalfFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".attn_output.weight" )),
111108 loadArrayAsFloatArrayFromBuffer (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".ffn_norm.weight" )),
112- loadArrayAsFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".ffn_gate.weight" )),
113- loadArrayAsFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".ffn_down.weight" )),
114- loadArrayAsFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".ffn_up.weight" )), floatBufferToFloatArray (tensorEntries .get ("output_norm.weight" )),
115- FloatArray .fromArray (ropeFreqs .first ()), FloatArray .fromArray (ropeFreqs .second ()), createByteArrayFromTensor (outputWeight ), outputWeight .ggmlType ());
109+ loadArrayAsHalfFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".ffn_gate.weight" )),
110+ loadArrayAsHalfFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".ffn_down.weight" )),
111+ loadArrayAsHalfFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".ffn_up.weight" )), floatBufferToFloatArray (tensorEntries .get ("output_norm.weight" )),
112+ FloatArray .fromArray (ropeFreqs .first ()), FloatArray .fromArray (ropeFreqs .second ()), loadTensorAsHalfFloatArray (outputWeight ), outputWeight .ggmlType ());
116113 }
117114
118115 /**
@@ -132,15 +129,51 @@ private static Weights createStandardWeights(Map<String, GGMLTensorEntry> tensor
132129 FloatBuffer .wrap (ropeFreqs .first ()), FloatBuffer .wrap (ropeFreqs .second ()), loadQuantized (outputWeight ), outputWeight .ggmlType ());
133130 }
134131
135- private static FloatArray [] loadArrayAsFloatArray (int size , IntFunction <GGMLTensorEntry > getTensorEntry ) {
132+ private static Tokenizer createTokenizer (Map <String , Object > metadata , Vocabulary vocabulary ) {
133+ String [] mergeLines = (String []) metadata .get ("tokenizer.ggml.merges" );
134+ List <Pair <Integer , Integer >> merges = Arrays .stream (mergeLines ).map (line -> line .split (" " ))
135+ .map (parts -> new Pair <>(vocabulary .getIndex (parts [0 ]).orElseThrow (), vocabulary .getIndex (parts [1 ]).orElseThrow ())).toList ();
136+
137+ int allTokens = vocabulary .size ();
138+ int baseTokens = 128000 ; // assume all tokens after the base ones are special.
139+ int reservedSpecialTokens = allTokens - baseTokens ;
140+ List <String > specialTokensList = Arrays .stream (vocabulary .tokens (), baseTokens , allTokens ).toList ();
141+
142+ assert specialTokensList .stream ().allMatch (token -> vocabulary .getIndex (token ).isPresent ());
143+
144+ Map <String , Integer > specialTokens = IntStream .range (0 , specialTokensList .size ()).boxed ().collect (Collectors .toMap (i -> specialTokensList .get (i ), i -> baseTokens + i ));
145+
146+ return new Tokenizer (vocabulary , merges , LLAMA_3_PATTERN , specialTokens );
147+ }
148+
149+ public static FloatTensor loadQuantized (GGMLTensorEntry entry ) {
150+ GGMLType ggmlType = entry .ggmlType ();
151+ return switch (ggmlType ) {
152+ // case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
153+ case Q8_0 -> new Q8_0FloatTensor (FloatTensor .numberOfElements (entry .shape ()), entry .memorySegment ());
154+ case Q4_0 -> new Q4_0FloatTensor (FloatTensor .numberOfElements (entry .shape ()), entry .memorySegment ());
155+ case F16 -> new F16FloatTensor (FloatTensor .numberOfElements (entry .shape ()), entry .memorySegment ());
156+ default -> throw new UnsupportedOperationException ("Quantization format " + ggmlType );
157+ };
158+ }
159+
160+ public static FloatArray [] loadArrayAsFloatArray (int size , IntFunction <GGMLTensorEntry > getTensorEntry ) {
136161 FloatArray [] array = new FloatArray [size ];
137162 for (int i = 0 ; i < size ; i ++) {
138163 array [i ] = loadTensorAsFloatArray (getTensorEntry .apply (i ));
139164 }
140165 return array ;
141166 }
142167
143- private static FloatArray floatBufferToFloatArray (GGMLTensorEntry tensorEntry ) {
168+ public static HalfFloatArray [] loadArrayAsHalfFloatArray (int size , IntFunction <GGMLTensorEntry > getTensorEntry ) {
169+ HalfFloatArray [] array = new HalfFloatArray [size ];
170+ for (int i = 0 ; i < size ; i ++) {
171+ array [i ] = loadTensorAsHalfFloatArray (getTensorEntry .apply (i ));
172+ }
173+ return array ;
174+ }
175+
176+ public static FloatArray floatBufferToFloatArray (GGMLTensorEntry tensorEntry ) {
144177 if (tensorEntry .ggmlType () == GGMLType .F32 ) {
145178 FloatBuffer buffer = tensorEntry .memorySegment ().asByteBuffer ().order (ByteOrder .LITTLE_ENDIAN ).asFloatBuffer ();
146179 return FloatArray .fromFloatBuffer (buffer );
@@ -149,20 +182,20 @@ private static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) {
149182 }
150183 }
151184
152- private static FloatArray [] loadArrayAsFloatArrayFromBuffer (int size , IntFunction <GGMLTensorEntry > getTensorEntry ) {
185+ public static FloatArray [] loadArrayAsFloatArrayFromBuffer (int size , IntFunction <GGMLTensorEntry > getTensorEntry ) {
153186 FloatArray [] array = new FloatArray [size ];
154187 for (int i = 0 ; i < size ; i ++) {
155188 array [i ] = floatBufferToFloatArray (getTensorEntry .apply (i ));
156189 }
157190 return array ;
158191 }
159192
160- private static ByteArray createByteArrayFromTensor (GGMLTensorEntry entry ) {
193+ public static ByteArray createByteArrayFromTensor (GGMLTensorEntry entry ) {
161194 FloatTensor tensor = loadQuantized (entry );
162195 return ByteArray .fromSegment (tensor .asMemorySegment ());
163196 }
164197
165- private static FloatArray loadTensorAsFloatArray (GGMLTensorEntry entry ) {
198+ public static FloatArray loadTensorAsFloatArray (GGMLTensorEntry entry ) {
166199 if (entry .ggmlType () == GGMLType .F32 ) {
167200 // For F32, we can directly create FloatArray from memory
168201 FloatBuffer buffer = entry .memorySegment ().asByteBuffer ().order (ByteOrder .LITTLE_ENDIAN ).asFloatBuffer ();
@@ -182,50 +215,20 @@ private static FloatArray loadTensorAsFloatArray(GGMLTensorEntry entry) {
182215 }
183216 }
184217
185- public static float getFloat (int index , int size , MemorySegment memorySegment ) {
186- assert 0 <= index && index < size ;
187- int blockIndex = index / GGMLType .Q4_0 .getBlockSize ();
188- int blockOffset = blockIndex * GGMLType .Q4_0 .getTypeSize ();
189- float scale = Float .float16ToFloat (readShort (memorySegment , blockOffset ));
190- byte quant ;
191- int modIndex = index % GGMLType .Q4_0 .getBlockSize ();
192- if (modIndex < GGMLType .Q4_0 .getBlockSize () / 2 ) {
193- quant = (byte ) (readByte (memorySegment , blockOffset + Float16 .BYTES + modIndex ) & 0x0F );
218+ public static HalfFloatArray loadTensorAsHalfFloatArray (GGMLTensorEntry entry ) {
219+ if (entry .ggmlType () == GGMLType .F32 ) {
220+ System .out .println ("Loading F32 tensor as HalfFloatArray" );
221+ return null ;
194222 } else {
195- quant = (byte ) ((readByte (memorySegment , blockOffset + Float16 .BYTES + modIndex - GGMLType .Q4_0 .getBlockSize () / 2 ) >>> 4 ) & 0x0F );
223+ // For quantized formats, we need to load through FloatTensor
224+ FloatTensor tensor = loadQuantized (entry );
225+ HalfFloatArray array = new HalfFloatArray (tensor .size ());
226+ for (int i = 0 ; i < tensor .size (); i ++) {
227+ HalfFloat x = new HalfFloat (tensor .getFloat (i ));
228+ array .set (i , x );
229+ }
230+ return array ;
196231 }
197- quant -= 8 ;
198- return quant * scale ;
199- }
200-
201- private static Tokenizer createTokenizer (Map <String , Object > metadata , Vocabulary vocabulary ) {
202- String [] mergeLines = (String []) metadata .get ("tokenizer.ggml.merges" );
203- List <Pair <Integer , Integer >> merges = Arrays .stream (mergeLines ).map (line -> line .split (" " ))
204- .map (parts -> new Pair <>(vocabulary .getIndex (parts [0 ]).orElseThrow (), vocabulary .getIndex (parts [1 ]).orElseThrow ())).toList ();
205-
206- int allTokens = vocabulary .size ();
207- int baseTokens = 128000 ; // assume all tokens after the base ones are special.
208- int reservedSpecialTokens = allTokens - baseTokens ;
209- List <String > specialTokensList = Arrays .stream (vocabulary .tokens (), baseTokens , allTokens ).toList ();
210-
211- assert specialTokensList .stream ().allMatch (token -> vocabulary .getIndex (token ).isPresent ());
212-
213- Map <String , Integer > specialTokens = IntStream .range (0 , specialTokensList .size ()).boxed ().collect (Collectors .toMap (i -> specialTokensList .get (i ), i -> baseTokens + i ));
214-
215- return new Tokenizer (vocabulary , merges , LLAMA_3_PATTERN , specialTokens );
216- }
217-
218- public static FloatTensor loadQuantized (GGMLTensorEntry entry ) {
219- GGMLType ggmlType = entry .ggmlType ();
220- // System.out.println("Loading quantized tensor of type " + entry.name());
221- return switch (ggmlType ) {
222- // case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
223- case Q8_0 -> new Q8_0FloatTensor (FloatTensor .numberOfElements (entry .shape ()), entry .memorySegment ());
224- case Q4_0 -> new Q4_0FloatTensor (FloatTensor .numberOfElements (entry .shape ()), entry .memorySegment ());
225- // case BF16 -> new BF16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
226- case F16 -> new F16FloatTensor (FloatTensor .numberOfElements (entry .shape ()), entry .memorySegment ());
227- default -> throw new UnsupportedOperationException ("Quantization format " + ggmlType );
228- };
229232 }
230233
231234 public static FloatTensor [] loadArrayOfQuantized (int size , IntFunction <GGMLTensorEntry > getTensorEntry ) {
0 commit comments