1+ from posixpath import join
2+ import numpy
3+ from numpy .lib .npyio import save
4+ from script .data_iterator import DataIterator
5+ import tensorflow as tf
6+ import time
7+ import random
8+ import sys
9+ from script .utils import *
10+ from tensorflow .python .framework import ops
11+ import os
12+ import json
13+
14+ EMBEDDING_DIM = 18
15+ HIDDEN_SIZE = 18 * 2
16+ ATTENTION_SIZE = 18 * 2
17+ best_auc = 0.0
18+ best_case_acc = 0.0
19+ batch_size = 1
20+ maxlen = 100
21+
22+ data_location = '../data'
23+ test_file = os .path .join (data_location , "local_test_splitByUser" )
24+ uid_voc = os .path .join (data_location , "uid_voc.pkl" )
25+ mid_voc = os .path .join (data_location , "mid_voc.pkl" )
26+ cat_voc = os .path .join (data_location , "cat_voc.pkl" )
27+
28+ def prepare_data (input , target , maxlen = None , return_neg = False ):
29+ # x: a list of sentences
30+ lengths_x = [len (s [4 ]) for s in input ]
31+ seqs_mid = [inp [3 ] for inp in input ]
32+ seqs_cat = [inp [4 ] for inp in input ]
33+ noclk_seqs_mid = [inp [5 ] for inp in input ]
34+ noclk_seqs_cat = [inp [6 ] for inp in input ]
35+
36+ if maxlen is not None :
37+ new_seqs_mid = []
38+ new_seqs_cat = []
39+ new_noclk_seqs_mid = []
40+ new_noclk_seqs_cat = []
41+ new_lengths_x = []
42+ for l_x , inp in zip (lengths_x , input ):
43+ if l_x > maxlen :
44+ new_seqs_mid .append (inp [3 ][l_x - maxlen :])
45+ new_seqs_cat .append (inp [4 ][l_x - maxlen :])
46+ new_noclk_seqs_mid .append (inp [5 ][l_x - maxlen :])
47+ new_noclk_seqs_cat .append (inp [6 ][l_x - maxlen :])
48+ new_lengths_x .append (maxlen )
49+ else :
50+ new_seqs_mid .append (inp [3 ])
51+ new_seqs_cat .append (inp [4 ])
52+ new_noclk_seqs_mid .append (inp [5 ])
53+ new_noclk_seqs_cat .append (inp [6 ])
54+ new_lengths_x .append (l_x )
55+ lengths_x = new_lengths_x
56+ seqs_mid = new_seqs_mid
57+ seqs_cat = new_seqs_cat
58+ noclk_seqs_mid = new_noclk_seqs_mid
59+ noclk_seqs_cat = new_noclk_seqs_cat
60+
61+ if len (lengths_x ) < 1 :
62+ return None , None , None , None
63+
64+ n_samples = len (seqs_mid )
65+ maxlen_x = numpy .max (lengths_x )
66+ neg_samples = len (noclk_seqs_mid [0 ][0 ])
67+
68+ mid_his = numpy .zeros ((n_samples , maxlen_x )).astype ('int64' )
69+ cat_his = numpy .zeros ((n_samples , maxlen_x )).astype ('int64' )
70+ noclk_mid_his = numpy .zeros (
71+ (n_samples , maxlen_x , neg_samples )).astype ('int64' )
72+ noclk_cat_his = numpy .zeros (
73+ (n_samples , maxlen_x , neg_samples )).astype ('int64' )
74+ mid_mask = numpy .zeros ((n_samples , maxlen_x )).astype ('float32' )
75+ for idx , [s_x , s_y , no_sx , no_sy ] in enumerate (
76+ zip (seqs_mid , seqs_cat , noclk_seqs_mid , noclk_seqs_cat )):
77+ mid_mask [idx , :lengths_x [idx ]] = 1.
78+ mid_his [idx , :lengths_x [idx ]] = s_x
79+ cat_his [idx , :lengths_x [idx ]] = s_y
80+ noclk_mid_his [idx , :lengths_x [idx ], :] = no_sx
81+ noclk_cat_his [idx , :lengths_x [idx ], :] = no_sy
82+
83+ uids = numpy .array ([inp [0 ] for inp in input ])
84+ mids = numpy .array ([inp [1 ] for inp in input ])
85+ cats = numpy .array ([inp [2 ] for inp in input ])
86+
87+ if return_neg :
88+ return uids , mids , cats , mid_his , cat_his , mid_mask , numpy .array (
89+ target ), numpy .array (lengths_x ), noclk_mid_his , noclk_cat_his
90+
91+ else :
92+ return uids , mids , cats , mid_his , cat_his , mid_mask , numpy .array (
93+ target ), numpy .array (lengths_x )
94+
95+
96+ test_data = DataIterator (test_file ,
97+ uid_voc ,
98+ mid_voc ,
99+ cat_voc ,
100+ batch_size ,
101+ maxlen ,
102+ data_location = data_location )
103+
104+ f = open ("./test_data.csv" ,"w" )
105+ counter = 0
106+
107+ for src , tgt in test_data :
108+ uids , mids , cats , mid_his , cat_his , mid_mask , target , sl = prepare_data (src , tgt )
109+ all_data = [uids , mids , cats , mid_his , cat_his , mid_mask , target , sl ]
110+ for cur_data in all_data :
111+ cur_data = numpy .squeeze (cur_data ).reshape (- 1 )
112+ for col in range (cur_data .shape [0 ]):
113+ uid = cur_data [col ]
114+ # print(uid)
115+ if col == cur_data .shape [0 ]- 1 :
116+ f .write (str (uid )+ ",k," )
117+ break
118+ f .write (str (uid )+ "," )
119+
120+ f .write ("\n " );
121+ if counter >= 1 :
122+ break
123+ counter += 1
124+
125+ f .close ()
0 commit comments