|
1 | 1 | # Copyright (c) Microsoft Corporation. All rights reserved.
|
2 | 2 | # Licensed under the MIT License.
|
3 |
| -from typing import Dict |
| 3 | +from typing import Dict, List |
4 | 4 |
|
5 | 5 | from botbuilder.schema import Activity
|
6 | 6 |
|
@@ -73,63 +73,82 @@ async def validate_auth_header(
|
73 | 73 | if not auth_header:
|
74 | 74 | raise ValueError("argument auth_header is null")
|
75 | 75 |
|
76 |
| - if SkillValidation.is_skill_token(auth_header): |
77 |
| - return await SkillValidation.authenticate_channel_token( |
78 |
| - auth_header, |
79 |
| - credentials, |
80 |
| - channel_service, |
81 |
| - channel_id, |
82 |
| - auth_configuration, |
83 |
| - ) |
84 |
| - |
85 |
| - if EmulatorValidation.is_token_from_emulator(auth_header): |
86 |
| - return await EmulatorValidation.authenticate_emulator_token( |
87 |
| - auth_header, credentials, channel_service, channel_id |
88 |
| - ) |
89 |
| - |
90 |
| - # If the channel is Public Azure |
91 |
| - if not channel_service: |
92 |
| - if service_url: |
93 |
| - return await ChannelValidation.authenticate_channel_token_with_service_url( |
| 76 | + async def get_claims() -> ClaimsIdentity: |
| 77 | + if SkillValidation.is_skill_token(auth_header): |
| 78 | + return await SkillValidation.authenticate_channel_token( |
94 | 79 | auth_header,
|
95 | 80 | credentials,
|
96 |
| - service_url, |
| 81 | + channel_service, |
97 | 82 | channel_id,
|
98 | 83 | auth_configuration,
|
99 | 84 | )
|
100 | 85 |
|
101 |
| - return await ChannelValidation.authenticate_channel_token( |
102 |
| - auth_header, credentials, channel_id, auth_configuration |
103 |
| - ) |
| 86 | + if EmulatorValidation.is_token_from_emulator(auth_header): |
| 87 | + return await EmulatorValidation.authenticate_emulator_token( |
| 88 | + auth_header, credentials, channel_service, channel_id |
| 89 | + ) |
| 90 | + |
| 91 | + # If the channel is Public Azure |
| 92 | + if not channel_service: |
| 93 | + if service_url: |
| 94 | + return await ChannelValidation.authenticate_channel_token_with_service_url( |
| 95 | + auth_header, |
| 96 | + credentials, |
| 97 | + service_url, |
| 98 | + channel_id, |
| 99 | + auth_configuration, |
| 100 | + ) |
| 101 | + |
| 102 | + return await ChannelValidation.authenticate_channel_token( |
| 103 | + auth_header, credentials, channel_id, auth_configuration |
| 104 | + ) |
104 | 105 |
|
105 |
| - if JwtTokenValidation.is_government(channel_service): |
| 106 | + if JwtTokenValidation.is_government(channel_service): |
| 107 | + if service_url: |
| 108 | + return await GovernmentChannelValidation.authenticate_channel_token_with_service_url( |
| 109 | + auth_header, |
| 110 | + credentials, |
| 111 | + service_url, |
| 112 | + channel_id, |
| 113 | + auth_configuration, |
| 114 | + ) |
| 115 | + |
| 116 | + return await GovernmentChannelValidation.authenticate_channel_token( |
| 117 | + auth_header, credentials, channel_id, auth_configuration |
| 118 | + ) |
| 119 | + |
| 120 | + # Otherwise use Enterprise Channel Validation |
106 | 121 | if service_url:
|
107 |
| - return await GovernmentChannelValidation.authenticate_channel_token_with_service_url( |
| 122 | + return await EnterpriseChannelValidation.authenticate_channel_token_with_service_url( |
108 | 123 | auth_header,
|
109 | 124 | credentials,
|
110 | 125 | service_url,
|
111 | 126 | channel_id,
|
| 127 | + channel_service, |
112 | 128 | auth_configuration,
|
113 | 129 | )
|
114 | 130 |
|
115 |
| - return await GovernmentChannelValidation.authenticate_channel_token( |
116 |
| - auth_header, credentials, channel_id, auth_configuration |
117 |
| - ) |
118 |
| - |
119 |
| - # Otherwise use Enterprise Channel Validation |
120 |
| - if service_url: |
121 |
| - return await EnterpriseChannelValidation.authenticate_channel_token_with_service_url( |
| 131 | + return await EnterpriseChannelValidation.authenticate_channel_token( |
122 | 132 | auth_header,
|
123 | 133 | credentials,
|
124 |
| - service_url, |
125 | 134 | channel_id,
|
126 | 135 | channel_service,
|
127 | 136 | auth_configuration,
|
128 | 137 | )
|
129 | 138 |
|
130 |
| - return await EnterpriseChannelValidation.authenticate_channel_token( |
131 |
| - auth_header, credentials, channel_id, channel_service, auth_configuration |
132 |
| - ) |
| 139 | + claims = await get_claims() |
| 140 | + |
| 141 | + if claims: |
| 142 | + await JwtTokenValidation.validate_claims(auth_configuration, claims.claims) |
| 143 | + |
| 144 | + return claims |
| 145 | + |
| 146 | + @staticmethod |
| 147 | + async def validate_claims( |
| 148 | + auth_config: AuthenticationConfiguration, claims: List[Dict] |
| 149 | + ): |
| 150 | + if auth_config and auth_config.claims_validator: |
| 151 | + await auth_config.claims_validator(claims) |
133 | 152 |
|
134 | 153 | @staticmethod
|
135 | 154 | def is_government(channel_service: str) -> bool:
|
|
0 commit comments