로지스틱 회귀 실습 -2
Logistic Regression
Goal : 인공신경망의 기본이 되는 로지스틱 회귀로 남녀 성별 예측해보기
이번 실습에서는 인공신경망의 기본이 되는 로지스틱 회귀를 사용해서 개인의 키로 성별을 예측하는 모형을 만들어볼 것입니다. 실습의 목표는 다음과 같습니다.
- Theano의 기본 사용법을 익힌다.
- 로지스틱 회귀 모형에 대해 이해한다.
- gradient descent를 구현하는 법을 배운다.
import theano
import theano.tensor as T
import numpy as np
자료를 만드는 부분입니다
x = T.vector('x')
w = T.scalar('w')
b = T.scalar('b')
y = T.vector('y')
out = 1 / (1 + T.exp(-(w * x) - b))
N = 100
손으로 계산한 sigmoid function 값
f = theano.function([x, w, b], out)
/Users/Simiro/anaconda3/lib/python3.5/site-packages/theano/tensor/signal/downsample.py:5: UserWarning: downsample module has been moved to the pool module.
warnings.warn("downsample module has been moved to the pool module.")
hair = np.concatenate([np.random.randn(N) +1, np.random.randn(N)])
prob = f(hair,
np.array(1, dtype = theano.config.floatX),
np.array(-1, dtype = theano.config.floatX))
prob
array([ 0.11558904, 0.49593163, 0.50266056, 0.18066072, 0.41688148,
0.49666854, 0.39019999, 0.37787053, 0.24365962, 0.5872073 ,
0.64218682, 0.87913439, 0.46804282, 0.49138378, 0.53244897,
0.56563533, 0.38386301, 0.72817262, 0.71435712, 0.8421329 ,
0.27449893, 0.31089732, 0.28456894, 0.5276074 , 0.77775312,
0.33467973, 0.11895283, 0.47499868, 0.43764116, 0.34815379,
0.55661846, 0.09179522, 0.38217598, 0.7989482 , 0.49020603,
0.3459761 , 0.5996382 , 0.57444145, 0.75427615, 0.10836856,
0.49189244, 0.09332962, 0.47537184, 0.63066922, 0.14194516,
0.78279751, 0.43462003, 0.55993874, 0.70784825, 0.25048662,
0.77757078, 0.63489736, 0.59578119, 0.34744468, 0.65984086,
0.25719687, 0.23028135, 0.27604124, 0.68334074, 0.67294121,
0.70615617, 0.5085993 , 0.67918025, 0.56140787, 0.64249315,
0.74695565, 0.3179371 , 0.5259163 , 0.84169246, 0.94411889,
0.47612601, 0.45173337, 0.48000097, 0.59060334, 0.42643188,
0.87934579, 0.57625341, 0.23529187, 0.09608668, 0.11590786,
0.22637512, 0.79000462, 0.2245848 , 0.3175366 , 0.10137861,
0.64306132, 0.50631107, 0.73406696, 0.50716892, 0.58947254,
0.44712662, 0.34006081, 0.15032411, 0.38505774, 0.38667379,
0.46387625, 0.70263627, 0.40758166, 0.51992322, 0.37372992,
0.46482379, 0.33613064, 0.29176525, 0.39665733, 0.29566868,
0.6830792 , 0.09914912, 0.44933601, 0.51109961, 0.43402864,
0.44240692, 0.30820354, 0.07419383, 0.51311126, 0.23962994,
0.68804707, 0.08277803, 0.23771643, 0.40546371, 0.34642179,
0.65205457, 0.11283301, 0.23097786, 0.39279025, 0.45560596,
0.20557529, 0.56630624, 0.44414538, 0.11992531, 0.08302729,
0.41761443, 0.07683355, 0.15940981, 0.41934893, 0.34585658,
0.46061428, 0.19044365, 0.32782305, 0.44463802, 0.39146187,
0.25396609, 0.37017432, 0.20984756, 0.68209838, 0.07416688,
0.41171599, 0.13801437, 0.68032696, 0.41079291, 0.19140806,
0.33883255, 0.25185193, 0.78723456, 0.33732148, 0.08421657,
0.19422472, 0.34108163, 0.05339644, 0.08862708, 0.34809861,
0.32221287, 0.19137922, 0.65283869, 0.44150487, 0.59837777,
0.38707145, 0.1333219 , 0.04298876, 0.60631316, 0.17610629,
0.32171648, 0.45718411, 0.32812774, 0.32396383, 0.70083906,
0.2164637 , 0.40866494, 0.16153463, 0.07100125, 0.21977632,
0.37657765, 0.33223294, 0.35561117, 0.56369684, 0.5964966 ,
0.51557555, 0.40247656, 0.26585279, 0.14015302, 0.10302591,
0.27458847, 0.12266008, 0.30145633, 0.36120606, 0.1769366 ,
0.28386606, 0.14812209, 0.17503618, 0.03308329, 0.3245861 ])
print(np.mean(prob[0:100]))
print(np.mean(prob[100:]))
0.479703084341
0.326590205584
이번에는 theano의 sigmoid layer를 이용해보죠
W = T.matrix('W')
x = T.vector('x')
b = T.vector('b')
out2 = T.nnet.sigmoid(T.dot(W, x) + b)
f = theano.function([W, x, b], out2)
prob = f(np.array([[-1, 2, -3], [4,5,6]], dtype = theano.config.floatX),
np.array([1,2,3], dtype = theano.config.floatX),
np.array([4,5], dtype = theano.config.floatX))
print(prob)
[ 0.11920292 1. ]
prediction = prob > 0.5
prediction
array([False, True], dtype=bool)
xent = -y * T.log(p) - (1-y) * T.log(1-p)
cost = xent.mean()
gw, gb = T.grad(cost, [w,b])
모두 모아 logistic regression 완성시키기
N = 200
np.random.seed(0)
len_hair = np.concatenate([np.random.randn(N) + 10 ,np.random.randn(N)]) # 남자 200명, 여자 200명 키 만들어내기
gender = np.concatenate([np.zeros(N), np.ones(N)])
x = T.vector('x')
y = T.vector('y')
w = theano.shared(1., name = 'w')
b = theano.shared(0., name = 'b')
prob = 1 / (1 + T.exp(-(w * x) -b))
prediction = prob > 0.5
xent = -y * T.log(prob) - (1-y) * T.log(1-prob)
cost = xent.mean()
gw, gb = T.grad(cost, [w, b])
# delta rule : W = W - learning_rate * gw, b = b - learning_rate * gb
train = theano.function([x, y],
outputs = [prob, prediction],
updates = ((w, w - 0.05 * gw), (b, b - 0.05 * gb)))
pred_prob = theano.function([x], prob)
predict = theano.function([x], prediction)
for i in range(1000):
prob, pred = train(len_hair, gender)
w.get_value()
array(-0.896757297481667)
b.get_value()
array(3.1472479382974248)
print(pred_prob(len_hair))
[ 6.09506135e-04 2.06786431e-03 1.23183782e-03 3.97522124e-04
5.55508329e-04 7.07600972e-03 1.26385537e-03 3.38640461e-03
3.24379336e-03 2.04863211e-03 2.60037161e-03 8.04523594e-04
1.49700431e-03 2.65291978e-03 1.98854284e-03 2.19461871e-03
7.76333794e-04 3.55319745e-03 2.23545891e-03 6.34067965e-03
2.84453508e-02 1.64813361e-03 1.36461894e-03 5.73860989e-03
3.87369514e-04 1.08132383e-02 2.83928606e-03 3.49658234e-03
7.49873421e-04 7.93722009e-04 2.57513395e-03 2.10896867e-03
6.53389508e-03 1.72247806e-02 4.03650142e-03 2.57190776e-03
9.83315059e-04 1.00821215e-03 4.18111731e-03 3.87536382e-03
7.53952868e-03 1.04886908e-02 1.35167753e-02 5.15582709e-04
4.66361098e-03 4.37493663e-03 9.04128418e-03 1.47511196e-03
1.24556095e-02 3.57735205e-03 6.57875766e-03 2.09253840e-03
4.66841276e-03 8.47953050e-03 3.03334178e-03 2.01637652e-03
2.78706614e-03 2.25675301e-03 5.21238328e-03 4.09031640e-03
5.39275673e-03 4.07868705e-03 6.11345990e-03 1.37581733e-02
2.52387392e-03 4.23543399e-03 1.26367001e-02 1.95515565e-03
6.64846504e-03 2.82362131e-03 1.54044499e-03 2.63563663e-03
1.06672912e-03 8.89804367e-03 2.06382587e-03 5.45248401e-03
6.43574365e-03 4.96069012e-03 3.90751678e-03 2.81298617e-03
8.36359284e-03 1.32086362e-03 1.95012218e-03 1.16274927e-02
7.80397811e-04 5.41580296e-04 1.02975486e-03 3.47397381e-03
7.68996158e-03 1.15106918e-03 4.24071708e-03 9.90250798e-04
2.45517965e-03 1.23415576e-03 2.15050619e-03 1.57181742e-03
2.93022656e-03 5.97703757e-04 2.64052268e-03 2.06447661e-03
5.47799006e-04 9.83706357e-03 9.18452473e-03 1.24218722e-03
8.42310426e-03 5.18899387e-04 4.28044276e-03 5.76573924e-03
5.28606596e-04 7.85827217e-04 5.55507846e-04 1.31470527e-03
6.38109124e-03 5.34742836e-04 3.75842024e-03 1.44250051e-03
1.26707013e-03 3.39747807e-03 1.70751843e-03 1.29581265e-03
2.11224934e-03 7.88850112e-03 2.26531789e-03 9.02200412e-04
5.50014094e-03 3.38119496e-03 4.36354271e-03 5.64691705e-04
1.62080493e-03 2.05439081e-03 5.88236193e-03 1.82581618e-03
5.40176939e-03 2.87486854e-03 5.21947441e-03 1.61481053e-03
1.76579474e-03 3.56318257e-03 2.07555921e-03 7.84413458e-03
1.11728932e-02 1.99651671e-03 2.54826580e-03 1.67578879e-03
3.49929975e-04 1.27022029e-03 6.68125975e-03 1.08833478e-03
9.56270933e-03 4.46773630e-03 3.14394225e-03 6.37844730e-04
5.75187611e-03 6.18631444e-03 3.23000289e-03 5.34972497e-03
1.07899675e-03 7.75302791e-03 8.23311091e-03 4.37394359e-03
4.61549122e-03 5.25493562e-04 1.26461129e-03 2.73513050e-03
8.82408783e-03 1.38937116e-03 7.22200378e-03 1.17157035e-02
1.02125685e-03 2.22772177e-03 1.29737780e-03 2.22416653e-03
1.37394512e-03 5.29063099e-03 7.44410990e-03 1.60736586e-03
6.06063540e-03 5.47558112e-03 4.44366184e-03 2.91199797e-03
4.05848650e-03 1.00774403e-02 5.25578878e-03 2.13218173e-02
1.69055587e-03 1.23256797e-02 7.92354698e-03 2.82306668e-03
5.72531169e-03 7.43027176e-04 9.36890328e-03 2.32941763e-03
3.06359511e-03 8.38551439e-03 1.85210764e-03 3.44805964e-03
1.48265985e-03 1.41556722e-03 4.26179930e-04 8.94039440e-04
9.70064411e-01 9.66492782e-01 8.96703352e-01 9.28218169e-01
9.29117084e-01 9.90021215e-01 9.59653304e-01 9.78312276e-01
9.47659436e-01 9.62140007e-01 9.11415207e-01 9.45975560e-01
9.19982008e-01 9.72495416e-01 9.81910829e-01 9.71110526e-01
9.59398874e-01 9.43065854e-01 7.54210522e-01 9.60271313e-01
9.82093075e-01 9.69454306e-01 9.72427618e-01 9.37933689e-01
9.89323422e-01 9.56499893e-01 9.52884854e-01 9.49742824e-01
9.75466959e-01 9.66450431e-01 9.88159110e-01 9.73133370e-01
9.74270804e-01 9.41262960e-01 9.84992190e-01 9.20320001e-01
8.59005933e-01 9.93330524e-01 9.40754775e-01 9.26914119e-01
9.76313402e-01 9.70787314e-01 9.63258311e-01 9.68148286e-01
9.68457158e-01 9.90530988e-01 8.92245536e-01 8.98356198e-01
9.79700254e-01 9.88595471e-01 9.35834875e-01 9.75000691e-01
9.53467325e-01 9.68738519e-01 9.26020306e-01 9.25822837e-01
9.78074440e-01 9.87724394e-01 9.89715295e-01 9.30854228e-01
9.85419266e-01 9.73447997e-01 9.75445446e-01 9.60622542e-01
9.92487313e-01 9.51568413e-01 9.35682523e-01 9.55551387e-01
9.68508433e-01 9.55208174e-01 9.42100337e-01 9.96437027e-01
8.01113656e-01 9.42536729e-01 9.76621905e-01 9.70626198e-01
9.37290564e-01 9.62722131e-01 9.93092932e-01 7.85143965e-01
9.62542674e-01 9.03121599e-01 9.77419929e-01 8.54394311e-01
9.47373177e-01 9.30942806e-01 9.83448467e-01 8.87068880e-01
9.26125939e-01 8.78660675e-01 9.76118737e-01 9.72843651e-01
7.46719837e-01 9.83662582e-01 9.63355594e-01 8.93569537e-01
9.55195710e-01 9.32420548e-01 9.70842633e-01 9.43502233e-01
9.86860237e-01 8.40279230e-01 9.62788375e-01 9.77183772e-01
9.27550942e-01 9.72358378e-01 9.87178836e-01 9.87319480e-01
9.25882921e-01 9.64096153e-01 9.63284358e-01 8.98509613e-01
9.84598017e-01 9.78171927e-01 9.70470514e-01 9.55324997e-01
9.60268381e-01 9.67845378e-01 9.60928694e-01 9.62437927e-01
9.77958893e-01 9.79693631e-01 9.47899473e-01 9.81038155e-01
9.85007729e-01 9.68546868e-01 9.64036929e-01 7.54640070e-01
9.77668951e-01 9.08990604e-01 9.22528014e-01 9.85420370e-01
9.20840912e-01 9.85354979e-01 9.96057068e-01 9.31088193e-01
9.91179837e-01 9.39509362e-01 9.77260277e-01 8.40108240e-01
8.99262272e-01 9.72181060e-01 9.77336412e-01 9.85740654e-01
9.71877192e-01 9.67662592e-01 9.69947309e-01 9.52876910e-01
9.32670567e-01 9.44469593e-01 9.78803598e-01 9.88302318e-01
8.72538681e-01 9.77368401e-01 9.76619550e-01 9.73779115e-01
9.91837704e-01 9.72771225e-01 9.72811144e-01 9.30276025e-01
9.25594157e-01 9.58666358e-01 9.09833701e-01 9.44923553e-01
9.59352100e-01 9.52706519e-01 9.65048530e-01 9.70725648e-01
9.67306526e-01 9.84614131e-01 9.47636430e-01 9.82670077e-01
9.16254288e-01 9.66784267e-01 9.57010685e-01 9.37285556e-01
9.28928872e-01 9.89602273e-01 9.65536740e-01 9.13563360e-01
9.90715092e-01 9.42673195e-01 9.94347235e-01 9.83113140e-01
9.57409734e-01 9.90367367e-01 9.82553434e-01 9.88650046e-01
8.41478590e-01 9.52573020e-01 9.33300280e-01 9.66004274e-01
9.69651524e-01 9.90016946e-01 9.67983239e-01 9.78754207e-01
9.15126345e-01 8.93209913e-01 8.62009657e-01 9.15499763e-01]
print(predict(len_hair))
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
gender
array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])