Commit 93a40cd6 authored by Sushant Mahajan's avatar Sushant Mahajan

added progress bar

parent 94c9f577
Pipeline #322 skipped
...@@ -6,6 +6,7 @@ from random import seed, random ...@@ -6,6 +6,7 @@ from random import seed, random
from pprint import pprint as pp from pprint import pprint as pp
from math import log, exp from math import log, exp
import numpy as np import numpy as np
from fractions import gcd
removed=[] removed=[]
...@@ -90,7 +91,7 @@ def predict(model, x): ...@@ -90,7 +91,7 @@ def predict(model, x):
return 1-np.argmax(h) return 1-np.argmax(h)
def fit(model, X, y, passes=1000): def fit(model, X, y, passes=1000, verbose=False):
m = X.shape[0] m = X.shape[0]
w1,w2 = model['w1'],model['w2'] #58x28, 29x2 w1,w2 = model['w1'],model['w2'] #58x28, 29x2
li,lh,lo=model['li'],model['lh'],model['lo'] li,lh,lo=model['li'],model['lh'],model['lo']
...@@ -128,7 +129,9 @@ def fit(model, X, y, passes=1000): ...@@ -128,7 +129,9 @@ def fit(model, X, y, passes=1000):
w1,w2 = ws[idx][0],ws[idx][1] w1,w2 = ws[idx][0],ws[idx][1]
model['w1'],model['w2'] = w1,w2 model['w1'],model['w2'] = w1,w2
if i % (passes/10)==0: print("training... %0.2f%%\r" % (i*100/passes), end='')
if verbose and i % (passes/10)==0:
print(i,costs[idx]) print(i,costs[idx])
return model return model
...@@ -154,9 +157,10 @@ if __name__ == "__main__": ...@@ -154,9 +157,10 @@ if __name__ == "__main__":
# model = {'li':57,'lh':h,'lo':2,'lambda':0.1,'eta':0.1} # model = {'li':57,'lh':h,'lo':2,'lambda':0.1,'eta':0.1}
# model['w1'] = np.random.randn(model['li']+1, model['lh'])/np.sqrt(model['li']+1) #58x28 # model['w1'] = np.random.randn(model['li']+1, model['lh'])/np.sqrt(model['li']+1) #58x28
# model['w2'] = np.random.randn(model['lh']+1, model['lo'])/np.sqrt(model['lh']+1) #29x2 # model['w2'] = np.random.randn(model['lh']+1, model['lo'])/np.sqrt(model['lh']+1) #29x2
model = fit(model, X, y, passes=1500) model = fit(model, X, y, passes=1500)
print("\npredicting...")
m = X.shape[0] m = X.shape[0]
py,y2=[],[] py,y2=[],[]
for i,row in enumerate(tX): for i,row in enumerate(tX):
...@@ -173,5 +177,7 @@ if __name__ == "__main__": ...@@ -173,5 +177,7 @@ if __name__ == "__main__":
for i,ans in enumerate(py): for i,ans in enumerate(py):
writer.writerow([i,ans]) writer.writerow([i,ans])
print("done...")
# acc = m-np.sum(abs(np.array(py)-np.array(y2))) # acc = m-np.sum(abs(np.array(py)-np.array(y2)))
# print(acc*100/m) # print(acc*100/m)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment