로지스틱 회귀 실습 -2 :: 딥러닝 기초 - mindscale
Skip to content

로지스틱 회귀 실습 -2

Logistic Regression

Goal : 인공신경망의 기본이 되는 로지스틱 회귀로 남녀 성별 예측해보기

이번 실습에서는 인공신경망의 기본이 되는 로지스틱 회귀를 사용해서 개인의 키로 성별을 예측하는 모형을 만들어볼 것입니다. 실습의 목표는 다음과 같습니다.

  • Theano의 기본 사용법을 익힌다.
  • 로지스틱 회귀 모형에 대해 이해한다.
  • gradient descent를 구현하는 법을 배운다.
import theano
import theano.tensor as T
import numpy as np

자료를 만드는 부분입니다

x = T.vector('x')
w = T.scalar('w')
b = T.scalar('b')
y = T.vector('y')
out = 1 / (1 + T.exp(-(w * x) - b))
N = 100

손으로 계산한 sigmoid function 값

f = theano.function([x, w, b], out)
/Users/Simiro/anaconda3/lib/python3.5/site-packages/theano/tensor/signal/downsample.py:5: UserWarning: downsample module has been moved to the pool module.
  warnings.warn("downsample module has been moved to the pool module.")
hair = np.concatenate([np.random.randn(N) +1, np.random.randn(N)])
prob = f(hair,
        np.array(1, dtype = theano.config.floatX),
        np.array(-1, dtype = theano.config.floatX))
prob
array([ 0.11558904,  0.49593163,  0.50266056,  0.18066072,  0.41688148,
        0.49666854,  0.39019999,  0.37787053,  0.24365962,  0.5872073 ,
        0.64218682,  0.87913439,  0.46804282,  0.49138378,  0.53244897,
        0.56563533,  0.38386301,  0.72817262,  0.71435712,  0.8421329 ,
        0.27449893,  0.31089732,  0.28456894,  0.5276074 ,  0.77775312,
        0.33467973,  0.11895283,  0.47499868,  0.43764116,  0.34815379,
        0.55661846,  0.09179522,  0.38217598,  0.7989482 ,  0.49020603,
        0.3459761 ,  0.5996382 ,  0.57444145,  0.75427615,  0.10836856,
        0.49189244,  0.09332962,  0.47537184,  0.63066922,  0.14194516,
        0.78279751,  0.43462003,  0.55993874,  0.70784825,  0.25048662,
        0.77757078,  0.63489736,  0.59578119,  0.34744468,  0.65984086,
        0.25719687,  0.23028135,  0.27604124,  0.68334074,  0.67294121,
        0.70615617,  0.5085993 ,  0.67918025,  0.56140787,  0.64249315,
        0.74695565,  0.3179371 ,  0.5259163 ,  0.84169246,  0.94411889,
        0.47612601,  0.45173337,  0.48000097,  0.59060334,  0.42643188,
        0.87934579,  0.57625341,  0.23529187,  0.09608668,  0.11590786,
        0.22637512,  0.79000462,  0.2245848 ,  0.3175366 ,  0.10137861,
        0.64306132,  0.50631107,  0.73406696,  0.50716892,  0.58947254,
        0.44712662,  0.34006081,  0.15032411,  0.38505774,  0.38667379,
        0.46387625,  0.70263627,  0.40758166,  0.51992322,  0.37372992,
        0.46482379,  0.33613064,  0.29176525,  0.39665733,  0.29566868,
        0.6830792 ,  0.09914912,  0.44933601,  0.51109961,  0.43402864,
        0.44240692,  0.30820354,  0.07419383,  0.51311126,  0.23962994,
        0.68804707,  0.08277803,  0.23771643,  0.40546371,  0.34642179,
        0.65205457,  0.11283301,  0.23097786,  0.39279025,  0.45560596,
        0.20557529,  0.56630624,  0.44414538,  0.11992531,  0.08302729,
        0.41761443,  0.07683355,  0.15940981,  0.41934893,  0.34585658,
        0.46061428,  0.19044365,  0.32782305,  0.44463802,  0.39146187,
        0.25396609,  0.37017432,  0.20984756,  0.68209838,  0.07416688,
        0.41171599,  0.13801437,  0.68032696,  0.41079291,  0.19140806,
        0.33883255,  0.25185193,  0.78723456,  0.33732148,  0.08421657,
        0.19422472,  0.34108163,  0.05339644,  0.08862708,  0.34809861,
        0.32221287,  0.19137922,  0.65283869,  0.44150487,  0.59837777,
        0.38707145,  0.1333219 ,  0.04298876,  0.60631316,  0.17610629,
        0.32171648,  0.45718411,  0.32812774,  0.32396383,  0.70083906,
        0.2164637 ,  0.40866494,  0.16153463,  0.07100125,  0.21977632,
        0.37657765,  0.33223294,  0.35561117,  0.56369684,  0.5964966 ,
        0.51557555,  0.40247656,  0.26585279,  0.14015302,  0.10302591,
        0.27458847,  0.12266008,  0.30145633,  0.36120606,  0.1769366 ,
        0.28386606,  0.14812209,  0.17503618,  0.03308329,  0.3245861 ])
print(np.mean(prob[0:100]))
print(np.mean(prob[100:]))
0.479703084341
0.326590205584

이번에는 theano의 sigmoid layer를 이용해보죠

W = T.matrix('W')
x = T.vector('x')
b = T.vector('b')

out2 = T.nnet.sigmoid(T.dot(W, x) + b)
f = theano.function([W, x, b], out2)
prob = f(np.array([[-1, 2, -3], [4,5,6]], dtype = theano.config.floatX),
        np.array([1,2,3], dtype = theano.config.floatX),
        np.array([4,5], dtype = theano.config.floatX))
print(prob)
[ 0.11920292  1.        ]
prediction = prob > 0.5
prediction
array([False,  True], dtype=bool)
xent = -y * T.log(p) - (1-y) * T.log(1-p)
cost = xent.mean()
gw, gb = T.grad(cost, [w,b])

모두 모아 logistic regression 완성시키기

N = 200
np.random.seed(0)
len_hair = np.concatenate([np.random.randn(N) + 10 ,np.random.randn(N)]) # 남자 200명, 여자 200명 키 만들어내기
gender = np.concatenate([np.zeros(N), np.ones(N)])
x = T.vector('x')
y = T.vector('y')

w = theano.shared(1., name = 'w')
b = theano.shared(0., name = 'b')
prob = 1 / (1 + T.exp(-(w * x) -b))

prediction = prob > 0.5

xent = -y * T.log(prob) - (1-y) * T.log(1-prob)
cost = xent.mean()

gw, gb = T.grad(cost, [w, b])
# delta rule : W = W - learning_rate * gw, b = b - learning_rate * gb
train = theano.function([x, y],
                       outputs = [prob, prediction],
                       updates = ((w, w - 0.05 * gw), (b, b - 0.05 * gb)))
pred_prob = theano.function([x], prob)
predict = theano.function([x], prediction)
for i in range(1000):
    prob, pred = train(len_hair, gender)
w.get_value()
array(-0.896757297481667)
b.get_value()
array(3.1472479382974248)
print(pred_prob(len_hair))
[  6.09506135e-04   2.06786431e-03   1.23183782e-03   3.97522124e-04
   5.55508329e-04   7.07600972e-03   1.26385537e-03   3.38640461e-03
   3.24379336e-03   2.04863211e-03   2.60037161e-03   8.04523594e-04
   1.49700431e-03   2.65291978e-03   1.98854284e-03   2.19461871e-03
   7.76333794e-04   3.55319745e-03   2.23545891e-03   6.34067965e-03
   2.84453508e-02   1.64813361e-03   1.36461894e-03   5.73860989e-03
   3.87369514e-04   1.08132383e-02   2.83928606e-03   3.49658234e-03
   7.49873421e-04   7.93722009e-04   2.57513395e-03   2.10896867e-03
   6.53389508e-03   1.72247806e-02   4.03650142e-03   2.57190776e-03
   9.83315059e-04   1.00821215e-03   4.18111731e-03   3.87536382e-03
   7.53952868e-03   1.04886908e-02   1.35167753e-02   5.15582709e-04
   4.66361098e-03   4.37493663e-03   9.04128418e-03   1.47511196e-03
   1.24556095e-02   3.57735205e-03   6.57875766e-03   2.09253840e-03
   4.66841276e-03   8.47953050e-03   3.03334178e-03   2.01637652e-03
   2.78706614e-03   2.25675301e-03   5.21238328e-03   4.09031640e-03
   5.39275673e-03   4.07868705e-03   6.11345990e-03   1.37581733e-02
   2.52387392e-03   4.23543399e-03   1.26367001e-02   1.95515565e-03
   6.64846504e-03   2.82362131e-03   1.54044499e-03   2.63563663e-03
   1.06672912e-03   8.89804367e-03   2.06382587e-03   5.45248401e-03
   6.43574365e-03   4.96069012e-03   3.90751678e-03   2.81298617e-03
   8.36359284e-03   1.32086362e-03   1.95012218e-03   1.16274927e-02
   7.80397811e-04   5.41580296e-04   1.02975486e-03   3.47397381e-03
   7.68996158e-03   1.15106918e-03   4.24071708e-03   9.90250798e-04
   2.45517965e-03   1.23415576e-03   2.15050619e-03   1.57181742e-03
   2.93022656e-03   5.97703757e-04   2.64052268e-03   2.06447661e-03
   5.47799006e-04   9.83706357e-03   9.18452473e-03   1.24218722e-03
   8.42310426e-03   5.18899387e-04   4.28044276e-03   5.76573924e-03
   5.28606596e-04   7.85827217e-04   5.55507846e-04   1.31470527e-03
   6.38109124e-03   5.34742836e-04   3.75842024e-03   1.44250051e-03
   1.26707013e-03   3.39747807e-03   1.70751843e-03   1.29581265e-03
   2.11224934e-03   7.88850112e-03   2.26531789e-03   9.02200412e-04
   5.50014094e-03   3.38119496e-03   4.36354271e-03   5.64691705e-04
   1.62080493e-03   2.05439081e-03   5.88236193e-03   1.82581618e-03
   5.40176939e-03   2.87486854e-03   5.21947441e-03   1.61481053e-03
   1.76579474e-03   3.56318257e-03   2.07555921e-03   7.84413458e-03
   1.11728932e-02   1.99651671e-03   2.54826580e-03   1.67578879e-03
   3.49929975e-04   1.27022029e-03   6.68125975e-03   1.08833478e-03
   9.56270933e-03   4.46773630e-03   3.14394225e-03   6.37844730e-04
   5.75187611e-03   6.18631444e-03   3.23000289e-03   5.34972497e-03
   1.07899675e-03   7.75302791e-03   8.23311091e-03   4.37394359e-03
   4.61549122e-03   5.25493562e-04   1.26461129e-03   2.73513050e-03
   8.82408783e-03   1.38937116e-03   7.22200378e-03   1.17157035e-02
   1.02125685e-03   2.22772177e-03   1.29737780e-03   2.22416653e-03
   1.37394512e-03   5.29063099e-03   7.44410990e-03   1.60736586e-03
   6.06063540e-03   5.47558112e-03   4.44366184e-03   2.91199797e-03
   4.05848650e-03   1.00774403e-02   5.25578878e-03   2.13218173e-02
   1.69055587e-03   1.23256797e-02   7.92354698e-03   2.82306668e-03
   5.72531169e-03   7.43027176e-04   9.36890328e-03   2.32941763e-03
   3.06359511e-03   8.38551439e-03   1.85210764e-03   3.44805964e-03
   1.48265985e-03   1.41556722e-03   4.26179930e-04   8.94039440e-04
   9.70064411e-01   9.66492782e-01   8.96703352e-01   9.28218169e-01
   9.29117084e-01   9.90021215e-01   9.59653304e-01   9.78312276e-01
   9.47659436e-01   9.62140007e-01   9.11415207e-01   9.45975560e-01
   9.19982008e-01   9.72495416e-01   9.81910829e-01   9.71110526e-01
   9.59398874e-01   9.43065854e-01   7.54210522e-01   9.60271313e-01
   9.82093075e-01   9.69454306e-01   9.72427618e-01   9.37933689e-01
   9.89323422e-01   9.56499893e-01   9.52884854e-01   9.49742824e-01
   9.75466959e-01   9.66450431e-01   9.88159110e-01   9.73133370e-01
   9.74270804e-01   9.41262960e-01   9.84992190e-01   9.20320001e-01
   8.59005933e-01   9.93330524e-01   9.40754775e-01   9.26914119e-01
   9.76313402e-01   9.70787314e-01   9.63258311e-01   9.68148286e-01
   9.68457158e-01   9.90530988e-01   8.92245536e-01   8.98356198e-01
   9.79700254e-01   9.88595471e-01   9.35834875e-01   9.75000691e-01
   9.53467325e-01   9.68738519e-01   9.26020306e-01   9.25822837e-01
   9.78074440e-01   9.87724394e-01   9.89715295e-01   9.30854228e-01
   9.85419266e-01   9.73447997e-01   9.75445446e-01   9.60622542e-01
   9.92487313e-01   9.51568413e-01   9.35682523e-01   9.55551387e-01
   9.68508433e-01   9.55208174e-01   9.42100337e-01   9.96437027e-01
   8.01113656e-01   9.42536729e-01   9.76621905e-01   9.70626198e-01
   9.37290564e-01   9.62722131e-01   9.93092932e-01   7.85143965e-01
   9.62542674e-01   9.03121599e-01   9.77419929e-01   8.54394311e-01
   9.47373177e-01   9.30942806e-01   9.83448467e-01   8.87068880e-01
   9.26125939e-01   8.78660675e-01   9.76118737e-01   9.72843651e-01
   7.46719837e-01   9.83662582e-01   9.63355594e-01   8.93569537e-01
   9.55195710e-01   9.32420548e-01   9.70842633e-01   9.43502233e-01
   9.86860237e-01   8.40279230e-01   9.62788375e-01   9.77183772e-01
   9.27550942e-01   9.72358378e-01   9.87178836e-01   9.87319480e-01
   9.25882921e-01   9.64096153e-01   9.63284358e-01   8.98509613e-01
   9.84598017e-01   9.78171927e-01   9.70470514e-01   9.55324997e-01
   9.60268381e-01   9.67845378e-01   9.60928694e-01   9.62437927e-01
   9.77958893e-01   9.79693631e-01   9.47899473e-01   9.81038155e-01
   9.85007729e-01   9.68546868e-01   9.64036929e-01   7.54640070e-01
   9.77668951e-01   9.08990604e-01   9.22528014e-01   9.85420370e-01
   9.20840912e-01   9.85354979e-01   9.96057068e-01   9.31088193e-01
   9.91179837e-01   9.39509362e-01   9.77260277e-01   8.40108240e-01
   8.99262272e-01   9.72181060e-01   9.77336412e-01   9.85740654e-01
   9.71877192e-01   9.67662592e-01   9.69947309e-01   9.52876910e-01
   9.32670567e-01   9.44469593e-01   9.78803598e-01   9.88302318e-01
   8.72538681e-01   9.77368401e-01   9.76619550e-01   9.73779115e-01
   9.91837704e-01   9.72771225e-01   9.72811144e-01   9.30276025e-01
   9.25594157e-01   9.58666358e-01   9.09833701e-01   9.44923553e-01
   9.59352100e-01   9.52706519e-01   9.65048530e-01   9.70725648e-01
   9.67306526e-01   9.84614131e-01   9.47636430e-01   9.82670077e-01
   9.16254288e-01   9.66784267e-01   9.57010685e-01   9.37285556e-01
   9.28928872e-01   9.89602273e-01   9.65536740e-01   9.13563360e-01
   9.90715092e-01   9.42673195e-01   9.94347235e-01   9.83113140e-01
   9.57409734e-01   9.90367367e-01   9.82553434e-01   9.88650046e-01
   8.41478590e-01   9.52573020e-01   9.33300280e-01   9.66004274e-01
   9.69651524e-01   9.90016946e-01   9.67983239e-01   9.78754207e-01
   9.15126345e-01   8.93209913e-01   8.62009657e-01   9.15499763e-01]
print(predict(len_hair))
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
gender
array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.])