Commit e53add22 authored by Sushant Mahajan's avatar Sushant Mahajan

fixed normalization for first column and added parameter search for gradient descnet

parent 72cdc309
Pipeline #298 skipped
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -5,14 +5,14 @@ Id,Label
3,0
4,0
5,0
6,0
6,1
7,0
8,0
9,0
10,1
11,1
11,0
12,0
13,1
13,0
14,1
15,0
16,1
......@@ -38,12 +38,12 @@ Id,Label
36,0
37,1
38,0
39,0
39,1
40,0
41,1
42,0
43,1
44,0
44,1
45,0
46,0
47,0
......@@ -51,11 +51,11 @@ Id,Label
49,0
50,1
51,0
52,1
52,0
53,0
54,1
55,1
56,1
56,0
57,0
58,1
59,1
......@@ -66,10 +66,10 @@ Id,Label
64,0
65,1
66,0
67,1
67,0
68,0
69,0
70,1
70,0
71,0
72,0
73,0
......@@ -92,7 +92,7 @@ Id,Label
90,1
91,1
92,0
93,0
93,1
94,0
95,0
96,1
......@@ -102,9 +102,9 @@ Id,Label
100,1
101,0
102,0
103,1
103,0
104,0
105,1
105,0
106,1
107,0
108,1
......@@ -114,13 +114,13 @@ Id,Label
112,1
113,1
114,0
115,1
115,0
116,1
117,0
118,1
119,0
120,1
121,0
120,0
121,1
122,1
123,0
124,0
......@@ -130,12 +130,12 @@ Id,Label
128,0
129,0
130,1
131,1
131,0
132,0
133,0
134,0
135,1
136,1
136,0
137,1
138,0
139,0
......@@ -147,8 +147,8 @@ Id,Label
145,0
146,0
147,0
148,1
149,1
148,0
149,0
150,0
151,0
152,1
......@@ -162,8 +162,8 @@ Id,Label
160,1
161,0
162,0
163,1
164,1
163,0
164,0
165,0
166,1
167,1
......@@ -186,7 +186,7 @@ Id,Label
184,0
185,0
186,0
187,1
187,0
188,0
189,1
190,1
......@@ -207,28 +207,28 @@ Id,Label
205,0
206,1
207,0
208,0
208,1
209,0
210,0
211,1
212,0
213,0
214,0
215,0
215,1
216,1
217,0
218,1
219,0
220,0
220,1
221,1
222,0
223,1
223,0
224,1
225,1
226,1
227,0
228,0
229,1
228,1
229,0
230,1
231,0
232,1
......@@ -237,7 +237,7 @@ Id,Label
235,1
236,0
237,0
238,1
238,0
239,1
240,0
241,1
......@@ -250,24 +250,24 @@ Id,Label
248,0
249,1
250,1
251,1
252,1
251,0
252,0
253,0
254,1
255,1
255,0
256,0
257,0
258,1
259,1
260,1
261,1
261,0
262,0
263,0
264,0
265,1
266,1
267,0
268,1
268,0
269,0
270,1
271,1
......@@ -281,7 +281,7 @@ Id,Label
279,0
280,1
281,0
282,1
282,0
283,0
284,1
285,1
......@@ -309,7 +309,7 @@ Id,Label
307,0
308,0
309,1
310,1
310,0
311,1
312,0
313,0
......@@ -319,11 +319,11 @@ Id,Label
317,0
318,0
319,0
320,0
320,1
321,0
322,0
323,0
324,1
324,0
325,0
326,0
327,0
......@@ -383,7 +383,7 @@ Id,Label
381,0
382,0
383,0
384,0
384,1
385,1
386,1
387,0
......@@ -398,7 +398,7 @@ Id,Label
396,1
397,0
398,0
399,0
399,1
400,0
401,1
402,0
......@@ -436,7 +436,7 @@ Id,Label
434,1
435,0
436,1
437,1
437,0
438,0
439,0
440,0
......@@ -449,7 +449,7 @@ Id,Label
447,0
448,1
449,0
450,0
450,1
451,1
452,1
453,1
......@@ -464,16 +464,16 @@ Id,Label
462,1
463,1
464,0
465,0
465,1
466,1
467,1
467,0
468,1
469,0
470,0
471,1
470,1
471,0
472,0
473,1
474,0
474,1
475,0
476,0
477,0
......@@ -509,14 +509,14 @@ Id,Label
507,1
508,0
509,0
510,1
510,0
511,0
512,0
513,0
514,0
515,0
516,1
517,0
517,1
518,0
519,1
520,1
......@@ -524,13 +524,13 @@ Id,Label
522,0
523,0
524,0
525,1
525,0
526,1
527,0
528,0
529,0
530,0
531,0
531,1
532,1
533,0
534,1
......@@ -561,16 +561,16 @@ Id,Label
559,0
560,0
561,1
562,1
562,0
563,0
564,1
564,0
565,0
566,1
567,0
568,0
568,1
569,0
570,0
571,1
571,0
572,1
573,1
574,0
......@@ -581,7 +581,7 @@ Id,Label
579,0
580,0
581,0
582,1
582,0
583,0
584,1
585,0
......@@ -591,7 +591,7 @@ Id,Label
589,0
590,1
591,1
592,1
592,0
593,0
594,0
595,0
......@@ -603,7 +603,7 @@ Id,Label
601,1
602,0
603,0
604,1
604,0
605,0
606,0
607,0
......@@ -623,11 +623,11 @@ Id,Label
621,0
622,0
623,1
624,1
624,0
625,0
626,0
627,0
628,1
628,0
629,0
630,0
631,0
......@@ -635,7 +635,7 @@ Id,Label
633,0
634,1
635,1
636,1
636,0
637,1
638,1
639,0
......@@ -648,7 +648,7 @@ Id,Label
646,0
647,0
648,0
649,1
649,0
650,0
651,0
652,1
......@@ -658,7 +658,7 @@ Id,Label
656,1
657,0
658,0
659,1
659,0
660,1
661,0
662,0
......@@ -667,31 +667,31 @@ Id,Label
665,1
666,0
667,0
668,1
668,0
669,0
670,1
671,0
672,1
673,0
674,0
675,1
675,0
676,1
677,1
677,0
678,0
679,1
680,0
681,1
682,0
683,1
684,1
685,1
683,0
684,0
685,0
686,1
687,0
688,0
689,0
690,0
691,0
692,1
692,0
693,0
694,0
695,1
......@@ -711,13 +711,13 @@ Id,Label
709,0
710,1
711,1
712,1
712,0
713,0
714,0
715,0
716,0
717,1
718,0
717,0
718,1
719,0
720,1
721,0
......@@ -744,11 +744,11 @@ Id,Label
742,0
743,0
744,1
745,0
746,1
745,1
746,0
747,1
748,0
749,1
749,0
750,0
751,1
752,0
......@@ -768,8 +768,8 @@ Id,Label
766,0
767,0
768,1
769,1
770,0
769,0
770,1
771,1
772,0
773,1
......@@ -787,16 +787,16 @@ Id,Label
785,1
786,0
787,0
788,1
788,0
789,0
790,0
790,1
791,0
792,1
793,0
794,1
794,0
795,0
796,0
797,1
797,0
798,0
799,1
800,0
......@@ -827,7 +827,7 @@ Id,Label
825,0
826,1
827,0
828,0
828,1
829,1
830,0
831,1
......@@ -839,7 +839,7 @@ Id,Label
837,0
838,0
839,1
840,1
840,0
841,1
842,0
843,1
......@@ -847,7 +847,7 @@ Id,Label
845,0
846,0
847,0
848,1
848,0
849,1
850,1
851,1
......@@ -858,7 +858,7 @@ Id,Label
856,0
857,0
858,0
859,1
859,0
860,0
861,0
862,0
......@@ -875,7 +875,7 @@ Id,Label
873,0
874,1
875,1
876,1
876,0
877,0
878,0
879,1
......@@ -891,9 +891,9 @@ Id,Label
889,0
890,1
891,0
892,1
892,0
893,0
894,1
894,0
895,0
896,0
897,1
......@@ -918,7 +918,7 @@ Id,Label
916,0
917,0
918,1
919,1
919,0
920,0
921,0
922,0
......@@ -937,9 +937,9 @@ Id,Label
935,1
936,0
937,1
938,1
938,0
939,0
940,1
940,0
941,0
942,1
943,0
......@@ -967,7 +967,7 @@ Id,Label
965,1
966,0
967,0
968,0
968,1
969,0
970,1
971,1
......@@ -995,7 +995,7 @@ Id,Label
993,0
994,0
995,1
996,1
996,0
997,0
998,0
999,0
......@@ -1015,12 +1015,12 @@ Id,Label
1013,0
1014,1
1015,0
1016,1
1016,0
1017,0
1018,0
1019,1
1020,0
1021,1
1021,0
1022,1
1023,0
1024,0
......@@ -1028,22 +1028,22 @@ Id,Label
1026,1
1027,1
1028,0
1029,0
1030,0
1029,1
1030,1
1031,0
1032,1
1033,0
1034,0
1035,0
1036,1
1037,1
1038,1
1039,1
1037,0
1038,0
1039,0
1040,0
1041,0
1042,0
1043,1
1044,0
1044,1
1045,0
1046,0
1047,1
......@@ -1069,7 +1069,7 @@ Id,Label
1067,0
1068,0
1069,0
1070,1
1070,0
1071,0
1072,0
1073,0
......@@ -1084,12 +1084,12 @@ Id,Label
1082,0
1083,1
1084,1
1085,0
1085,1
1086,1
1087,0
1088,1
1088,0
1089,0
1090,1
1090,0
1091,0
1092,1
1093,0
......@@ -1102,7 +1102,7 @@ Id,Label
1100,0
1101,0
1102,0
1103,1
1103,0
1104,0
1105,0
1106,0
......@@ -1128,7 +1128,7 @@ Id,Label
1126,0
1127,1
1128,0
1129,1
1129,0
1130,0
1131,0
1132,0
......@@ -1136,9 +1136,9 @@ Id,Label
1134,0
1135,0
1136,0
1137,1
1137,0
1138,1
1139,1
1139,0
1140,0
1141,1
1142,0
......@@ -1152,7 +1152,7 @@ Id,Label
1150,0
1151,0
1152,1
1153,1
1153,0
1154,1
1155,1
1156,0
......@@ -1165,7 +1165,7 @@ Id,Label
1163,0
1164,0
1165,0
1166,1
1166,0
1167,0
1168,0
1169,1
......@@ -1173,7 +1173,7 @@ Id,Label
1171,1
1172,1
1173,0
1174,0
1174,1
1175,0
1176,0
1177,0
......@@ -1189,23 +1189,23 @@ Id,Label
1187,0
1188,0
1189,0
1190,0
1191,1
1190,1
1191,0
1192,0
1193,0
1194,0
1195,1
1195,0
1196,0
1197,0
1198,0
1199,0
1200,1
1200,0
1201,1
1202,0
1203,1
1204,1
1205,1
1206,1
1206,0
1207,0
1208,1
1209,1
......@@ -1214,9 +1214,9 @@ Id,Label
1212,1
1213,1
1214,0
1215,1
1215,0
1216,0
1217,1
1217,0
1218,0
1219,1
1220,0
......@@ -1225,7 +1225,7 @@ Id,Label
1223,0
1224,1
1225,0
1226,1
1226,0
1227,0
1228,0
1229,0
......@@ -1246,8 +1246,8 @@ Id,Label
1244,1
1245,0
1246,0
1247,0
1248,1
1247,1
1248,0
1249,0
1250,0
1251,0
......@@ -1259,7 +1259,7 @@ Id,Label
1257,0
1258,0
1259,0
1260,1
1260,0
1261,1
1262,0
1263,1
......@@ -1269,7 +1269,7 @@ Id,Label
1267,0
1268,0
1269,0
1270,1
1270,0
1271,1
1272,0
1273,1
......@@ -1285,13 +1285,13 @@ Id,Label
1283,0
1284,0
1285,0
1286,1
1286,0
1287,0
1288,0
1289,1
1290,0
1291,0
1292,1
1292,0
1293,1
1294,0
1295,1
......@@ -1309,9 +1309,9 @@ Id,Label
1307,1
1308,1
1309,0
1310,1
1310,0
1311,1
1312,0
1312,1
1313,0
1314,0
1315,0
......@@ -1333,7 +1333,7 @@ Id,Label
1331,0
1332,0
1333,0
1334,0
1334,1
1335,0
1336,1
1337,0
......@@ -1347,7 +1347,7 @@ Id,Label
1345,0
1346,0
1347,1
1348,1
1348,0
1349,1
1350,1
1351,1
......@@ -1377,7 +1377,7 @@ Id,Label
1375,0
1376,1
1377,0
1378,1
1378,0
1379,0
1380,0
1381,0
......@@ -1395,7 +1395,7 @@ Id,Label
1393,0
1394,0
1395,0
1396,1
1396,0
1397,0
1398,0
1399,1
......@@ -1413,7 +1413,7 @@ Id,Label
1411,1
1412,0
1413,1
1414,0
1414,1
1415,0
1416,1
1417,1
......@@ -1433,10 +1433,10 @@ Id,Label
1431,1
1432,0
1433,0
1434,1
1434,0
1435,0
1436,1
1437,1
1437,0
1438,1
1439,1
1440,0
......@@ -1461,7 +1461,7 @@ Id,Label
1459,0
1460,0
1461,1
1462,1
1462,0
1463,1
1464,0
1465,0
......@@ -1472,18 +1472,18 @@ Id,Label
1470,1
1471,1
1472,1
1473,1
1473,0
1474,0
1475,1
1476,1
1477,1
1477,0
1478,1
1479,1
1480,0
1481,1
1482,0
1483,0
1484,0
1484,1
1485,0
1486,0
1487,1
......@@ -1499,14 +1499,14 @@ Id,Label
1497,0
1498,0
1499,1
1500,0
1500,1
1501,0
1502,0
1503,0
1504,0
1505,0
1506,0
1507,1
1507,0
1508,1
1509,0
1510,1
......@@ -1524,11 +1524,11 @@ Id,Label
1522,1
1523,0
1524,0
1525,1
1525,0
1526,0
1527,0
1528,0
1529,0
1529,1
1530,0
1531,0
1532,0
......@@ -1560,7 +1560,7 @@ Id,Label
1558,0
1559,0
1560,1
1561,1
1561,0
1562,1
1563,0
1564,1
......@@ -1584,15 +1584,15 @@ Id,Label
1582,1
1583,1
1584,0
1585,1
1585,0
1586,0
1587,0
1588,1
1588,0
1589,1
1590,1
1591,1
1592,0
1593,1
1593,0
1594,0
1595,1
1596,1
......
......@@ -7,14 +7,21 @@ from pprint import pprint as pp
from math import log, exp
import numpy as np
def doNormalize(X):
removed=[]
def doNormalize(X,isTrain):
#do 0 mean 1 std normalization
x1 = np.array(X,dtype=float)
for i in range(len(X[0])):
col = x1[:,i]
mean,std = col.mean(),col.std()
std = std if std!=0.0 else 1.0
x1[:,i] = (x1[:,i]-mean)/std
#std = std if std!=0.0 else 1.0
if i!=0:
if std<0.1 and isTrain:
removed.append(i)
else:
x1[:,i] = (x1[:,i]-mean)/max(std,1.0)
x1 = np.delete(x1, removed, axis=1)
return x1.tolist()
def getData(srcF, isTrain=True, addBias=True, normalize=True):
......@@ -37,7 +44,7 @@ def getData(srcF, isTrain=True, addBias=True, normalize=True):
y.append(entry)
if normalize:
X = doNormalize(X)
X = doNormalize(X,isTrain)
#print(X[0])
return (np.array(X),np.array(y))
......@@ -111,14 +118,21 @@ def fit(model, X, y, passes=1000):
dw1 += (model['lambda']/m)*w1
dw2 += (model['lambda']/m)*w2
w1 += -model['eta']*dw1
w2 += -model['eta']*dw2
model['w1'] = w1
model['w2'] = w2
costs = []
ws = []
for eta in model['eta']:
tw1 = w1-eta*dw1
tw2 = w2-eta*dw2
model['w1'] = tw1
model['w2'] = tw2
costs.append(cost(model,X,y))
ws.append((tw1,tw2))
idx = np.argmin(costs)
w1,w2 = ws[idx][0],ws[idx][1]
model['w1'],model['w2'] = w1,w2
if i % (passes/10)==0:
print(i,cost(model, X, y))
print(i,costs[idx])
return model
......@@ -126,15 +140,17 @@ def fit(model, X, y, passes=1000):
if __name__ == "__main__":
np.random.seed(47)
np.seterr(over='raise')
X,y = getData("Train.csv")
tX,ty = getData("TestX.csv",isTrain=False)
model = {}
model = {'li':57,'lh':85,'lo':2,'lambda':0.1,'eta':0.01}
model = {'li':X.shape[1]-1,'lh':int(3*(X.shape[1]-1)/2),'lo':2,'lambda':0.1,'eta':[0.01,0.06,0.1]}
# model['w1'] = np.random.randn(model['li']+1, model['lh'])/np.sqrt(model['li']+1) #58x28
# model['w2'] = np.random.randn(model['lh']+1, model['lo'])/np.sqrt(model['lh']+1) #29x2
model['w1'] = np.random.rand(model['li']+1, model['lh'])*0.24 - 0.12
model['w2'] = np.random.rand(model['lh']+1, model['lo'])*0.24 - 0.12
X,y = getData("Train.csv")
tX,ty = getData("TestX.csv",isTrain=False)
#cost(model, X, y)
# for h in [57/3, 57/2, 2*57/3, 57, 3*57/2]:
# h=int(h)
......@@ -142,7 +158,7 @@ if __name__ == "__main__":
# model['w1'] = np.random.randn(model['li']+1, model['lh'])/np.sqrt(model['li']+1) #58x28
# model['w2'] = np.random.randn(model['lh']+1, model['lo'])/np.sqrt(model['lh']+1) #29x2
model = fit(model, X, y, passes=500)
model = fit(model, X, y, passes=1500)
m = X.shape[0]
py,y2=[],[]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment