모형 서브클래싱
import tensorflow as tf
단순한 모형
class SimpleModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.layer1 = tf.keras.layers.Dense(2, input_shape=(2,), activation='relu')
        self.layer2 = tf.keras.layers.Dense(1, activation='sigmoid')
    def call(self, x):
        z = self.layer1(x)
        return self.layer2(z)
모형 인스턴스를 생성한다.
model = SimpleModel()
모형에 데이터를 입력
x = tf.convert_to_tensor([[1.0, 2.0]])
model(x)
<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.52212834]], dtype=float32)>
모형 요약을 보면 레이어 목록만 표시된다.
model.summary()
Model: "simple_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) multiple 6 _________________________________________________________________ dense_1 (Dense) multiple 3 ================================================================= Total params: 9 Trainable params: 9 Non-trainable params: 0 _________________________________________________________________
모형을 시각화해서 레이어의 연결 관계를 볼 수 없다.
tf.keras.utils.plot_model(model)
복잡한 모형
class ComplexModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.layer1 = tf.keras.layers.Dense(2, activation='relu')
        self.layer2 = tf.keras.layers.Dense(2, activation='tanh')
        self.layer3 = tf.keras.layers.Add()
        self.layer4 = tf.keras.layers.Dense(1, activation='sigmoid')
        self.layer5 = tf.keras.layers.Dense(1, activation='sigmoid')
    def call(self, x1, x2):
        out1 = self.layer1(x1)
        out2 = self.layer2(x2)
        out3 = self.layer3([out1, out2])
        out4 = self.layer4(out3)
        out5 = self.layer5(out3)
model = ComplexModel()
model(x, x)
모형 요약에 레이어 목록만 표시되고 레이어들 간의 연결 관계가 표시되지 않는다.
model.summary()
Model: "complex_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_2 (Dense) multiple 6 _________________________________________________________________ dense_3 (Dense) multiple 6 _________________________________________________________________ add (Add) multiple 0 _________________________________________________________________ dense_4 (Dense) multiple 3 _________________________________________________________________ dense_5 (Dense) multiple 3 ================================================================= Total params: 18 Trainable params: 18 Non-trainable params: 0 _________________________________________________________________
모형을 시각화해서 레이어의 연결 관계를 볼 수 없다.
tf.keras.utils.plot_model(model)
레이어 없이 모형 만들기
class SignModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.w = tf.Variable(1.0)
    def call(self, x):
        if x > 0:
            return self.w * x
        else:
            return 0.0
model = SignModel()
model(3.0)
<tf.Tensor: shape=(), dtype=float32, numpy=3.0>
model.summary()
Model: "sign_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= Total params: 1 Trainable params: 1 Non-trainable params: 0 _________________________________________________________________