Commit bf87ab8c authored by Sushant Mahajan's avatar Sushant Mahajan

modified model and corrected loss calculation

parent c90c56f0
Id,Label
0,0
1,1
2,1
3,1
4,0
5,0
6,0
7,0
8,0
9,0
10,1
11,1
12,0
13,0
14,1
15,0
16,1
17,1
18,0
19,0
20,0
21,0
22,0
23,1
24,0
25,1
26,0
27,0
28,0
29,0
30,0
31,1
32,0
33,1
34,0
35,0
36,0
37,1
38,0
39,1
40,0
41,1
42,1
43,1
44,0
45,0
46,0
47,0
48,1
49,0
50,1
51,0
52,1
53,0
54,1
55,1
56,0
57,0
58,1
59,1
60,1
61,0
62,0
63,0
64,0
65,1
66,0
67,0
68,0
69,0
70,1
71,0
72,0
73,1
74,1
75,0
76,1
77,0
78,0
79,0
80,1
81,0
82,0
83,0
84,1
85,1
86,0
87,1
88,0
89,0
90,1
91,1
92,0
93,0
94,1
95,0
96,1
97,0
98,0
99,0
100,1
101,0
102,0
103,0
104,1
105,1
106,1
107,0
108,1
109,0
110,1
111,1
112,1
113,1
114,0
115,0
116,0
117,0
118,1
119,0
120,1
121,0
122,1
123,0
124,0
125,1
126,1
127,1
128,0
129,0
130,1
131,1
132,0
133,0
134,0
135,1
136,1
137,1
138,0
139,0
140,0
141,0
142,1
143,0
144,1
145,0
146,0
147,0
148,1
149,1
150,0
151,0
152,1
153,1
154,0
155,1
156,0
157,1
158,0
159,0
160,1
161,0
162,0
163,1
164,1
165,0
166,1
167,1
168,0
169,0
170,0
171,0
172,1
173,1
174,1
175,0
176,1
177,1
178,0
179,1
180,0
181,0
182,1
183,1
184,0
185,0
186,0
187,1
188,0
189,1
190,1
191,0
192,0
193,1
194,0
195,1
196,0
197,0
198,1
199,1
200,0
201,0
202,0
203,0
204,0
205,0
206,1
207,0
208,0
209,0
210,0
211,1
212,0
213,0
214,0
215,1
216,1
217,1
218,1
219,0
220,0
221,1
222,0
223,1
224,1
225,1
226,1
227,0
228,0
229,1
230,1
231,0
232,1
233,0
234,0
235,1
236,0
237,0
238,1
239,1
240,0
241,1
242,0
243,1
244,0
245,0
246,0
247,0
248,0
249,1
250,1
251,0
252,1
253,0
254,1
255,1
256,0
257,0
258,1
259,1
260,1
261,0
262,0
263,1
264,0
265,1
266,1
267,0
268,1
269,0
270,1
271,1
272,1
273,0
274,0
275,0
276,0
277,0
278,0
279,0
280,1
281,0
282,0
283,0
284,1
285,1
286,0
287,0
288,0
289,0
290,1
291,1
292,0
293,1
294,0
295,0
296,1
297,0
298,0
299,0
300,0
301,0
302,0
303,0
304,1
305,1
306,0
307,0
308,0
309,1
310,1
311,1
312,0
313,1
314,0
315,0
316,1
317,0
318,0
319,0
320,0
321,0
322,0
323,0
324,1
325,0
326,0
327,0
328,0
329,1
330,1
331,0
332,1
333,1
334,0
335,1
336,1
337,1
338,1
339,1
340,1
341,0
342,1
343,0
344,1
345,1
346,0
347,1
348,0
349,0
350,0
351,1
352,0
353,0
354,0
355,0
356,0
357,1
358,1
359,1
360,1
361,0
362,1
363,0
364,1
365,0
366,0
367,0
368,0
369,0
370,0
371,0
372,0
373,0
374,0
375,1
376,0
377,0
378,1
379,0
380,1
381,0
382,0
383,0
384,1
385,1
386,0
387,0
388,0
389,0
390,1
391,0
392,1
393,0
394,0
395,1
396,1
397,0
398,0
399,1
400,1
401,1
402,0
403,0
404,0
405,1
406,1
407,0
408,0
409,0
410,0
411,0
412,0
413,1
414,0
415,0
416,0
417,1
418,1
419,0
420,1
421,1
422,0
423,1
424,1
425,0
426,1
427,0
428,0
429,0
430,0
431,0
432,1
433,1
434,1
435,0
436,1
437,0
438,0
439,0
440,0
441,0
442,0
443,0
444,0
445,1
446,0
447,0
448,1
449,0
450,1
451,1
452,1
453,1
454,1
455,1
456,0
457,0
458,1
459,1
460,1
461,0
462,1
463,1
464,0
465,1
466,1
467,1
468,1
469,0
470,1
471,1
472,0
473,1
474,1
475,0
476,0
477,0
478,1
479,1
480,1
481,1
482,0
483,0
484,0
485,0
486,0
487,0
488,0
489,1
490,0
491,0
492,0
493,0
494,1
495,0
496,0
497,1
498,1
499,0
500,0
501,0
502,0
503,1
504,0
505,0
506,1
507,1
508,0
509,0
510,0
511,0
512,0
513,0
514,0
515,0
516,1
517,1
518,0
519,1
520,1
521,0
522,0
523,0
524,0
525,1
526,1
527,0
528,0
529,1
530,0
531,1
532,1
533,0
534,1
535,0
536,0
537,1
538,0
539,0
540,1
541,0
542,0
543,0
544,0
545,0
546,0
547,0
548,1
549,0
550,1
551,1
552,1
553,1
554,0
555,0
556,1
557,1
558,0
559,0
560,0
561,1
562,1
563,0
564,0
565,0
566,1
567,0
568,1
569,0
570,0
571,1
572,1
573,1
574,0
575,1
576,0
577,0
578,0
579,0
580,0
581,1
582,1
583,0
584,1
585,0
586,0
587,0
588,1
589,0
590,1
591,0
592,1
593,0
594,0
595,0
596,0
597,0
598,0
599,1
600,1
601,1
602,1
603,0
604,0
605,0
606,0
607,0
608,1
609,0
610,1
611,1
612,1
613,0
614,1
615,0
616,0
617,0
618,1
619,0
620,0
621,0
622,0
623,0
624,1
625,0
626,0
627,0
628,1
629,0
630,0
631,0
632,0
633,0
634,1
635,1
636,1
637,1
638,1
639,0
640,0
641,0
642,1
643,0
644,1
645,1
646,0
647,0
648,0
649,1
650,0
651,0
652,1
653,0
654,1
655,0
656,1
657,0
658,0
659,0
660,1
661,0
662,0
663,0
664,1
665,1
666,0
667,0
668,1
669,0
670,1
671,0
672,1
673,0
674,0
675,1
676,1
677,0
678,0
679,1
680,0
681,1
682,0
683,1
684,1
685,1
686,1
687,0
688,0
689,0
690,0
691,0
692,0
693,0
694,0
695,1
696,0
697,0
698,0
699,1
700,1
701,0
702,0
703,0
704,0
705,1
706,0
707,0
708,0
709,0
710,1
711,1
712,1
713,0
714,0
715,0
716,0
717,0
718,0
719,0
720,1
721,1
722,1
723,0
724,0
725,0
726,1
727,1
728,1
729,1
730,0
731,0
732,1
733,0
734,1
735,0
736,0
737,0
738,1
739,0
740,0
741,0
742,0
743,0
744,1
745,0
746,1
747,1
748,0
749,1
750,0
751,1
752,0
753,1
754,0
755,0
756,0
757,0
758,1
759,1
760,1
761,0
762,1
763,0
764,1
765,1
766,0
767,0
768,1
769,0
770,1
771,1
772,0
773,1
774,1
775,0
776,1
777,0
778,0
779,0
780,0
781,0
782,1
783,0
784,0
785,1
786,1
787,0
788,1
789,0
790,1
791,0
792,1
793,0
794,1
795,0
796,0
797,0
798,0
799,1
800,0
801,0
802,0
803,1
804,1
805,0
806,0
807,0
808,1
809,0
810,1
811,0
812,0
813,1
814,0
815,1
816,1
817,1
818,0
819,0
820,0
821,0
822,0
823,1
824,0
825,0
826,1
827,0
828,1
829,1
830,0
831,1
832,0
833,0
834,0
835,1
836,0
837,0
838,0
839,1
840,1
841,1
842,0
843,1
844,0
845,0
846,0
847,0
848,1
849,1
850,1
851,1
852,1
853,1
854,1
855,0
856,0
857,0
858,0
859,1
860,0
861,0
862,0
863,1
864,0
865,1
866,0
867,0
868,0
869,1
870,1
871,1
872,0
873,0
874,1
875,1
876,1
877,0
878,0
879,1
880,0
881,1
882,0
883,0
884,1
885,1
886,0
887,0
888,0
889,0
890,1
891,0
892,1
893,0
894,1
895,0
896,1
897,1
898,0
899,0
900,1
901,0
902,0
903,1
904,0
905,0
906,1
907,1
908,0
909,0
910,1
911,0
912,1
913,1
914,0
915,0
916,0
917,0
918,1
919,1
920,1
921,0
922,0
923,0
924,0
925,1
926,0
927,1
928,1
929,0
930,1
931,1
932,0
933,0
934,0
935,1
936,0
937,1
938,0
939,0
940,0
941,0
942,1
943,0
944,1
945,1
946,0
947,1
948,1
949,0
950,0
951,0
952,0
953,0
954,1
955,0
956,0
957,1
958,0
959,0
960,0
961,0
962,1
963,0
964,0
965,1
966,0
967,0
968,0
969,0
970,1
971,1
972,1
973,1
974,0
975,0
976,1
977,0
978,0
979,0
980,1
981,0
982,0
983,0
984,0
985,1
986,0
987,0
988,1
989,0
990,0
991,0
992,0
993,0
994,0
995,1
996,0
997,0
998,0
999,0
1000,0
1001,1
1002,1
1003,0
1004,1
1005,0
1006,0
1007,0
1008,1
1009,0
1010,0
1011,0
1012,1
1013,1
1014,1
1015,0
1016,0
1017,0
1018,0
1019,1
1020,0
1021,1
1022,1
1023,0
1024,0
1025,1
1026,1
1027,1
1028,0
1029,1
1030,0
1031,0
1032,1
1033,0
1034,0
1035,0
1036,1
1037,1
1038,1
1039,1
1040,0
1041,0
1042,0
1043,1
1044,0
1045,0
1046,0
1047,1
1048,0
1049,1
1050,1
1051,0
1052,0
1053,0
1054,0
1055,0
1056,0
1057,0
1058,0
1059,1
1060,1
1061,1
1062,0
1063,0
1064,0
1065,1
1066,1
1067,0
1068,0
1069,0
1070,1
1071,1
1072,0
1073,0
1074,1
1075,0
1076,1
1077,1
1078,0
1079,0
1080,0
1081,0
1082,0
1083,1
1084,1
1085,1
1086,1
1087,0
1088,1
1089,0
1090,1
1091,0
1092,1
1093,0
1094,0
1095,0
1096,0
1097,0
1098,1
1099,1
1100,0
1101,0
1102,0
1103,1
1104,0
1105,0
1106,0
1107,1
1108,1
1109,0
1110,0
1111,1
1112,0
1113,0
1114,0
1115,1
1116,0
1117,0
1118,1
1119,0
1120,0
1121,1
1122,1
1123,0
1124,1
1125,1
1126,0
1127,1
1128,0
1129,1
1130,0
1131,0
1132,0
1133,0
1134,0
1135,0
1136,0
1137,1
1138,1
1139,1
1140,0
1141,1
1142,0
1143,1
1144,0
1145,0
1146,0
1147,0
1148,1
1149,0
1150,1
1151,0
1152,1
1153,1
1154,1
1155,1
1156,0
1157,0
1158,1
1159,0
1160,1
1161,0
1162,0
1163,0
1164,0
1165,0
1166,1
1167,0
1168,0
1169,0
1170,0
1171,1
1172,1
1173,0
1174,0
1175,0
1176,0
1177,0
1178,0
1179,1
1180,0
1181,1
1182,0
1183,0
1184,0
1185,0
1186,1
1187,0
1188,0
1189,0
1190,1
1191,1
1192,0
1193,0
1194,0
1195,1
1196,0
1197,0
1198,0
1199,0
1200,1
1201,1
1202,0
1203,1
1204,1
1205,1
1206,0
1207,0
1208,1
1209,1
1210,0
1211,1
1212,1
1213,1
1214,1
1215,1
1216,0
1217,0
1218,0
1219,1
1220,0
1221,0
1222,0
1223,0
1224,1
1225,0
1226,1
1227,1
1228,1
1229,0
1230,0
1231,1
1232,0
1233,0
1234,0
1235,0
1236,1
1237,1
1238,0
1239,1
1240,1
1241,0
1242,0
1243,0
1244,1
1245,0
1246,0
1247,0
1248,0
1249,0
1250,1
1251,0
1252,0
1253,1
1254,1
1255,0
1256,0
1257,0
1258,0
1259,0
1260,1
1261,1
1262,0
1263,1
1264,1
1265,0
1266,0
1267,0
1268,0
1269,0
1270,1
1271,1
1272,0
1273,1
1274,0
1275,0
1276,1
1277,0
1278,0
1279,0
1280,0
1281,0
1282,0
1283,0
1284,0
1285,0
1286,0
1287,0
1288,0
1289,1
1290,0
1291,0
1292,1
1293,1
1294,0
1295,1
1296,1
1297,0
1298,0
1299,1
1300,0
1301,1
1302,0
1303,0
1304,0
1305,1
1306,0
1307,1
1308,1
1309,0
1310,1
1311,1
1312,1
1313,0
1314,0
1315,0
1316,0
1317,0
1318,0
1319,0
1320,0
1321,0
1322,0
1323,1
1324,0
1325,0
1326,0
1327,0
1328,0
1329,0
1330,1
1331,0
1332,0
1333,0
1334,0
1335,0
1336,1
1337,0
1338,0
1339,1
1340,1
1341,1
1342,1
1343,1
1344,0
1345,0
1346,0
1347,1
1348,0
1349,1
1350,1
1351,1
1352,1
1353,0
1354,0
1355,1
1356,1
1357,1
1358,0
1359,0
1360,0
1361,0
1362,1
1363,1
1364,1
1365,0
1366,0
1367,0
1368,1
1369,1
1370,0
1371,1
1372,0
1373,0
1374,0
1375,0
1376,1
1377,0
1378,1
1379,0
1380,0
1381,0
1382,0
1383,1
1384,0
1385,1
1386,0
1387,1
1388,0
1389,0
1390,0
1391,1
1392,0
1393,0
1394,0
1395,0
1396,0
1397,0
1398,0
1399,1
1400,1
1401,0
1402,0
1403,0
1404,1
1405,0
1406,0
1407,0
1408,0
1409,0
1410,0
1411,1
1412,0
1413,1
1414,1
1415,1
1416,1
1417,1
1418,0
1419,1
1420,0
1421,1
1422,1
1423,0
1424,0
1425,0
1426,0
1427,1
1428,0
1429,1
1430,1
1431,0
1432,0
1433,0
1434,1
1435,0
1436,1
1437,1
1438,1
1439,1
1440,0
1441,0
1442,1
1443,0
1444,1
1445,1
1446,0
1447,0
1448,1
1449,1
1450,0
1451,1
1452,0
1453,1
1454,0
1455,0
1456,0
1457,1
1458,1
1459,0
1460,0
1461,1
1462,1
1463,1
1464,0
1465,0
1466,1
1467,1
1468,0
1469,0
1470,1
1471,1
1472,1
1473,1
1474,0
1475,1
1476,1
1477,0
1478,1
1479,1
1480,0
1481,1
1482,0
1483,0
1484,1
1485,0
1486,0
1487,1
1488,1
1489,0
1490,1
1491,0
1492,1
1493,0
1494,1
1495,1
1496,1
1497,0
1498,0
1499,1
1500,0
1501,0
1502,0
1503,0
1504,0
1505,0
1506,0
1507,1
1508,1
1509,0
1510,1
1511,1
1512,1
1513,0
1514,0
1515,0
1516,1
1517,1
1518,0
1519,1
1520,0
1521,0
1522,1
1523,0
1524,0
1525,1
1526,0
1527,0
1528,0
1529,0
1530,0
1531,0
1532,0
1533,0
1534,0
1535,0
1536,0
1537,0
1538,1
1539,0
1540,0
1541,1
1542,1
1543,1
1544,0
1545,0
1546,0
1547,0
1548,1
1549,0
1550,1
1551,1
1552,0
1553,1
1554,0
1555,1
1556,1
1557,0
1558,0
1559,0
1560,1
1561,1
1562,1
1563,0
1564,1
1565,0
1566,0
1567,1
1568,1
1569,0
1570,1
1571,0
1572,0
1573,1
1574,0
1575,1
1576,1
1577,0
1578,0
1579,1
1580,1
1581,0
1582,1
1583,1
1584,0
1585,1
1586,0
1587,0
1588,0
1589,1
1590,1
1591,1
1592,0
1593,1
1594,0
1595,1
1596,1
1597,0
1598,0
1599,0
......@@ -2,7 +2,7 @@
import sys
import os
import csv
from random import random
from random import seed, random
from pprint import pprint as pp
from math import log, exp
import numpy as np
......@@ -45,7 +45,7 @@ def sigmoid(v):
return 1.0/(1+exp(-v))
def sigmoidGradient(v):
return [a*b for a,b in zip([sigmoid(x) for x in v], [sigmoid(x) for x in v])]
return [a*b for a,b in zip([sigmoid(x) for x in v], [1-sigmoid(x) for x in v])]
def regularization(cost, w1, w2, lamb, m):
reg = sum(w1*w1)+sum(w2*w2)
......@@ -104,34 +104,66 @@ def cost(li, lh, lo, weights, X, y, lamb):
return J,grad
def fit(X, y, li, lh, lo, weights, lamb, eta, passes=10000, verbose=True):
def fit(X, y, li, lh, lo, weight, lamb, eta, passes=1000, verbose=True):
weights = np.copy(weight)
for i in range(1,passes+1):
J, dw = cost(li, lh, lo, weights, X, y, lamb)
#print(weights.shape, dw.shape)
weights += -eta*dw
print(i,"\r", end='')
if verbose and i%1000 == 0:
if verbose and i%(passes/10) == 0:
print()
print(J)
return weights
def predict(x, w1, w2):
x=[1]+x #58x1
#x=[1]+x #58x1
x = np.array(x)
h1 = sigmoid(np.dot(w1,x).tolist()) #28x58 * 58x1 = 28x1
h1 = [sigmoid(z) for z in np.dot(w1,x).tolist()] #28x58 * 58x1 = 28x1
h1 = [1]+h1
h2 = sigmoid(np.dot(w2,h1).tolist()) #1x29 * 29x1 = 1x1
return 1 if h2>0.5 else 0
h2 = sigmoid(np.dot(w2,h1).tolist()[0]) #1x29 * 29x1 = 1x1
return h2
def setWeightsFromFile(weights):
if os.path.isfile("weights"):
with open("weights","rb") as wfile:
weights = np.load(wfile)
return True
return False
if __name__ == "__main__":
np.random.rand(47)
X,y = getData(params["train"])
tX,ty = getData(params["test"], isTrain=False)
# print(len(X), len(X[0]), len(y), X[0])
# print(len(tX), len(ty), tX[0])
li,lh,lo = tuple(params["layers"])
weights = np.array([random() for _ in range(lh*(li+1)+lo*(lh+1))])
weights = np.random.rand(lh*(li+1)+lo*(lh+1))
lamb,eta = 0.1,0.1
fit(X, y, li, lh, lo, weights, lamb, eta)
if not setWeightsFromFile(weights):
weights = fit(X, y, li, lh, lo, weights, lamb, eta, passes=300)
with open("weights","wb") as wfile:
np.save(wfile, weights)
w1 = weights[:(li+1)*lh].reshape(lh,li+1) #28x58
w2 = weights[(li+1)*lh:].reshape(lo,lh+1) #1x29
py = []
for x in X:
py.append(predict(x,w1,w2))
print(py)
# print("train accuracy", len(list(filter(lambda z:z==0,[a-b for a,b in zip(py,y)])))*1.0/len(y))
# pty = []
# for x in tX:
# pty.append(predict(x,w1,w2))
# with open("answer.txt","w") as dest:
# writer = csv.writer(dest)
# for i,ans in enumerate(pty):
# writer.writerow([i,ans])
# J,grad = cost(li, lh, lo, weights, X, y, 0.1)
# print(J,grad)
#print(len(w1), len(w1[0]), len(w2), len(w2[0]))
\ No newline at end of file
#! /usr/bin/env python3
import sys
import os
import csv
from random import seed, random
from pprint import pprint as pp
from math import log, exp
import numpy as np
def doNormalize(X):
#do 0 mean 1 std normalization
x1 = np.array(X,dtype=float)
for i in range(len(X[0])):
col = x1[:,i]
mean,std = col.mean(),col.std()
std = std if std!=0.0 else 1.0
x1[:,i] = (x1[:,i]-mean)/std
return x1.tolist()
def getData(srcF, isTrain=True, addBias=True, normalize=True):
X,y=[],[]
with open(srcF) as src:
reader = csv.reader(src, delimiter=',')
for i,row in enumerate(reader):
temp = []
if addBias:
temp.append(1)
end = -1 if isTrain else len(row)
temp.extend(row[:end])
#correct data type
X.append(list(map(float, temp)))
if isTrain:
v = int(row[-1])
entry = [1,0] if v==1 else [0,1]
y.append(entry)
if normalize:
X = doNormalize(X)
#print(X[0])
return (np.array(X),np.array(y))
def sigmoid(v):
return 1.0/(1+np.exp(-v))
def sigmoidDiff(v):
return sigmoid(v)*(1-sigmoid(v))
def feedforward(model, X):
w1,w2 = model['w1'],model['w2'] #58x28, 29x2
z1 = X.dot(w1) #mx58 * 58x28 = mx28
a1 = sigmoid(z1) #mx28
a1 = np.insert(a1,0,np.ones(a1.shape[0]),axis=1) #mx29
z2 = a1.dot(w2) #mx29 * 29x2
h = sigmoid(z2) #mx2
return h
def restrictProb(a):
return min([max([a,1e-15]), 1-1e-15])
def cost(model, X, y):
m = X.shape[0]
h = feedforward(model, X)
y2 = y.astype(float)
vf = np.vectorize(restrictProb)
py = vf(h)
loss = -(1.0/m)*np.sum(y*np.log(py)+(1-y)*np.log(1-py)) #mx2 .* mx2
#regularize
w1,w2 = model['w1'],model['w2']
loss += model['lambda']*(np.sum(np.square(w1)) + np.sum(np.square(w2)))/(2*m)
return loss
def predict(model, x):
w1,w2 = model['w1'],model['w2']
#print(x.shape, w1.shape)
z1 = x.dot(w1) #1x58 * 58x28 = 1x28
a1 = sigmoid(z1) #1x28
a1 = np.insert(a1,0,1)
z2 = a1.dot(w2) #1x29 x 29x2
h = sigmoid(z2)
return 1-np.argmax(h)
def fit(model, X, y, passes=1000):
m = X.shape[0]
w1,w2 = model['w1'],model['w2'] #58x28, 29x2
li,lh,lo=model['li'],model['lh'],model['lo']
for i in range(passes):
z1 = X.dot(w1) #mx58 * 58x28 = mx28
a2 = sigmoid(z1) #mx28
a2 = np.insert(a2,0,np.ones(a2.shape[0]),axis=1) #mx29
z2 = a2.dot(w2) #mx29 * 29x2
h = sigmoid(z2) #mx2
#backpropagation
del3 = h-y #mx2
z1 = np.insert(z1,0,np.ones(z1.shape[0]),axis=1)
#mx29
del2 = del3.dot(w2.reshape(lo,lh+1))*sigmoidDiff(z1)
#mx2 * 2x29 .* mx29 = mx29
del2 = del2[:,1:] #mx28
dw1 = np.dot(X.T,del2) #58xm*mx28=58x28
dw2 = (a2.T).dot(del3) #29xm*mx2=29x2
dw1 += (model['lambda']/m)*w1
dw2 += (model['lambda']/m)*w2
w1 += -model['eta']*dw1
w2 += -model['eta']*dw2
model['w1'] = w1
model['w2'] = w2
if i % (passes/10)==0:
print(i,cost(model, X, y))
return model
if __name__ == "__main__":
np.random.seed(47)
model = {}
model = {'li':57,'lh':85,'lo':2,'lambda':0.05,'eta':0.01}
# model['w1'] = np.random.randn(model['li']+1, model['lh'])/np.sqrt(model['li']+1) #58x28
# model['w2'] = np.random.randn(model['lh']+1, model['lo'])/np.sqrt(model['lh']+1) #29x2
model['w1'] = np.random.rand(model['li']+1, model['lh'])*0.24 - 0.12
model['w2'] = np.random.rand(model['lh']+1, model['lo'])*0.24 - 0.12
X,y = getData("Train.csv")
tX,ty = getData("TestX.csv",isTrain=False)
#cost(model, X, y)
# for h in [57/3, 57/2, 2*57/3, 57, 3*57/2]:
# h=int(h)
# model = {'li':57,'lh':h,'lo':2,'lambda':0.1,'eta':0.1}
# model['w1'] = np.random.randn(model['li']+1, model['lh'])/np.sqrt(model['li']+1) #58x28
# model['w2'] = np.random.randn(model['lh']+1, model['lo'])/np.sqrt(model['lh']+1) #29x2
model = fit(model, X, y)
m = X.shape[0]
py,y2=[],[]
for i,row in enumerate(tX):
ans = predict(model, np.array(row))
py.append(ans)
# y2.append(1 if y[i][0]==1 else 0)
# acc = m-np.sum(abs(np.array(py)-np.array(y2)))
# print(h, acc*100/m)
with open("answer.txt","w") as wfile:
writer = csv.writer(wfile)
writer.writerow(['Id','Label'])
for i,ans in enumerate(py):
writer.writerow([i,ans])
# acc = m-np.sum(abs(np.array(py)-np.array(y2)))
# print(acc*100/m)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment