모형 서브클래싱
import tensorflow as tf
단순한 모형
class SimpleModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.layer1 = tf.keras.layers.Dense(2, input_shape=(2,), activation='relu')
self.layer2 = tf.keras.layers.Dense(1, activation='sigmoid')
def call(self, x):
z = self.layer1(x)
return self.layer2(z)
모형 인스턴스를 생성한다.
model = SimpleModel()
모형에 데이터를 입력
x = tf.convert_to_tensor([[1.0, 2.0]])
model(x)
<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.52212834]], dtype=float32)>
모형 요약을 보면 레이어 목록만 표시된다.
model.summary()
Model: "simple_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) multiple 6 _________________________________________________________________ dense_1 (Dense) multiple 3 ================================================================= Total params: 9 Trainable params: 9 Non-trainable params: 0 _________________________________________________________________
모형을 시각화해서 레이어의 연결 관계를 볼 수 없다.
tf.keras.utils.plot_model(model)
복잡한 모형
class ComplexModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.layer1 = tf.keras.layers.Dense(2, activation='relu')
self.layer2 = tf.keras.layers.Dense(2, activation='tanh')
self.layer3 = tf.keras.layers.Add()
self.layer4 = tf.keras.layers.Dense(1, activation='sigmoid')
self.layer5 = tf.keras.layers.Dense(1, activation='sigmoid')
def call(self, x1, x2):
out1 = self.layer1(x1)
out2 = self.layer2(x2)
out3 = self.layer3([out1, out2])
out4 = self.layer4(out3)
out5 = self.layer5(out3)
model = ComplexModel()
model(x, x)
모형 요약에 레이어 목록만 표시되고 레이어들 간의 연결 관계가 표시되지 않는다.
model.summary()
Model: "complex_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_2 (Dense) multiple 6 _________________________________________________________________ dense_3 (Dense) multiple 6 _________________________________________________________________ add (Add) multiple 0 _________________________________________________________________ dense_4 (Dense) multiple 3 _________________________________________________________________ dense_5 (Dense) multiple 3 ================================================================= Total params: 18 Trainable params: 18 Non-trainable params: 0 _________________________________________________________________
모형을 시각화해서 레이어의 연결 관계를 볼 수 없다.
tf.keras.utils.plot_model(model)
레이어 없이 모형 만들기
class SignModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.w = tf.Variable(1.0)
def call(self, x):
if x > 0:
return self.w * x
else:
return 0.0
model = SignModel()
model(3.0)
<tf.Tensor: shape=(), dtype=float32, numpy=3.0>
model.summary()
Model: "sign_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= Total params: 1 Trainable params: 1 Non-trainable params: 0 _________________________________________________________________