Skip to content

Commit eb8d63a

Browse files
fix: Enforce that __enter__ is called on all user interfaces before use (#70)
* chore: Add license headers to all files * fix: Enforce that __enter__ is called on all user interfaces before use * fix: Enforce that __enter__ is called on all user interfaces before use
1 parent b0407f6 commit eb8d63a

4 files changed

Lines changed: 71 additions & 15 deletions

File tree

google/cloud/pubsublite/cloudpubsub/publisher_client.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from google.cloud.pubsublite.internal.constructable_from_service_account import (
3737
ConstructableFromServiceAccount,
3838
)
39+
from google.cloud.pubsublite.internal.require_started import RequireStarted
3940
from google.cloud.pubsublite.internal.wire.make_publisher import (
4041
DEFAULT_BATCHING_SETTINGS as WIRE_DEFAULT_BATCHING,
4142
)
@@ -52,6 +53,7 @@ class PublisherClient(PublisherClientInterface, ConstructableFromServiceAccount)
5253
"""
5354

5455
_impl: PublisherClientInterface
56+
_require_stared: RequireStarted
5557

5658
DEFAULT_BATCHING_SETTINGS = WIRE_DEFAULT_BATCHING
5759
"""
@@ -83,6 +85,7 @@ def __init__(
8385
transport=transport,
8486
)
8587
)
88+
self._require_stared = RequireStarted()
8689

8790
@overrides
8891
def publish(
@@ -92,18 +95,21 @@ def publish(
9295
ordering_key: str = "",
9396
**attrs: Mapping[str, str]
9497
) -> "Future[str]":
98+
self._require_stared.require_started()
9599
return self._impl.publish(
96100
topic=topic, data=data, ordering_key=ordering_key, **attrs
97101
)
98102

99103
@overrides
100104
def __enter__(self):
105+
self._require_stared.__enter__()
101106
self._impl.__enter__()
102107
return self
103108

104109
@overrides
105110
def __exit__(self, exc_type, exc_value, traceback):
106111
self._impl.__exit__(exc_type, exc_value, traceback)
112+
self._require_stared.__exit__(exc_type, exc_value, traceback)
107113

108114

109115
class AsyncPublisherClient(
@@ -117,6 +123,7 @@ class AsyncPublisherClient(
117123
"""
118124

119125
_impl: AsyncPublisherClientInterface
126+
_require_stared: RequireStarted
120127

121128
DEFAULT_BATCHING_SETTINGS = WIRE_DEFAULT_BATCHING
122129
"""
@@ -148,6 +155,7 @@ def __init__(
148155
transport=transport,
149156
)
150157
)
158+
self._require_stared = RequireStarted()
151159

152160
@overrides
153161
async def publish(
@@ -157,15 +165,18 @@ async def publish(
157165
ordering_key: str = "",
158166
**attrs: Mapping[str, str]
159167
) -> str:
168+
self._require_stared.require_started()
160169
return await self._impl.publish(
161170
topic=topic, data=data, ordering_key=ordering_key, **attrs
162171
)
163172

164173
@overrides
165174
async def __aenter__(self):
175+
self._require_stared.__enter__()
166176
await self._impl.__aenter__()
167177
return self
168178

169179
@overrides
170180
async def __aexit__(self, exc_type, exc_value, traceback):
171181
await self._impl.__aexit__(exc_type, exc_value, traceback)
182+
self._require_stared.__exit__(exc_type, exc_value, traceback)

google/cloud/pubsublite/cloudpubsub/subscriber_client.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from google.cloud.pubsublite.internal.constructable_from_service_account import (
4040
ConstructableFromServiceAccount,
4141
)
42+
from google.cloud.pubsublite.internal.require_started import RequireStarted
4243
from google.cloud.pubsublite.types import (
4344
FlowControlSettings,
4445
Partition,
@@ -56,6 +57,7 @@ class SubscriberClient(SubscriberClientInterface, ConstructableFromServiceAccoun
5657
"""
5758

5859
_impl: SubscriberClientInterface
60+
_require_started: RequireStarted
5961

6062
def __init__(
6163
self,
@@ -92,6 +94,7 @@ def __init__(
9294
client_options=client_options,
9395
),
9496
)
97+
self._require_started = RequireStarted()
9598

9699
@overrides
97100
def subscribe(
@@ -101,6 +104,7 @@ def subscribe(
101104
per_partition_flow_control_settings: FlowControlSettings,
102105
fixed_partitions: Optional[Set[Partition]] = None,
103106
) -> StreamingPullFuture:
107+
self._require_started.require_started()
104108
return self._impl.subscribe(
105109
subscription,
106110
callback,
@@ -110,12 +114,14 @@ def subscribe(
110114

111115
@overrides
112116
def __enter__(self):
117+
self._require_started.__enter__()
113118
self._impl.__enter__()
114119
return self
115120

116121
@overrides
117122
def __exit__(self, exc_type, exc_value, traceback):
118123
self._impl.__exit__(exc_type, exc_value, traceback)
124+
self._require_started.__exit__(exc_type, exc_value, traceback)
119125

120126

121127
class AsyncSubscriberClient(
@@ -130,6 +136,7 @@ class AsyncSubscriberClient(
130136
"""
131137

132138
_impl: AsyncSubscriberClientInterface
139+
_require_started: RequireStarted
133140

134141
def __init__(
135142
self,
@@ -161,6 +168,7 @@ def __init__(
161168
client_options=client_options,
162169
)
163170
)
171+
self._require_started = RequireStarted()
164172

165173
@overrides
166174
async def subscribe(
@@ -169,15 +177,18 @@ async def subscribe(
169177
per_partition_flow_control_settings: FlowControlSettings,
170178
fixed_partitions: Optional[Set[Partition]] = None,
171179
) -> AsyncIterator[Message]:
180+
self._require_started.require_started()
172181
return await self._impl.subscribe(
173182
subscription, per_partition_flow_control_settings, fixed_partitions
174183
)
175184

176185
@overrides
177186
async def __aenter__(self):
187+
self._require_started.__enter__()
178188
await self._impl.__aenter__()
179189
return self
180190

181191
@overrides
182192
async def __aexit__(self, exc_type, exc_value, traceback):
183193
await self._impl.__aexit__(exc_type, exc_value, traceback)
194+
self._require_started.__exit__(exc_type, exc_value, traceback)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import ContextManager
16+
17+
from google.api_core.exceptions import FailedPrecondition
18+
19+
20+
class RequireStarted(ContextManager):
21+
def __init__(self):
22+
self._started = False
23+
24+
def __enter__(self):
25+
if self._started:
26+
raise FailedPrecondition("__enter__ called twice.")
27+
self._started = True
28+
return self
29+
30+
def require_started(self):
31+
if not self._started:
32+
raise FailedPrecondition("__enter__ has never been called.")
33+
34+
def __exit__(self, exc_type, exc_value, traceback):
35+
self.require_started()

samples/snippets/subscriber_example.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,21 +59,20 @@ def callback(message):
5959
print(f"Received {message_data} of ordering key {message.ordering_key}.")
6060
message.ack()
6161

62-
subscriber_client = SubscriberClient()
63-
64-
streaming_pull_future = subscriber_client.subscribe(
65-
subscription_path,
66-
callback=callback,
67-
per_partition_flow_control_settings=per_partition_flow_control_settings,
68-
)
69-
70-
print(f"Listening for messages on {str(subscription_path)}...")
71-
72-
try:
73-
streaming_pull_future.result(timeout=timeout)
74-
except TimeoutError or KeyboardInterrupt:
75-
streaming_pull_future.cancel()
76-
assert streaming_pull_future.done()
62+
with SubscriberClient() as subscriber_client:
63+
streaming_pull_future = subscriber_client.subscribe(
64+
subscription_path,
65+
callback=callback,
66+
per_partition_flow_control_settings=per_partition_flow_control_settings,
67+
)
68+
69+
print(f"Listening for messages on {str(subscription_path)}...")
70+
71+
try:
72+
streaming_pull_future.result(timeout=timeout)
73+
except TimeoutError or KeyboardInterrupt:
74+
streaming_pull_future.cancel()
75+
assert streaming_pull_future.done()
7776
# [END pubsublite_quickstart_subscriber]
7877

7978

0 commit comments

Comments
 (0)