Nenhuma descrição

sync_subscriptions_from_storekit.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. #!/usr/bin/env python3
  2. """Sync App Store Connect subscriptions from Paywall.storekit."""
  3. from __future__ import annotations
  4. import json
  5. import sys
  6. import time
  7. from pathlib import Path
  8. import jwt
  9. import requests
  10. API_BASE = "https://api.appstoreconnect.apple.com/v1"
  11. ROOT = Path(__file__).resolve().parents[1]
  12. REPO_ROOT = ROOT.parent
  13. ENV_PATH = ROOT / ".env"
  14. STOREKIT_PATH = REPO_ROOT / "App for Indeed" / "Paywall.storekit"
  15. MANIFEST_PATH = ROOT / "subscriptions.manifest.json"
  16. BASE_TERRITORY = "USA"
  17. LOCALE = "en-US"
  18. def load_env(path: Path) -> dict[str, str]:
  19. values: dict[str, str] = {}
  20. for line in path.read_text(encoding="utf-8").splitlines():
  21. line = line.strip()
  22. if not line or line.startswith("#") or "=" not in line:
  23. continue
  24. key, value = line.split("=", 1)
  25. values[key.strip()] = value.strip()
  26. return values
  27. def storekit_locale_to_asc(locale: str) -> str:
  28. return locale.replace("_", "-")
  29. def period_from_iso8601(value: str) -> str:
  30. mapping = {
  31. "P1W": "ONE_WEEK",
  32. "P1M": "ONE_MONTH",
  33. "P1Y": "ONE_YEAR",
  34. }
  35. if value not in mapping:
  36. raise ValueError(f"Unsupported subscription period {value!r}")
  37. return mapping[value]
  38. class ASCClient:
  39. def __init__(self, issuer_id: str, key_id: str, private_key_path: str):
  40. self.issuer_id = issuer_id
  41. self.key_id = key_id
  42. self.private_key = Path(private_key_path).expanduser().read_text(encoding="utf-8")
  43. self.session = requests.Session()
  44. self._token: str | None = None
  45. self._token_exp = 0.0
  46. def _ensure_token(self) -> str:
  47. now = time.time()
  48. if not self._token or now >= self._token_exp - 60:
  49. issued_at = int(now)
  50. payload = {
  51. "iss": self.issuer_id,
  52. "iat": issued_at,
  53. "exp": issued_at + 1200,
  54. "aud": "appstoreconnect-v1",
  55. }
  56. self._token = jwt.encode(
  57. payload,
  58. self.private_key,
  59. algorithm="ES256",
  60. headers={"kid": self.key_id, "typ": "JWT"},
  61. )
  62. self._token_exp = issued_at + 1200
  63. return self._token
  64. def request(self, method: str, path: str, **kwargs) -> requests.Response:
  65. headers = kwargs.pop("headers", {})
  66. headers["Authorization"] = f"Bearer {self._ensure_token()}"
  67. headers.setdefault("Content-Type", "application/json")
  68. url = path if path.startswith("http") else f"{API_BASE}{path}"
  69. return self.session.request(method, url, headers=headers, timeout=90, **kwargs)
  70. def get_json(self, path: str, params: dict | None = None) -> dict:
  71. response = self.request("GET", path, params=params)
  72. if not response.ok:
  73. raise RuntimeError(f"GET {path} failed ({response.status_code}): {response.text}")
  74. return response.json()
  75. def post(self, resource_type: str, attributes: dict, relationships: dict) -> dict:
  76. body = {
  77. "data": {
  78. "type": resource_type,
  79. "attributes": attributes,
  80. "relationships": relationships,
  81. }
  82. }
  83. collection = resource_type
  84. response = self.request("POST", f"/{collection}", json=body)
  85. if not response.ok:
  86. raise RuntimeError(
  87. f"POST {resource_type} failed ({response.status_code}): {response.text}"
  88. )
  89. return response.json()["data"]
  90. def find_app(client: ASCClient, bundle_id: str) -> dict:
  91. data = client.get_json("/apps", params={"filter[bundleId]": bundle_id, "limit": 1})
  92. apps = data.get("data", [])
  93. if not apps:
  94. raise RuntimeError(f"No app found for bundle id {bundle_id!r}")
  95. return apps[0]
  96. def load_storekit(path: Path) -> dict:
  97. return json.loads(path.read_text(encoding="utf-8"))
  98. def list_subscription_groups(client: ASCClient, app_id: str) -> list[dict]:
  99. data = client.get_json(
  100. f"/apps/{app_id}/subscriptionGroups",
  101. params={"include": "subscriptions", "limit": 50},
  102. )
  103. groups = data.get("data", [])
  104. included = {item["id"]: item for item in data.get("included", []) if item["type"] == "subscriptions"}
  105. for group in groups:
  106. rel_ids = [
  107. item["id"]
  108. for item in group.get("relationships", {})
  109. .get("subscriptions", {})
  110. .get("data", [])
  111. ]
  112. group["_subscriptions"] = [included[i] for i in rel_ids if i in included]
  113. return groups
  114. def find_or_create_group(client: ASCClient, app_id: str, reference_name: str) -> dict:
  115. for group in list_subscription_groups(client, app_id):
  116. if group.get("attributes", {}).get("referenceName") == reference_name:
  117. return group
  118. created = client.post(
  119. "subscriptionGroups",
  120. {"referenceName": reference_name},
  121. {"app": {"data": {"type": "apps", "id": app_id}}},
  122. )
  123. created["_subscriptions"] = []
  124. return created
  125. def ensure_group_localization(
  126. client: ASCClient, group_id: str, group_name: str, app_display_name: str
  127. ) -> None:
  128. data = client.get_json(f"/subscriptionGroups/{group_id}/subscriptionGroupLocalizations")
  129. for loc in data.get("data", []):
  130. if loc.get("attributes", {}).get("locale") == LOCALE:
  131. return
  132. client.post(
  133. "subscriptionGroupLocalizations",
  134. {"name": group_name, "locale": LOCALE, "customAppName": app_display_name},
  135. {"subscriptionGroup": {"data": {"type": "subscriptionGroups", "id": group_id}}},
  136. )
  137. def find_subscription_by_product_id(subscriptions: list[dict], product_id: str) -> dict | None:
  138. for sub in subscriptions:
  139. if sub.get("attributes", {}).get("productId") == product_id:
  140. return sub
  141. return None
  142. def create_subscription(
  143. client: ASCClient, group_id: str, spec: dict, group_number: int
  144. ) -> dict:
  145. loc = spec["localizations"][0]
  146. return client.post(
  147. "subscriptions",
  148. {
  149. "name": spec["referenceName"],
  150. "productId": spec["productID"],
  151. "subscriptionPeriod": period_from_iso8601(spec["recurringSubscriptionPeriod"]),
  152. "groupLevel": group_number,
  153. },
  154. {"group": {"data": {"type": "subscriptionGroups", "id": group_id}}},
  155. )
  156. def ensure_subscription_localization(client: ASCClient, sub_id: str, spec: dict) -> None:
  157. loc = spec["localizations"][0]
  158. asc_locale = storekit_locale_to_asc(loc["locale"])
  159. data = client.get_json(f"/subscriptions/{sub_id}/subscriptionLocalizations")
  160. for item in data.get("data", []):
  161. if item.get("attributes", {}).get("locale") == asc_locale:
  162. return
  163. client.post(
  164. "subscriptionLocalizations",
  165. {
  166. "name": loc["displayName"],
  167. "description": loc["description"],
  168. "locale": asc_locale,
  169. },
  170. {"subscription": {"data": {"type": "subscriptions", "id": sub_id}}},
  171. )
  172. def ensure_availability(client: ASCClient, sub_id: str) -> None:
  173. try:
  174. client.get_json(f"/subscriptionAvailabilities/{sub_id}")
  175. return
  176. except RuntimeError:
  177. pass
  178. client.post(
  179. "subscriptionAvailabilities",
  180. {"availableInNewTerritories": True},
  181. {
  182. "subscription": {"data": {"type": "subscriptions", "id": sub_id}},
  183. "availableTerritories": {"data": [{"type": "territories", "id": BASE_TERRITORY}]},
  184. },
  185. )
  186. def price_point_for_amount(client: ASCClient, sub_id: str, amount: str) -> str:
  187. path = f"/subscriptions/{sub_id}/pricePoints?filter[territory]={BASE_TERRITORY}&limit=200"
  188. while path:
  189. data = client.get_json(path)
  190. for point in data.get("data", []):
  191. if point.get("attributes", {}).get("customerPrice") == amount:
  192. return point["id"]
  193. next_url = data.get("links", {}).get("next")
  194. if not next_url:
  195. break
  196. path = next_url.replace(API_BASE, "")
  197. raise RuntimeError(f"No {BASE_TERRITORY} price point for ${amount} on subscription {sub_id}")
  198. def ensure_price(client: ASCClient, sub_id: str, amount: str) -> None:
  199. prices = client.get_json(f"/subscriptions/{sub_id}/prices")
  200. if prices.get("meta", {}).get("paging", {}).get("total", 0) > 0:
  201. return
  202. price_point_id = price_point_for_amount(client, sub_id, amount)
  203. client.post(
  204. "subscriptionPrices",
  205. {},
  206. {
  207. "subscription": {"data": {"type": "subscriptions", "id": sub_id}},
  208. "subscriptionPricePoint": {
  209. "data": {"type": "subscriptionPricePoints", "id": price_point_id}
  210. },
  211. },
  212. )
  213. def has_intro_offer(client: ASCClient, sub_id: str) -> bool:
  214. data = client.get_json(f"/subscriptions/{sub_id}/introductoryOffers")
  215. return data.get("meta", {}).get("paging", {}).get("total", 0) > 0
  216. def ensure_free_trial(client: ASCClient, sub_id: str, days: int) -> None:
  217. if has_intro_offer(client, sub_id):
  218. return
  219. duration = {3: "THREE_DAYS"}.get(days)
  220. if duration is None:
  221. raise ValueError(f"Unsupported free-trial length: {days} days")
  222. client.post(
  223. "subscriptionIntroductoryOffers",
  224. {"offerMode": "FREE_TRIAL", "duration": duration, "numberOfPeriods": 1},
  225. {
  226. "subscription": {"data": {"type": "subscriptions", "id": sub_id}},
  227. "territory": {"data": {"type": "territories", "id": BASE_TERRITORY}},
  228. },
  229. )
  230. def intro_days_from_storekit(offer: dict | None) -> int | None:
  231. if not offer:
  232. return None
  233. if offer.get("paymentMode") != "free":
  234. return None
  235. period = offer.get("subscriptionPeriod", "")
  236. if period == "P3D":
  237. return 3
  238. return None
  239. def main() -> int:
  240. if not ENV_PATH.exists():
  241. print(f"Missing {ENV_PATH}", file=sys.stderr)
  242. return 1
  243. if not STOREKIT_PATH.exists():
  244. print(f"Missing {STOREKIT_PATH}", file=sys.stderr)
  245. return 1
  246. env = load_env(ENV_PATH)
  247. issuer = env.get("APP_STORE_CONNECT_ISSUER_ID", "")
  248. key_id = env.get("APP_STORE_CONNECT_KEY_ID", "")
  249. key_path = env.get("APP_STORE_CONNECT_PRIVATE_KEY_PATH", "")
  250. bundle_id = env.get("APP_STORE_CONNECT_BUNDLE_ID", "")
  251. if not all([issuer, key_id, key_path, bundle_id]):
  252. print("Fill App Store Connect credentials in app-connect/.env", file=sys.stderr)
  253. return 1
  254. storekit = load_storekit(STOREKIT_PATH)
  255. groups = storekit.get("subscriptionGroups", [])
  256. if len(groups) != 1:
  257. print("Expected exactly one subscription group in Paywall.storekit", file=sys.stderr)
  258. return 1
  259. sk_group = groups[0]
  260. client = ASCClient(issuer, key_id, key_path)
  261. app = find_app(client, bundle_id)
  262. app_id = app["id"]
  263. print(f"App: {app['attributes'].get('name')} ({app_id})")
  264. asc_group = find_or_create_group(client, app_id, sk_group["name"])
  265. group_id = asc_group["id"]
  266. print(f"Subscription group: {sk_group['name']} ({group_id})")
  267. ensure_group_localization(client, group_id, sk_group["name"], "App for Indeed")
  268. manifest: dict = {
  269. "subscriptionGroupId": group_id,
  270. "subscriptionGroupName": sk_group["name"],
  271. "products": [],
  272. }
  273. existing = list_subscription_groups(client, app_id)
  274. current_subs: list[dict] = []
  275. for group in existing:
  276. if group["id"] == group_id:
  277. current_subs = group.get("_subscriptions", [])
  278. break
  279. for spec in sorted(sk_group["subscriptions"], key=lambda s: s["groupNumber"]):
  280. product_id = spec["productID"]
  281. sub = find_subscription_by_product_id(current_subs, product_id)
  282. if sub is None:
  283. sub = create_subscription(client, group_id, spec, spec["groupNumber"])
  284. print(f"Created subscription {product_id} ({sub['id']})")
  285. current_subs.append(sub)
  286. else:
  287. print(f"Found subscription {product_id} ({sub['id']})")
  288. sub_id = sub["id"]
  289. ensure_subscription_localization(client, sub_id, spec)
  290. ensure_availability(client, sub_id)
  291. ensure_price(client, sub_id, spec["displayPrice"])
  292. trial_days = intro_days_from_storekit(spec.get("introductoryOffer"))
  293. if trial_days:
  294. ensure_free_trial(client, sub_id, trial_days)
  295. print(f" Free trial: {trial_days} days")
  296. state = sub.get("attributes", {}).get("state")
  297. refreshed = client.get_json(f"/subscriptions/{sub_id}")
  298. state = refreshed["data"]["attributes"].get("state", state)
  299. manifest["products"].append(
  300. {
  301. "productId": product_id,
  302. "subscriptionId": sub_id,
  303. "referenceName": spec["referenceName"],
  304. "displayPrice": spec["displayPrice"],
  305. "period": spec["recurringSubscriptionPeriod"],
  306. "state": state,
  307. }
  308. )
  309. print(f" State: {state}")
  310. MANIFEST_PATH.write_text(json.dumps(manifest, indent=2) + "\n", encoding="utf-8")
  311. print(f"Wrote manifest: {MANIFEST_PATH}")
  312. print("Done. Review subscriptions in App Store Connect and submit for review when ready.")
  313. return 0
  314. if __name__ == "__main__":
  315. raise SystemExit(main())