77import org .tensorflow .framework .initializers .Glorot ;
88import org .tensorflow .framework .initializers .VarianceScaling ;
99import org .tensorflow .framework .utils .TestSession ;
10+ import org .tensorflow .ndarray .FloatNdArray ;
1011import org .tensorflow .ndarray .Shape ;
1112import org .tensorflow .ndarray .buffer .DataBuffers ;
1213import org .tensorflow .op .Op ;
2526import org .tensorflow .types .family .TType ;
2627
2728import java .util .ArrayList ;
29+ import java .util .Arrays ;
2830import java .util .List ;
2931
32+ import static org .junit .jupiter .api .Assertions .assertArrayEquals ;
3033import static org .junit .jupiter .api .Assertions .assertEquals ;
3134
3235/** Test cases for GradientDescent Optimizer */
@@ -125,6 +128,7 @@ public void testDeterminism() {
125128 GraphDef def ;
126129 String initName ;
127130 String trainName ;
131+ String lossName ;
128132
129133 String fcWeightName , fcBiasName , outputWeightName , outputBiasName ;
130134
@@ -159,8 +163,9 @@ public void testDeterminism() {
159163 Mean <TFloat32 > loss =
160164 tf .math .mean (
161165 tf .nn .raw .softmaxCrossEntropyWithLogits (output , placeholder ).loss (), tf .constant (0 ));
166+ lossName = loss .op ().name ();
162167
163- GradientDescent gd = new GradientDescent (g , 0.1f );
168+ GradientDescent gd = new GradientDescent (g , 10.0f );
164169 Op trainingOp = gd .minimize (loss );
165170 trainName = trainingOp .op ().name ();
166171
@@ -177,12 +182,14 @@ public void testDeterminism() {
177182 -14.0f , -15.0f , 0.16f , 0.17f , 0.18f , 1.9f , 0.2f
178183 };
179184 TFloat32 dataTensor = TFloat32 .tensorOf (Shape .of (1 , 20 ), DataBuffers .of (data ));
180- float [] target = new float [] {0.0f , 1.0f };
185+ float [] target = new float [] {0.2f , 0.8f };
181186 TFloat32 targetTensor = TFloat32 .tensorOf (Shape .of (1 , 2 ), DataBuffers .of (target ));
182187
183- int numRuns = 10 ;
188+ int numRuns = 20 ;
184189 List <List <Tensor >> initialized = new ArrayList <>(numRuns );
185190 List <List <Tensor >> trained = new ArrayList <>(numRuns );
191+ float [] initialLoss = new float [numRuns ];
192+ float [] postTrainingLoss = new float [numRuns ];
186193
187194 for (int i = 0 ; i < numRuns ; i ++) {
188195 try (Graph g = new Graph ();
@@ -197,12 +204,16 @@ public void testDeterminism() {
197204 .fetch (outputWeightName )
198205 .fetch (outputBiasName )
199206 .run ());
207+ System .out .println ("Initialized - " + ndArrToString ((TFloat32 )initialized .get (i ).get (3 )));
200208
201- s .runner ()
209+ TFloat32 lossVal = ( TFloat32 ) s .runner ()
202210 .addTarget (trainName )
203211 .feed ("input" , dataTensor )
204212 .feed ("output" , targetTensor )
205- .run ();
213+ .fetch (lossName )
214+ .run ().get (0 );
215+ initialLoss [i ] = lossVal .getFloat ();
216+ lossVal .close ();
206217
207218 trained .add (
208219 s .runner ()
@@ -211,10 +222,25 @@ public void testDeterminism() {
211222 .fetch (outputWeightName )
212223 .fetch (outputBiasName )
213224 .run ());
225+ System .out .println ("Initialized - " + ndArrToString ((TFloat32 )initialized .get (i ).get (3 )));
226+ System .out .println ("Trained - " + ndArrToString ((TFloat32 )trained .get (i ).get (3 )));
227+
228+ lossVal = (TFloat32 ) s .runner ()
229+ .addTarget (trainName )
230+ .feed ("input" , dataTensor )
231+ .feed ("output" , targetTensor )
232+ .fetch (lossName )
233+ .run ().get (0 );
234+ postTrainingLoss [i ] = lossVal .getFloat ();
235+ lossVal .close ();
214236 }
215237 }
216238
217239 for (int i = 1 ; i < numRuns ; i ++) {
240+ assertEquals (initialLoss [0 ],initialLoss [i ]);
241+ assertEquals (postTrainingLoss [0 ],postTrainingLoss [i ]);
242+ // Because the weights are references not copies.
243+ assertEquals (initialized .get (i ),trained .get (i ));
218244 assertEquals (
219245 initialized .get (0 ),
220246 initialized .get (i ),
@@ -234,4 +260,10 @@ public void testDeterminism() {
234260 }
235261 }
236262 }
263+
264+ private static String ndArrToString (FloatNdArray ndarray ) {
265+ StringBuffer sb = new StringBuffer ();
266+ ndarray .scalars ().forEachIndexed ((idx ,array ) -> sb .append (Arrays .toString (idx )).append (" = " ).append (array .getFloat ()).append ("\n " ));
267+ return sb .toString ();
268+ }
237269}
0 commit comments