Theano 기초 :: 딥러닝 기초 - mindscale
Skip to content

Theano 기초

theano 기초

본 강의에서는

  1. theano의 특징을 알아보고
  2. 어떤 기능이 있는지
  3. 어떻게 활용할 수 있는지에 대해 알아보려고 합니다.

1. function

import theano
import theano.tensor as T
# variables(double, scalar)
x = T.dscalar('x')
y = T.dscalar('y')

# variable that is associated to the computation of the variables above.
z = x + y
# create function
f = theano.function(inputs = [x,y], outputs = z)
f(2,3)
array(5.0)
mult = x * y
f = theano.function([x, y], [z, mult])
f(2,3)
[array(5.0), array(6.0)]

2. shared variable & updates

After each function evaluation, the updates mechanism can replace the value of any SharedVariable [implicit] inputs with new values computed from the expressions in the updates list. An exception will be raised if you give two update expressions for the same SharedVariable input (that doesn’t make sense).

state = theano.shared(0.0)
inc = T.dscalar('inc')
accumulator = theano.function([inc], updates=[(state, state + inc)])
state.get_value()
array(0.0)
accumulator(100)
state.get_value()
array(200.0)

3. replace variables

fn_of_state = state * 2 + inc
foo = T.scalar(dtype=state.dtype)
skip_shared = theano.function([inc, foo], fn_of_state, givens=[(state, foo)])
state.get_value()
array(200.0)
skip_shared(1,3) # (3 * 2) + 1
array(7.0)
from theano.tensor.shared_randomstreams import RandomStreams
srng = RandomStreams(seed = 123)
rv_u = srng.uniform((2,2))
rv_n = srng.normal((2,2))
gen_uni = theano.function([], outputs=rv_u)
gen_norm = theano.function([], outputs=rv_n, no_default_updates=True)
gen_uni()
array([[ 0.72803009,  0.59044123],
       [ 0.23715077,  0.69958932]])
gen_uni()
array([[ 0.86608782,  0.35008632],
       [ 0.37173976,  0.5594106 ]])
gen_norm()
array([[-0.35867012,  0.97187258],
       [-0.07658328, -0.86469693]])
gen_norm()
array([[-0.35867012,  0.97187258],
       [-0.07658328, -0.86469693]])

4. calculate gradient

x = T.dvector('x')
y = T.sum(x**2 + x)
dydx = T.grad(cost=y, wrt=x)
f = theano.function(inputs=[x], outputs=dydx)

f([3,4])   #(2 * x + 1)
array([ 7.,  9.])