@@ -23,94 +23,59 @@ static const char* model_config = "{ \
2323 } " ;
2424
2525
26- ::tensorflow::eas::ArrayProto get_proto_float_1 (std::vector<float >& cur_vector){
27- ::tensorflow::eas::ArrayShape array_shape;
28- ::tensorflow::eas::ArrayDataType dtype_f =
29- ::tensorflow::eas::ArrayDataType::DT_FLOAT;
30-
31- array_shape.add_dim (1 );
32- ::tensorflow::eas::ArrayProto input;
33- input.add_float_val ((float )cur_vector.back ());
34- input.set_dtype (dtype_f);
35- *(input.mutable_array_shape ()) = array_shape;
36- return input;
37-
38- }
3926
40- ::tensorflow::eas::ArrayProto get_proto_float_2 (std::vector<float >& cur_vector){
27+ ::tensorflow::eas::ArrayProto get_proto_cc (std::vector<char * >& cur_vector, ::tensorflow::eas::ArrayDataType dtype_f ){
4128 ::tensorflow::eas::ArrayShape array_shape;
42- ::tensorflow::eas::ArrayDataType dtype_f =
43- ::tensorflow::eas::ArrayDataType::DT_FLOAT;
44- int num_elem = (int )cur_vector.size ();
45-
46- array_shape.add_dim (1 );
47- if ((int )cur_vector.size () < 0 ){
48-
49- array_shape.add_dim (1 );
50- ::tensorflow::eas::ArrayProto input;
51- input.add_float_val (1.0 );
52- input.set_dtype (dtype_f);
53- *(input.mutable_array_shape ()) = array_shape;
54-
55- return input;
56- }
57- array_shape.add_dim ((int )cur_vector.size ());
58-
5929 ::tensorflow::eas::ArrayProto input;
60- for (int tt = 0 ; tt < (int )cur_vector.size (); ++tt)
61- {
62- input.add_float_val ((float )cur_vector[tt]);
63- }
30+
31+ int num_elem = (int )cur_vector.size ();
6432 input.set_dtype (dtype_f);
65- *(input.mutable_array_shape ()) = array_shape;
66-
67- return input;
68-
69- }
70-
71- ::tensorflow::eas::ArrayProto get_proto_int_1 (std::vector<int >& cur_vector){
72- ::tensorflow::eas::ArrayShape array_shape;
73- ::tensorflow::eas::ArrayDataType dtype_i =
74- ::tensorflow::eas::ArrayDataType::DT_INT32;
75-
76- array_shape.add_dim (1 );
77- ::tensorflow::eas::ArrayProto input;
78- input.add_int_val ((int )cur_vector.back ());
79- input.set_dtype (dtype_i);
80- *(input.mutable_array_shape ()) = array_shape;
81- return input;
8233
34+ switch (dtype_f){
35+ case 1 :
36+ array_shape.add_dim (1 );
37+ if (num_elem == 1 ){
38+ input.add_float_val ((float )atof (cur_vector.back ()));
39+ *(input.mutable_array_shape ()) = array_shape;
40+ return input;
41+ }
42+ array_shape.add_dim (cur_vector.size ());
43+ for (unsigned int tt = 0 ; tt < cur_vector.size (); ++tt)
44+ {
45+ input.add_float_val ((float )atof (cur_vector[tt]));
46+ }
47+ *(input.mutable_array_shape ()) = array_shape;
48+
49+ return input;
50+
51+ break ;
52+
53+ case 3 :
54+ array_shape.add_dim (1 );
55+ if (num_elem == 1 ){
56+ input.add_int_val ((int )atoi (cur_vector.back ()));
57+ *(input.mutable_array_shape ()) = array_shape;
58+ return input;
59+ }
60+ array_shape.add_dim (cur_vector.size ());
61+ for (unsigned int tt = 0 ; tt < cur_vector.size (); ++tt)
62+ {
63+ input.add_int_val ((int )atoi (cur_vector[tt]));
64+ }
65+ *(input.mutable_array_shape ()) = array_shape;
66+
67+ return input;
68+ break ;
69+
70+ default :
71+ break ;
72+ }
73+
74+ std::cerr << " type error\n " ;
75+ return input;
8376}
8477
85- ::tensorflow::eas::ArrayProto get_proto_int_2 (std::vector<int >& cur_vector){
86- ::tensorflow::eas::ArrayShape array_shape;
87- ::tensorflow::eas::ArrayDataType dtype_f =
88- ::tensorflow::eas::ArrayDataType::DT_INT32;
89- int num_elem = (int )cur_vector.size ();
90-
91- array_shape.add_dim (1 );
92- if ((int )cur_vector.size () < 0 ){
93-
94- array_shape.add_dim (1 );
95- ::tensorflow::eas::ArrayProto input;
96- input.add_int_val (1 );
97- input.set_dtype (dtype_f);
98- *(input.mutable_array_shape ()) = array_shape;
99-
100- return input;
101- }
102- array_shape.add_dim ((int )cur_vector.size ());
103- ::tensorflow::eas::ArrayProto input;
104- for (int tt = 0 ; tt < (int )cur_vector.size (); ++tt)
105- {
106- input.add_int_val ((int )cur_vector[tt]);
107- }
108- input.set_dtype (dtype_f);
109- *(input.mutable_array_shape ()) = array_shape;
11078
111- return input;
112-
113- }
11479
11580
11681int main (int argc, char ** argv) {
@@ -134,14 +99,14 @@ int main(int argc, char** argv) {
13499 int cur_type = 0 ;
135100
136101 // vector variables
137- std::vector<int > cur_uids;
138- std::vector<int > cur_mids;
139- std::vector<int > cur_cats;
140- std::vector<int > cur_sl; // single
141- std::vector<int > cur_mid_his;
142- std::vector<int > cur_cat_his;
143- std::vector<float > cur_mid_mask;
144- std::vector<float > cur_target; // multiple
102+ std::vector<char * > cur_uids;
103+ std::vector<char * > cur_mids;
104+ std::vector<char * > cur_cats;
105+ std::vector<char * > cur_sl; // single
106+ std::vector<char * > cur_mid_his;
107+ std::vector<char * > cur_cat_his;
108+ std::vector<char * > cur_mid_mask;
109+ std::vector<char * > cur_target; // multiple
145110
146111 // temp pointers
147112 std::vector<char *> temp_ptrs;
@@ -184,56 +149,56 @@ int main(int argc, char** argv) {
184149
185150 temp_ptrs.push_back ((char *) malloc (sizeof (char )*strlen (record)));
186151 strcpy (temp_ptrs.back (),record);
187- cur_uids.push_back (( int ) atoi ( temp_ptrs.back () ));
152+ cur_uids.push_back (temp_ptrs.back ());
188153 break ;
189154
190155 case 1 :
191156
192157 temp_ptrs.push_back ((char *) malloc (sizeof (char )*strlen (record)));
193158 strcpy (temp_ptrs.back (),record);
194- cur_mids.push_back (( int ) atoi ( temp_ptrs.back () ));
159+ cur_mids.push_back (temp_ptrs.back ());
195160 break ;
196161
197162 case 2 :
198163
199164 temp_ptrs.push_back ((char *) malloc (sizeof (char )*strlen (record)));
200165 strcpy (temp_ptrs.back (),record);
201- cur_cats.push_back (( int ) atoi ( temp_ptrs.back () ));
166+ cur_cats.push_back (temp_ptrs.back ());
202167 break ;
203168
204169 case 3 :
205170
206171 temp_ptrs.push_back ((char *) malloc (sizeof (char )*strlen (record)));
207172 strcpy (temp_ptrs.back (),record);
208- cur_mid_his.push_back (( int ) atoi ( temp_ptrs.back () ));
173+ cur_mid_his.push_back (temp_ptrs.back ());
209174 break ;
210175
211176 case 4 :
212177
213178 temp_ptrs.push_back ((char *) malloc (sizeof (char )*strlen (record)));
214179 strcpy (temp_ptrs.back (),record);
215- cur_cat_his.push_back (( int ) atoi ( temp_ptrs.back () ));
180+ cur_cat_his.push_back (temp_ptrs.back ());
216181 break ;
217182
218183 case 5 :
219184
220185 temp_ptrs.push_back ((char *) malloc (sizeof (char )*strlen (record)));
221186 strcpy (temp_ptrs.back (),record);
222- cur_mid_mask.push_back (( float ) atof ( temp_ptrs.back () ));
187+ cur_mid_mask.push_back (temp_ptrs.back ());
223188 break ;
224189
225190 case 6 :
226191
227192 temp_ptrs.push_back ((char *) malloc (sizeof (char )*strlen (record)));
228193 strcpy (temp_ptrs.back (),record);
229- cur_target.push_back (( float ) atof ( temp_ptrs.back () ));
194+ cur_target.push_back (temp_ptrs.back ());
230195 break ;
231196
232197 case 7 :
233198
234199 temp_ptrs.push_back ((char *) malloc (sizeof (char )*strlen (record)));
235200 strcpy (temp_ptrs.back (),record);
236- cur_sl.push_back (( int ) atoi ( temp_ptrs.back () ));
201+ cur_sl.push_back (temp_ptrs.back ());
237202 break ;
238203
239204 default :
@@ -244,18 +209,20 @@ int main(int argc, char** argv) {
244209
245210 }
246211
247- // // ---------------------------------------prepare request--------------------------------------
248212
249-
213+ ::tensorflow::eas::ArrayDataType dtype_i =
214+ ::tensorflow::eas::ArrayDataType::DT_INT32;
215+ ::tensorflow::eas::ArrayDataType dtype_f =
216+ ::tensorflow::eas::ArrayDataType::DT_FLOAT;
250217 // get all inputs
251- ::tensorflow::eas::ArrayProto proto_uids = get_proto_int_1 (cur_uids); // -1
252- ::tensorflow::eas::ArrayProto proto_mids = get_proto_int_1 (cur_mids); // -1
253- ::tensorflow::eas::ArrayProto proto_cats = get_proto_int_1 (cur_cats); // -1
254- ::tensorflow::eas::ArrayProto proto_mid_his = get_proto_int_2 (cur_mid_his); // -1 -1
255- ::tensorflow::eas::ArrayProto proto_cat_his = get_proto_int_2 (cur_cat_his); // -1 -1
256- ::tensorflow::eas::ArrayProto proto_mid_mask= get_proto_float_2 (cur_mid_mask); // float // -1 -1
257- ::tensorflow::eas::ArrayProto proto_target = get_proto_float_2 (cur_target); // float // -1 -1
258- ::tensorflow::eas::ArrayProto proto_sl = get_proto_int_1 (cur_sl); // -1
218+ ::tensorflow::eas::ArrayProto proto_uids = get_proto_cc (cur_uids,dtype_i ); // -1
219+ ::tensorflow::eas::ArrayProto proto_mids = get_proto_cc (cur_mids,dtype_i ); // -1
220+ ::tensorflow::eas::ArrayProto proto_cats = get_proto_cc (cur_cats,dtype_i ); // -1
221+ ::tensorflow::eas::ArrayProto proto_mid_his = get_proto_cc (cur_mid_his,dtype_i ); // -1 -1
222+ ::tensorflow::eas::ArrayProto proto_cat_his = get_proto_cc (cur_cat_his,dtype_i ); // -1 -1
223+ ::tensorflow::eas::ArrayProto proto_mid_mask= get_proto_cc (cur_mid_mask,dtype_f ); // float // -1 -1
224+ ::tensorflow::eas::ArrayProto proto_target = get_proto_cc (cur_target,dtype_f ); // float // -1 -1
225+ ::tensorflow::eas::ArrayProto proto_sl = get_proto_cc (cur_sl,dtype_i ); // -1
259226
260227
261228 // setup request
@@ -276,7 +243,7 @@ int main(int argc, char** argv) {
276243 void *buffer1 = malloc (size);
277244 req.SerializeToArray (buffer1, size);
278245
279- // // -------------------------------------process and get feedback-----------------------------------------
246+ // ---------------------------------------------- process and get feedback---------- -----------------------------------------
280247 void * output = nullptr ;
281248 int output_size = 0 ;
282249 state = process (model, buffer1, size, &output, &output_size);
0 commit comments