모형 서브클래싱 :: 대화형 AI - mindscale
Skip to content

모형 서브클래싱

import tensorflow as tf

단순한 모형

class SimpleModel(tf.keras.Model):
    def __init__(self):
        super().__init__()

        self.layer1 = tf.keras.layers.Dense(2, input_shape=(2,), activation='relu')
        self.layer2 = tf.keras.layers.Dense(1, activation='sigmoid')

    def call(self, x):
        z = self.layer1(x)
        return self.layer2(z)

모형 인스턴스를 생성한다.

model = SimpleModel()

모형에 데이터를 입력

x = tf.convert_to_tensor([[1.0, 2.0]])
model(x)
<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.52212834]], dtype=float32)>

모형 요약을 보면 레이어 목록만 표시된다.

model.summary()
Model: "simple_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                multiple                  6         
_________________________________________________________________
dense_1 (Dense)              multiple                  3         
=================================================================
Total params: 9
Trainable params: 9
Non-trainable params: 0
_________________________________________________________________

모형을 시각화해서 레이어의 연결 관계를 볼 수 없다.

tf.keras.utils.plot_model(model)

복잡한 모형

class ComplexModel(tf.keras.Model):
    def __init__(self):
        super().__init__()

        self.layer1 = tf.keras.layers.Dense(2, activation='relu')
        self.layer2 = tf.keras.layers.Dense(2, activation='tanh')
        self.layer3 = tf.keras.layers.Add()
        self.layer4 = tf.keras.layers.Dense(1, activation='sigmoid')
        self.layer5 = tf.keras.layers.Dense(1, activation='sigmoid')

    def call(self, x1, x2):
        out1 = self.layer1(x1)
        out2 = self.layer2(x2)
        out3 = self.layer3([out1, out2])
        out4 = self.layer4(out3)
        out5 = self.layer5(out3)
model = ComplexModel()
model(x, x)

모형 요약에 레이어 목록만 표시되고 레이어들 간의 연결 관계가 표시되지 않는다.

model.summary()
Model: "complex_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_2 (Dense)              multiple                  6         
_________________________________________________________________
dense_3 (Dense)              multiple                  6         
_________________________________________________________________
add (Add)                    multiple                  0         
_________________________________________________________________
dense_4 (Dense)              multiple                  3         
_________________________________________________________________
dense_5 (Dense)              multiple                  3         
=================================================================
Total params: 18
Trainable params: 18
Non-trainable params: 0
_________________________________________________________________

모형을 시각화해서 레이어의 연결 관계를 볼 수 없다.

tf.keras.utils.plot_model(model)

레이어 없이 모형 만들기

class SignModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.w = tf.Variable(1.0)

    def call(self, x):
        if x > 0:
            return self.w * x
        else:
            return 0.0
model = SignModel()
model(3.0)
<tf.Tensor: shape=(), dtype=float32, numpy=3.0>
model.summary()
Model: "sign_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
Total params: 1
Trainable params: 1
Non-trainable params: 0
_________________________________________________________________